Visualizing in silico chemical decompositions

dataviz
cheminformatics
LNP
Ways to view ultralarge chemical spaces
Author

Akshay Balsubramani

Visualizing chemical decompositions

In working with ultralarge chemical spaces, it is often useful for drug designers and other scientists like medicinal chemists to visualize the chemical decompositions of molecules, with a view towards understanding the local chemical space and the relationships between molecules.

In this post, I’ll cover a few different ways to do this which are appropriate for different settings.

Context

We often encounter a situation where we have computed a decomposition of a molecule into a set of structure fragments. The molecule itself is represented by these fragments and the reaction rules to reconstitute them back together.

Highlighting the origin of individual atoms

It’s useful to be able to manipulate atom-level visualizations of these fragment decompositions. For each atom, we’d like to be able to track which fragment it came from.

This is an extra bit of bookkeeping that can be done when evaluating the infix expression corresponding to each structure. To do this, we modify and refactor the evaluate_expression function to return a list of mol objects with annotated atoms, which can be interpreted in parallel with each product:

CODE
from collections import defaultdict
import matplotlib.colors, io
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG, display
import matplotlib.pyplot as plt, numpy as np
from PIL import Image


RDLogger.DisableLog('rdApp.*')  # keep the output tidy

def _propagate_frag_ids(prod, reactants):
    p = Chem.Mol(prod)
    for at in p.GetAtoms():
        ids = set()
        if at.HasProp("react_idx") and at.HasProp("react_atom_idx"):
            r = int(at.GetProp("react_idx"))
            a = int(at.GetProp("react_atom_idx"))
            ra = reactants[r].GetAtomWithIdx(a)
            if ra.HasProp("frag_id"):
                ids.update(ra.GetProp("frag_id").split(","))
        elif at.HasProp("old_mapno"):
            m = int(at.GetProp("old_mapno"))
            for rm in reactants:
                for ra in rm.GetAtoms():
                    if ra.GetAtomMapNum() == m and ra.HasProp("frag_id"):
                        ids.update(ra.GetProp("frag_id").split(","))
        if not ids:
            ids.add("NEW")
        at.SetProp("frag_id", ",".join(sorted(ids)))
        at.SetAtomMapNum(0)
    return p

def evaluate_expression(expr, frags_program, rxn_dict):
    tok = expr.split('_')
    if len(tok) == 1 and tok[0].startswith('F'):
        idx = int(tok[0][1:])
        mol = Chem.MolFromSmiles(frags_program[idx])
        for a in mol.GetAtoms():
            a.SetProp("frag_id", f"F{idx}")
        return [Chem.CanonSmiles(frags_program[idx])], [mol]

    # assume F# _ R# _ F#    (binary)
    left, rcode, right = tok
    l_smi, l_mol = evaluate_expression(left, frags_program, rxn_dict)
    r_smi, r_mol = evaluate_expression(right, frags_program, rxn_dict)

    rxn = AllChem.ReactionFromSmarts(rxn_dict[rcode])
    products, annotated = [], []
    for lm in l_mol:
        for rm in r_mol:
            for prods in rxn.RunReactants((lm, rm)):
                prod = prods[0]
                smi = Chem.CanonSmiles(Chem.MolToSmiles(prod, True))
                if smi in products:
                    continue
                products.append(smi)
                annotated.append(_propagate_frag_ids(prod, (lm, rm)))
    return products, annotated


def _avg_rgb(cols):
    return tuple(float(np.mean([c[i] for c in cols])) for i in range(3))

