Lightweight, interactive QSAR importance plots

dataviz
cheminformatics
A self-contained visualization of atomic contributions to models
Author

Akshay Balsubramani

Modified

October 30, 2024

Introduction

I’ve often run into situations in small-molecule drug discovery in which chemists want to visualize a dataset of chemical structures. Often, this is in the context of some learned embedding space, and sometimes even in the context of a “QSAR” (Quantitative Structure-Activity Relationship) model which predicts the response of each structure in some context (e.g. its efficacy as a drug in vivo, or in an experimental validation platform). QSAR models can further isolate atoms of interest using an atomic importance plot, which shows the contribution of each atom to the model’s predictions.

A static scatterplot is a typical way to organize an embedding, but it can be unwieldy to visualize this amount of heterogeneous information in a single static plot. Many thousands of structures typically merit visualization at once, and each of them is represented using a chemical structure diagram. In practice, this amount of information is successfully conveyed via an interactive scatterplot, which is easily rendered, as we have previously discussed.

Such a scatterplot provides a way to visualize a data-driven learned chemical space, and interactively see the structures in it. This post extends that visualization in two ways:

  • We visualize the atomic contributions to QSAR model predictions, for each structure; so each chemical structure diagram contains atom-by-atom contributions to the model’s prediction. These can be calculated in a black-box manner with standard RDKit functionality we demonstrate.
  • The plot is made in self-contained HTML, without requiring a Python environment or any code. HTML is a universal standard, so the files can be opened in any browser; we take advantage of a couple of modern wrinkles in the HTML standard, exploited by the wonderful library Bokeh.
Exploring chemical space amounts to searching over an embedded space of chemical structures. (Source.)

Like the previous post, the workflow to interactively render data will be structured in two parts:

  • Preprocessing script: Given a dataset of chemicals, prepares a dataframe that’s ready to be viewed by a front-end. This includes loading (in this case, training for demonstration purposes) a QSAR model.
  • Interactive presentation: Given a dataframe, calculate the model’s predictions on each structure and their atomic contributions, and view the results interactively in a browser.

Embedding to organize a dataset of chemical structures

First, we need to compute an embedding a given dataset of chemical structures (SMILES strings), and load/train the desired prediction model.

Embedding is a common task in cheminformatics, which we have discussed in a previous post. Following that post, we will work with a combined dataset of:

This dataset constitutes >5,000 chemicals.

CODE
import pandas as pd, time, deepchem as dc
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors, Descriptors3D, Draw, rdMolDescriptors, Draw, PandasTools
from rdkit.Chem.Draw import IPythonConsole
import scanpy as sc, anndata


fda_fname = 'fda_zinc20.txt'
fda_url = 'https://zinc20.docking.org/substances/subsets/fda.txt:zinc_id+smiles+preferred_name+mwt+logp+rb?count=all'
fda_molecules = pd.read_csv(fda_url, header=None, index_col=0, sep='\t')
fda_molecules['dataset'] = 'FDA'

worldnotfda_fname = 'world-not-fda_zinc20.txt'
worldnotfda_url = 'https://zinc20.docking.org/substances/subsets/world-not-fda.txt:zinc_id+smiles+preferred_name+mwt+logp+rb?count=all'
worldnotfda_molecules = pd.read_csv(worldnotfda_url, header=None, index_col=0, sep='\t')
worldnotfda_molecules['dataset'] = 'world-not-FDA'

all_chems = pd.concat([fda_molecules, worldnotfda_molecules])
all_chems.columns = ['SMILES', 'Preferred name', 'Molecular weight', 'Log P', 'Rotatable bonds', 'dataset']

Calculate feature fingerprints for the data

CODE
from rdkit.Chem.Draw import SimilarityMaps

import numpy as np
import io
from io import BytesIO
from functools import partial

itime = time.time()
fp_func = partial(SimilarityMaps.GetMorganFingerprint, nBits=2048, radius=3)
feature_mat = np.array([np.array(fp_func(Chem.MolFromSmiles(x))) for x in all_chems['SMILES']])
print("Data featurized. Time: {}".format(time.time() - itime))
Data featurized. Time: 18.869210720062256

Calculate neighborhoods in the embedding space

CODE
# Calculating some chemical descriptors using RDKit.