def highlight_fragments(mol, size=(300,200), palette=["#bdbdbd", '#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#7d87b9', '#bec1d4', '#d6bcc0']):
    athighlights = defaultdict(list)
    arads = {}
    uq_fids = set()
    for a in mol.GetAtoms():
        for fid in a.GetProp("frag_id").split(","):
            uq_fids.add(fid)
            athighlights[a.GetIdx()].append(
                matplotlib.colors.to_rgb(palette[int(fid[1:])])
            )
        arads[a.GetIdx()] = 0.5
    bndhighlights = defaultdict(list)
    for b in mol.GetBonds():
        for fid in uq_fids:
            if fid in athighlights[b.GetBeginAtomIdx()] and fid in athighlights[b.GetEndAtomIdx()]:
                bndhighlights[b.GetIdx()].append(
                    matplotlib.colors.to_rgb(palette[int(fid[1:])])
                )

    d = rdMolDraw2D.MolDraw2DSVG(*size)
    d = rdMolDraw2D.MolDraw2DCairo(*size)
    # rdMolDraw2D.PrepareAndDrawMolecule(
    #     d, mol,
    #     highlightAtoms=list(a_cols),
    #     highlightAtomColors=a_cols,
    #     highlightBonds=list(b_cols),
    #     highlightBondColors=b_cols,
    # )
    d.DrawMoleculeWithHighlights(
        mol, "", 
        dict(athighlights),dict(bndhighlights),arads,{}
    )
    d.FinishDrawing()
    img = Image.open(io.BytesIO(d.GetDrawingText()))
    display(img)
    #display(SVG(d.GetDrawingText()))

To illustrate this in action, we can turn to a standard example of a common ionizable lipid and break it down at its ester linkages.

CODE
frags = [
    "CCC[CH3:1]",     # F0 : a methyl radical (map #1)
    "O=C[CH3:2]",     # F1 : another methyl radical (map #2)
]

rxn_smarts = {
    "R1": "[CH3:1].[CH3:2]>>[C:1]-[C:2]"
}


expr = "F0_R1_F1"
smis, mols = evaluate_expression(expr, frags, rxn_smarts)

print("canonical SMILES:", smis[0])
for a in mols[0].GetAtoms():
    print(f"Atom {a.GetIdx():>2}: frag_id = {a.GetProp('frag_id')}")

highlight_fragments(mols[0])

With complex structures, this really does delineate the fragment-based decomposition of any chemical structure.

Tree-based images

A complementary way to proceed is to render fragments individually and use the inherent tree-based structure of the decomposition as a guide to visualization.

CODE
# Imports

from datetime import datetime
import numpy as np

import networkx as nx
from networkx.drawing.nx_pydot import graphviz_layout
from networkx.readwrite import json_graph

from rdkit import Chem
from rdkit.Chem import Draw

import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from PIL import Image

print('Last modified:', datetime.now())
Last modified: 2025-08-22 02:26:20.542314

Style config

We want to color each reaction differently to better illustrate the skeleton of the molecule.

CODE
BG_COL = (255, 255, 255)

REACTION_COLORS = {
    'esterification': 'blue',
    'amide_coupling': 'green',
    'alcohols_to_ether': 'red',
    'thiol_to_disulfide': 'purple',
    'epoxide_ring_opening': 'orange',
    'epoxide_ring_opening_saturated': 'cyan',
    'michael_addition_acrylate_1o': 'magenta',
    'michael_addition_acrylate_2o': 'yellow',
    'michael_addition_acrylamide_1o': 'pink',
    'michael_addition_acrylamide_2o': 'lime',
}

Graph construction functions

CODE
file_pfx = "../../files/chem/"

decomp_tools_path = file_pfx + "combinatorial_chemistry_decomposition_tools.py"

from importlib.machinery import SourceFileLoader
ccd_tools = SourceFileLoader("ccd_tools", decomp_tools_path).load_module()


class TreeNode:
    def __init__(self, label):
        self.label = label
        self.children = []

 
def construct_tree(molecule_info, root_smiles):
    root = TreeNode(root_smiles)
    def build_tree(node, smiles):
        if smiles in molecule_info:
            reaction_type = molecule_info[smiles]['reaction_type']
            for component_smiles in molecule_info[smiles]['components']:
                child_node = TreeNode(component_smiles)
                node.children.append((reaction_type, child_node))
                build_tree(child_node, component_smiles)
    build_tree(root, root_smiles)
    return root


def get_tree(input_molecule_smiles, rxn_smarts_dict, rxn_codes):
    input_molecule_smiles = Chem.CanonSmiles((input_molecule_smiles))
    molecule_info, _, _, _ = ccd_tools.decompose_dataset(
        [input_molecule_smiles], 
        rxn_smarts_dict, 
        rxn_codes, 
        reaction_types=["esterification"],
        print_debug=False
    )
    
    ccd_tools.compute_fragment_counts(molecule_info, frags_program, rxn_codes)
    db_construction.decompose_lipid_set(
        [input_molecule_smiles],
        default_pfx = 'all_chems/'
    )
    root_node = construct_tree(molecule_info, input_molecule_smiles)
    nodes = []
    links = []
    index = 0
    
    # Recursive function to traverse the tree and populate nodes and links
    def traverse_tree(node, index):
        # Add the current node to the list of nodes
        ind = index
        nodes.append({"name": f"{node.label}_{ind}"}) # Adding index to ensure unique node name
        for reaction_type, child_node in node.children:
            links.append({"source": f"{node.label}_{ind}", "target": f"{child_node.label}_{ind+1}", "name": reaction_type})
            # Recursively traverse child nodes
            traverse_tree(child_node, index+1)
    
    traverse_tree(root_node, index)
    G = nx.DiGraph()
    # Replace all white pixels with background color
    for node in nodes:
        G.add_node(node["name"])
    for link in links:
        G.add_edge(link["source"], link["target"], **link)
    
    data = json_graph.tree_data(G, root=input_molecule_smiles+'_0')
    T = nx.tree_graph(data)
    for node in G.nodes:
        T.add_node(node)
    for link in links:
        T.add_edge(link["source"], link["target"], **link)
    return T,nodes,links

Tree display function

CODE
def display_lipid_tree(input_molecule_smiles): 
    input_molecule_smiles = Chem.CanonSmiles(input_molecule_smiles)        
    T, nodes, links = get_tree(input_molecule_smiles) 
    pos = graphviz_layout(T, prog="dot") # Nodes coordinates

    fig, ax = plt.subplots(figsize=(25, 50))
    ax.set_aspect('auto')
    ax.set_zorder(1)  # Set zorder of edges to 1
    ax.set_facecolor('#%02x%02x%02x' % bg_col) 
    cmap = plt.get_cmap("viridis")  
    xmin, xmax, ymin, ymax = 0, 0, 0, 0

    for e in links:
        node1, node2, reaction = e["source"], e["target"], e["name"]
        x1, y1 = pos[node1]
        x2, y2 = pos[node2]
        xmin = min(xmin, x1, x2)
        xmax = max(xmax, x1, x2)
        ymin = min(ymin, y1, y2)
        ymax = max(ymax, y1, y2)
        line = plt.Line2D((x1, x2), (y1, y2), linewidth=5, alpha=0.9, color=reaction_colors[e['name']])
        ax.add_line(line)
        text_x = x2
        text_y = (y1 + y2) / 2
        ax.text(text_x, text_y, reaction[0:6], fontsize=20, ha='center', va='center', color='black')
    
    for i, smiles in enumerate(T.nodes()):
        smiles = smiles.split('_')[0]
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            img = Draw.MolToImage(mol, size=(100, 100))            
            img = np.array(img)
            img[np.all(img == 255, axis=-1)] = bg_col # Replace all white pixels with background color
            img = Image.fromarray(img)
 
            # Create image annotation box and add to plot
            imagebox = OffsetImage(img)
            ab = AnnotationBbox(imagebox, pos[list(T.nodes.keys())[i]], frameon=False, pad=0.0,
                                bboxprops=dict(edgecolor='none'))
            ab.set_zorder(2)  # Set zorder of nodes to 2
            ax.add_artist(ab)
    ax.axis('off')
    ax.set_title('Tree-based chemical decomposition', fontsize=50, pad=50)
    legend_handles = []
    for reaction, color in reaction_colors.items():
        legend_handles.append(mlines.Line2D([], [], color=color, label=reaction, linewidth=5))
    legend = ax.legend(handles=legend_handles, title='Reactions', fontsize=25, title_fontsize=35,
                       loc='upper right', bbox_to_anchor=(1.1, 1))
    legend.set_zorder(3)  # Set zorder of legend to 3 
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=70))
    xrange = xmax-xmin
    yrange = ymax-ymin
    ax.set_xlim([xmin-xrange/2, xmax+xrange/2])
    ax.set_ylim([ymin-yrange/2, ymax+yrange/2])
    bg_col_normalized = tuple(value / 255 for value in bg_col)
    fig.patch.set_facecolor(bg_col_normalized)
    plt.show()