metadata_df = pd.DataFrame(data=feature_mat, index=all_chems.index)

metadata_df = all_chems.copy()  #metadata_df.join(all_chems)
metadata_df.index.name = 'ZINC_ID'



anndata_all = anndata.AnnData(X=feature_mat, obs=metadata_df)

itime = time.time()
# Do PCA for denoising and dimensionality reduction
sc.pp.pca(anndata_all, n_comps=10)
print("PCA computed. Time: {}".format(time.time() - itime))

# Compute neighborhood graph in PCA space
sc.pp.neighbors(anndata_all)
print("Neighbors computed. Time: {}".format(time.time() - itime))

# Compute UMAP using neighborhood graph
sc.tl.umap(anndata_all)
print("UMAP calculated. Time: {}".format(time.time() - itime))

sc.pl.umap(anndata_all, color='Molecular weight') #, s=4)

anndata_all.obs['x'] = anndata_all.obsm['X_umap'][:, 0]
anndata_all.obs['y'] = anndata_all.obsm['X_umap'][:, 1]
#anndata_all.write('approved_drugs.h5')
print("Data written. Time: {}".format(time.time() - itime))
/var/folders/5b/ps6ymxr90tj0jglr7hvc98zm0000gn/T/ipykernel_89541/2276992414.py:10: FutureWarning: X.dtype being converted to np.float32 from int64. In the next version of anndata (0.9) conversion will not be automatic. Pass dtype explicitly to avoid this warning. Pass `AnnData(X, dtype=X.dtype, ...)` to get the future behavour.
  anndata_all = anndata.AnnData(X=feature_mat, obs=metadata_df)
PCA computed. Time: 0.2940680980682373
Neighbors computed. Time: 1.092061996459961
UMAP calculated. Time: 9.20744776725769

Data written. Time: 9.378067016601562

Train a model and make predictions

We train a model to predict log-P (a measure of lipophilicity, namely the octanol-water partition coefficient) from the chemical structures.

In practical situations, we could load a pre-trained model as well (example).

CODE
from sklearn.ensemble import RandomForestRegressor
in_smiles = list(all_chems['SMILES'])
in_signal = list(all_chems['Log P'])

y_train = np.array(in_signal)
X_train = anndata_all.X

model = RandomForestRegressor()
model.fit(X_train, y_train)

# y_pred = model.predict(X_test)
# accuracy = accuracy_score(y_test, y_pred)
# print(f"Accuracy: {accuracy}")
RandomForestRegressor()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Make predictions on the data

CODE
import scipy.stats
import matplotlib.colors, matplotlib.pyplot as plt, base64

tree_preds = []
for tree in range(model.n_estimators):
    tree_preds.append(model.estimators_[tree].predict(X_train))
tree_preds = np.array(tree_preds)

anndata_all.obs["predictions"] = tree_preds.mean(axis=0)     # == model.predict(X_train)
anndata_all.obs["uncertainty"] = tree_preds.std(axis=0)
anndata_all.obs["q_predictions"] = scipy.stats.rankdata(anndata_all.obs["predictions"])/len(anndata_all.obs["predictions"])

# adta_all_mod.obs["q_predictions"] is between 0 and 1. 
# replace each of these numerical values by corresponding hex viridis colors.

anndata_all.obs["colormapped_predictions"] = anndata_all.obs["q_predictions"].apply(lambda x: matplotlib.colors.rgb2hex(plt.cm.viridis(x)))

Plot atomic importances in an interactive scatterplot

There is a very simple conceptual way to calculate the importance of any atom to a prediction: if we remove the atom from the molecule but leave it otherwise equivalent, how much does the prediction change?

This is very efficient to implement using the circular fingerprint features common in cheminformatics. RDKit’s Greg Landrum has some comprehensive explanations on his blog, in addition to the published whitepaper.

CODE
import bokeh.plotting, bokeh.models

def get_pred(fp, pred_function):
    fp = np.array([list(fp)])
    return pred_function(fp)[0]