Visualize Lipid Decomposition

CODE
from rdkit import RDLogger

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning') # Disable RDKit warnings

input_smiles = "CCCCC(=O)OCCCCCCCCOC(=O)CC(CC(=O)OCCCCCCCCOC(=O)CCCC)(C(=O)OCCCCCCCCOC(=O)CCCC)OC(=O)CCN(C)C"
display_lipid_tree(input_smiles)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[15], line 7
      4 RDLogger.DisableLog('rdApp.warning') # Disable RDKit warnings
      6 input_smiles = "CCCCC(=O)OCCCCCCCCOC(=O)CC(CC(=O)OCCCCCCCCOC(=O)CCCC)(C(=O)OCCCCCCCCOC(=O)CCCC)OC(=O)CCN(C)C"
----> 7 display_lipid_tree(input_smiles)

Cell In[14], line 3, in display_lipid_tree(input_molecule_smiles)
      1 def display_lipid_tree(input_molecule_smiles): 
      2     input_molecule_smiles = Chem.CanonSmiles(input_molecule_smiles)        
----> 3     T, nodes, links = get_tree(input_molecule_smiles) 
      4     pos = graphviz_layout(T, prog="dot") # Nodes coordinates
      6     fig, ax = plt.subplots(figsize=(25, 50))

Cell In[13], line 32, in get_tree(input_molecule_smiles)
     28 def get_tree(input_molecule_smiles):
     29     input_molecule_smiles = Chem.CanonSmiles((input_molecule_smiles))
     30     molecule_info, _, _, _ = ccd_tools.decompose_dataset(
     31         [input_molecule_smiles], 
---> 32         rxn_smarts_dict, 
     33         rxn_codes, 
     34         reaction_types=["esterification"],
     35         print_debug=False
     36     )
     39     db_construction.decompose_lipid_set(
     40         [input_molecule_smiles],
     41         default_pfx = 'all_chems/'
     42     )
     43     root_node = construct_tree(molecule_info, input_molecule_smiles)

NameError: name 'rxn_smarts_dict' is not defined
CODE
import itertools


def evaluate_expression_mod(
    infix_str, frags_program, rxn_smarts_dict, rxn_codes
):
    """
    Returns a list of SMILES strings that are the possible results of evaluating the infix expression infix_str. 
    frags_program is a list of SMILES strings that correspond to fragment IDs in the expression. 
    random_choice is the number of random fragments to subsample and add to the expression (None for no subsampling).
    """
    rxn_names = {v: k for k, v in rxn_codes.items()}
    tokens = infix_str.split('_')
    # Base case: if token is a SMILES, return its value
    if len(tokens) == 1:
        if tokens[0].startswith('F'):     # It's a fragment
            this_frag_smiles = frags_program[int(tokens[0][1:])]
            toret = [ this_frag_smiles ]    # Add the fragment that constituted the original structure
            possible_results = [Chem.CanonSmiles(x) for x in list(set(toret))]
        else:     # It's a SMILES string
            possible_results = [ Chem.CanonSmiles(tokens[0]) ]
        return possible_results
    
    # Handle Parentheses
    for i, ch in enumerate(tokens):
        if ch == '(':
            start = i
        elif ch == ')':
            end = i
            mid_values = evaluate_expression(
                '_'.join(tokens[start+1:end]), frags_program, rxn_smarts_dict, rxn_codes)
            toret = []
            leftstr = '_'.join(tokens[:start])
            rightstr = '_'.join(tokens[end+1:])
            for left_val in mid_values:
                strlst = []
                if len(leftstr) > 0:
                    strlst.append(leftstr)
                strlst.append(str(left_val))
                if len(rightstr) > 0:
                    strlst.append(rightstr)
                toret.extend(evaluate_expression(
                    '_'.join(strlst), frags_program, rxn_smarts_dict, rxn_codes))
            return toret

    # Handle reactions
    for i in reversed(range(len(tokens))):
        if tokens[i].startswith('R'):
            toret = []
            left_vals = evaluate_expression(
                '_'.join(tokens[:i]), frags_program, rxn_smarts_dict, rxn_codes)
            right_vals = evaluate_expression(
                '_'.join(tokens[i+1:]), frags_program, rxn_smarts_dict, rxn_codes)
            reaction_code = tokens[i]
            for comp1, comp2 in itertools.product(left_vals, right_vals):
                reactants = [comp1, comp2]

                rxn = AllChem.ReactionFromSmarts(rxn_smarts_dict[rxn_names[reaction_code]])
                products, annotated = [], []

                left_smiles, left_mols = left_vals[0], left_vals[1]
                right_smiles, right_mols = right_vals[0], right_vals[1]
                for l_smi, l_mol in zip(left_smiles, left_mols):
                    for r_smi, r_mol in zip(right_smiles, right_mols):
                        for prods in rxn.RunReactants((l_mol, r_mol)):
                            for prod in prods:
                                smi = Chem.CanonSmiles(Chem.MolToSmiles(prod, True))
                                if smi in products:          # de‑dup
                                    continue
                                products.append(smi)
                                annotated.append(_propagate_frag_ids(prod, (l_mol, r_mol)))
            return products, annotated



from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG, display
import matplotlib.pyplot as plt
import numpy as np

def _average_rgb(colours):
    """Blend ≥2 RGB tuples by arithmetic mean."""
    return tuple(float(np.mean([c[i] for c in colours])) for i in range(3))

def highlight_fragments(mol, size=(600, 400), cmap="tab20"):
    """
    Render an SVG of `mol` where each fragment ID (mol atom prop 'frag_id')
    gets its own colour, following the RDKit Cookbook recipe for multiple
    highlight colours.

    Parameters
    ----------
    mol  : RDKit Mol that already carries per-atom `frag_id` properties
    size : (w,h) tuple in px
    cmap : any Matplotlib qualitative palette name with ≥ #fragments colours
    """
    mol = Chem.Mol(mol)                        # defensive copy
    rdMolDraw2D.PrepareMolForDrawing(mol)
    
    # 1) collect atoms per fragment
    frag_to_atoms = {}
    for atom in mol.GetAtoms():
        if not atom.HasProp("frag_id"): 
            continue
        for fid in atom.GetProp("frag_id").split(","):
            frag_to_atoms.setdefault(fid, set()).add(atom.GetIdx())

    # 2) assign a distinct base colour to each fragment ID
    palette = plt.cm.get_cmap(cmap, len(frag_to_atoms))
    fid2colour = {fid: palette(i)[:3] for i, fid in enumerate(sorted(frag_to_atoms))}

    # 3) build highlight dictionaries following the Cookbook idiom
    atom_highlights, bond_highlights = {}, {}
    for fid, atoms in frag_to_atoms.items():
        colour = fid2colour[fid]
        for a in atoms:
            if a in atom_highlights:           # linker atom → blend colours
                atom_highlights[a] = _average_rgb([atom_highlights[a], colour])
            else:
                atom_highlights[a] = colour

        # bonds where both atoms belong to this fragment
        for bond in mol.GetBonds():
            if bond.GetBeginAtomIdx() in atoms and bond.GetEndAtomIdx() in atoms:
                bidx = bond.GetIdx()
                if bidx in bond_highlights:
                    bond_highlights[bidx] = _average_rgb([bond_highlights[bidx], colour])
                else:
                    bond_highlights[bidx] = colour

    # 4) flat list of all highlighted atoms/bonds
    hatoms = list(atom_highlights)
    hbonds = list(bond_highlights)

    # 5) RDKit Cookbook draw → SVG
    drawer = rdMolDraw2D.MolDraw2DSVG(*size)
    rdMolDraw2D.PrepareAndDrawMolecule(
        drawer, mol,
        highlightAtoms=hatoms,
        highlightAtomColors=atom_highlights,
        highlightBonds=hbonds,
        highlightBondColors=bond_highlights,
    )
    drawer.FinishDrawing()
    display(SVG(drawer.GetDrawingText()))

Reuse

Citation

BibTeX citation:
@online{balsubramani,
  author = {Balsubramani, Akshay},
  title = {Visualizing in Silico Chemical Decompositions},
  langid = {en}
}
For attribution, please cite this work as:
Balsubramani, Akshay. n.d. “Visualizing in Silico Chemical Decompositions.”