def render_molecule_scatterplot(
    df, fp_func, smiles_col="SMILES", name_col="Preferred name", 
    color_col="color", legend_col="cluster", circ_line_width=0.1, pt_size=5, title='Molecule scatterplot',  
    output_html='interactive_scatterplot.html', pred_model=None, size_tuple=(250, 250)
):
    """
    Renders a scatterplot of chemical structures using Bokeh, with points colored according to the 'color' column.

    Parameters:
    df (pd.DataFrame): DataFrame with columns ['SMILES', 'x', 'y', 'name', 'predictions', color_col]
    output_html (str): Filename for the output HTML file

    Returns:
    None
    """
    # Generate images of molecules and encode them as base64
    images = []
    for i in range(len(df[smiles_col])):
        smiles = df[smiles_col][i]
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            img = Draw.MolToImage(mol, size=size_tuple)
            buffered = BytesIO()
            img.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode()
            
            # Query prediction model to generate "similarity map" of atom contributions
            d = Draw.MolDraw2DCairo(*size_tuple)
            _, maxWeight = SimilarityMaps.GetSimilarityMapForModel(
                mol,
                fp_func,
                lambda x : get_pred(x, pred_model.predict), colorMap='coolwarm', 
                draw2d=d
            )
            d.FinishDrawing()
            img_str = base64.b64encode(d.GetDrawingText()).decode()

            img_uri = f'data:image/png;base64,{img_str}'
            images.append(img_uri)
        else:
            images.append(None)

    df['img'] = images
    df['name'] = list(df[name_col])


    # Prepare data source
    source = bokeh.models.ColumnDataSource(df)

    # Create a Bokeh figure
    p = bokeh.plotting.figure(
        tools='pan,box_zoom,reset,hover',
        title=title, 
        width=900, 
        height=600
    )

    # Add circle glyphs with color
    if legend_col is not None:
        p.circle('x', 'y', size=pt_size, line_width=circ_line_width, source=source, fill_alpha=0.6, fill_color=color_col, legend_field=legend_col)
    else:
        p.circle('x', 'y', size=pt_size, line_width=circ_line_width, source=source, fill_alpha=0.6, fill_color=color_col)
    p.toolbar.logo = None
    # p.background_fill_color = "black"

    # Configure hover tool
    hover = p.select_one(bokeh.models.HoverTool)
    hover.tooltips = """
    <div>
        <div><img src="@img" alt="Molecule image" style="width:300px;height:300px;"></div>
        <div><span style="font-size: 12px;"><b>Name:</b> @name</span></div>
        <div><span style="font-size: 12px;"><b>Prediction:</b> @predictions +/- @uncertainty </span></div>
    </div>
    """
    if legend_col is not None:
        p.legend.title = legend_col
        p.legend.location = 'center'
        p.legend.background_fill_alpha = 0.5
        p.legend.click_policy="hide"

        # **Set the legend to have two columns**
        p.legend.orientation = 'vertical'
        #p.legend.label_width = 100
        #p.legend.label_height = 20
        #p.legend.glyph_height = 20
        #p.legend.spacing = 5
        p.legend.ncols = 2  # Set number of columns to 2

        p.add_layout(p.legend[0], 'right')
    
    bokeh.plotting.output_file(output_html)
    bokeh.plotting.save(p)

Render self-contained HTML plot

This may take a while as it’s not optimized in any way. Budget ~4-10 structures per second.

Caution: This will generate a large HTML file, as it contains a lot of data, primarily the rendered chemical structures and atomic contribution info for each structure. Budget ~5MB per 100 structures.

CODE
render_molecule_scatterplot(
    anndata_all.obs, fp_func, 
    smiles_col="SMILES", 
    name_col="Preferred name", 
    color_col="colormapped_predictions", 
    pred_model=model, 
    pt_size=4, 
    output_html=f'msscatter.html', 
    legend_col=None, 
    title=f'Log-p predictions on known drugs', 
    size_tuple=(250, 250)
)
BokehDeprecationWarning: 'circle() method with size value' was deprecated in Bokeh 3.4.0 and will be removed, use 'scatter(size=...) instead' instead.

The resulting file readily displays in any browser, and can be shared with collaborators or viewed offline. It shows atomic importances for each structure in determining the log-P prediction. These sanity-check nicely: the lipophilic portions of the molecules are highlighted in red as contributing positively towards higher log-P. Conversely, the hydrophilic portions are in blue, as those atoms contribute negatively towards log-P.