Screening too many chemicals to count

cheminformatics
LNP
Searching inconceivably large spaces in drug discovery
Author

Akshay Balsubramani

How to predict on chemical space without enumerating it

In modern AI-aided drug discovery, the chemical space addressed by the drug discovery funnel is truly vast. Virtual screening of large libraries of compounds is much higher-throughput than traditional primary screening. However, the chemical space that can be explored is still limited by the ability to enumerate and score compounds; this does not reasonably scale over many billions of compounds because of the expense of calculation on each compound, no matter how fast the calculations are.

To address this problem, we exploit the fact that we only want the highest scoring structures from a library. This is a very similar situation to online marketing or advertising matching, in which recommendations need to be made from relatively little data. There is a deep understanding of the methods that succeed in such a task. One prominent example is also one of the oldest and most versatile - Thompson sampling.

We’ll explore how Thompson sampling can be used to score chemical structures without enumerating them, and how it can be applied to the problem of virtual screening in drug discovery. For this application in chemistry, we’ll use a convenient implementation from a recent paper (Klarich et al. 2024).

Defining a chemical space implicitly

The celebrated Pfizer-BioNTech and Moderna mRNA vaccines for COVID-19 use modular lipid nanoparticles (LNPs) to deliver mRNA. These formulations use different ionizable lipids, known as ALC-0315 (Pfizer-BioNTech) and SM-102 (Moderna), which are often looked upon in a modular manner:

Their “reaction skeletons” are similar (ester linkages connecting two tails with some branching to a tertiary amine head group) but have important differences (e.g. their ester groups face opposite directions with respect to the head group). Suppose we want to try varying these components from one of these basic structures.

Take SM-102 as an example; we might want to vary the head and tails, keeping the ester linkers intact. Let’s say we want to modify it precisely, using the large and diverse collection of hundreds of primary amine heads and hydrophobic tails outlined in previous posts (akinc?). We will perform the following modifications:

  • Try a number of different head groups corresponding to different primary amines.

  • Keep the alkyl chain spacers separating the acid groups from the amine nitrogen (currently each tail has 5 spacer carbons before the carbonyl group).

  • Try different tail groups, varying the length, branching, and saturation.

In silico “reaction” rule

To implement Thompson sampling, we need to write these modifications as a multi-component in silico “reaction” in an unambiguous way. This will have several components per the above description:

  • A head group, which is a primary amine.
  • The first tail (as an alcohol).
  • The second tail (as an alcohol).

We insert the alkyl spacer carbons on either side.

CODE
from rdkit import Chem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdChemReactions
import io
from PIL import Image


def draw_reaction(smarts_in):
    rxn = rdChemReactions.ReactionFromSmarts(smarts_in)
    drawer = rdMolDraw2D.MolDraw2DCairo(700,300)
    drawer.DrawReaction(rxn)
    drawer.FinishDrawing()
    return drawer.GetDrawingText()


rxn_str = "[#6:1][NX3;H2:2].[#6:3][OX2H1:4].[#6:5][OX2H1:6]>>[#6:1][N:2](CCCCCC(=O)[OX2H1:4][#6:3])CCCCCC(=O)[OX2H1:6][#6:5]"
bio = io.BytesIO(draw_reaction(rxn_str))
Image.open(bio)

Head groups

First, we gather a set of primary amines that are well studied as head groups in LNPs, from a seminal line of work on short-RNA delivery (akinc?).

CODE
import mols2grid

frags_whitehead_amines = ['NCCCCCO', 'CCNCCCNCC', 'CNCCOCCNC', 'CN(C)CCCN', 'CCN(CC)CCCN', 'NCCN(CCO)CCO', 'NCCN', 'NCCNCCO', 'NCCN(CCN)CCN', 'CN(CCN)CCN', 'CCN(CC)CCNCCN', 'N/C=C1\\\\CCCN1', 'NCCNCCNCCNCCNCCN', 'NCCCOCCOCCOCCCN', 'NCCOCCO', 'CC(C)C[C@H](N)CO', 'NCCNCCN1CCN(CCN)CC1', 'NCCN1CCNCC1', 'CC(N)CNCC(C)N', 'CCN(CCCNC)CCCNC', 'CN(C)CCCNCCCN', 'CNCCN(CCNC)CCNC', 'CC(CN)N', 'CN(CCCN)CCCN', 'C(CN)CN1CCN(CCCN)CC1', 'C1=CC(=CC(=C1)CN)CN', 'C1=CC2=C(C=C1N)OCCO2', 'COCCCN', 'C1COC(O1)CCN', 'C1OC(C)(C)OC1CN', 'COCCN', 'CCOCCN', 'COC(OC)CN', 'C1CC(OC1)CN', 'C(CO)CN', 'C[C@@H](O)CN', 'C[C@H](O)CCN', 'CC(CO)(CO)N', 'C(CCO)CN', 'C(COCCO)N', 'C(O)C(C)(C)CN', 'CCC(CO)(CO)N', 'C(CCCO)CCN', 'C1C[C@H](O)CC[C@H]1N', 'N(C)CCNC', 'N(CC)CCNCC', 'N(C(C)C)CCNC(C)C', 'N(CC)CCCNCC', 'N(C)CCOCCOCCNC', 'N1CCCNCC1', 'CCCN(CC)CN', 'C1CCN(C1)CCN', 'CN(C)CCN', 'CC(CN(C)C)N', 'C1CCN(CC1)CCN', 'C1COCCN1CCN', 'C1COCCN1CCCN', 'C1=CN(C=N1)CCCN', 'CNCCN', 'CNCCCN', 'CCCNCCN', 'C(CNCCNCCN)N', 'C(CN)CN', 'C(CCN)CN', 'CCCCNCCN', 'N(CCNCCO)CCO', 'CCCCCCCCCCCCCN', 'C(CN)CNCCN', 'CCCNCCCN', 'C(CCNCCCN)CNCCCN', 'CC(C)N(C(C)C)CCN', 'C(CNCCNCCNCCN)N', 'C(CN)CNCCN', 'C(CNCCN)N', 'CCCCCCCCCCCCN(CCN)CCN', 'C1CC[C@@H]([C@H](C1)N)N', 'C1CC[C@@H]([C@@H](C1)N)N', 'C1CCN(CC1)N', 'C1C=CC[C@H]([C@@H]1N)N', 'C(CNCCN)CNCCN', 'CCNCCN', 'C=CCCN', 'C1CNC[C@H]1N', 'C1CC(CC(C1)N)N', 'C1CNCCC1N', 'N1CCNCC1', 'N1CCNC[C@H]1C', 'CC1CCC(CC1)CCN', 'CCC1CCC(CC1)CCN', 'N1CCNCC1c1ccccc1', 'N1CCCNCCCNCCC1', 'N1CCNCCNCCNCC1', 'N1CCCNCCNCCNCC1', 'N1CCCNCCNCCCNCC1', 'N1(C)CCCNCCN(C)CCCNCC1', 'N1CCCNCCCNCCCNCC1', 'N1CCNCCNCCOCC1', 'C1C[C@H](CC[C@@H]1N)N', 'C1CC(CC(C1)CN)CN', 'C1=CC(=CC(=C1)N)N', 'C1=CC(=CC=C1N)N', 'CC1=CC=CC(=C1N)N', 'CC1=C(C=C(C=C1)N)N', 'CC1=C(C=CC=C1N)N', 'CC1=CC(=C(C=C1)N)N', 'C1=CC2=C(C(=C1)N)C(=CC=C2)N', 'N1C(C)NC(C)NC1C', 'C1(=NC(=NC(=N1)N)N)N', 'C1CC(CCC1CN)CN', 'C(CCCCN)CCCN', 'C(COCCOCCN)N', 'C(COCCOCCOCCN)N', 'C(CCOCCCN)COCCCN', 'COCCCN', 'N(C)CCOCCOCCNC', 'C(CO)N', 'N(CCNCCO)CCO', 'CCCN', 'CCCCN', 'CCCCCN', 'CCCCCCN', 'CNCCCCCCNC', 'CCCNCCCN', 'C(CN)CNCCCN', 'C(CCCNCCCCCCN)CCN', 'C(CN)CNCCCNCCCN', 'CC(C)[NH]', 'CC(C)(C)[NH]', 'CCC(CC)N', 'CCC(C)(C)N', 'CC(C)CN', 'CCC(C)N', 'CC(C)C(C)N', 'CCCC(C)N', 'CC(C)CCN', 'CC(C)(C)CCN', 'CC(C)(C)CC(C)(C)N', 'CCC(C)CN', 'CCCCC(CC)CN', 'CC(CO)N', 'CC(C)C[C@@H](CO)N', 'CC(C)(C)C(CO)N', 'C(CO)(CO)(CO)N', 'CC(C)OCCCN', 'N(C(C)(C)C)CCNC(C)(C)C', 'C(CCN)CCN', 'C(CCCN)CCN', 'C(CCCCN)CCN', 'C(CCCCCN)CCN', 'C(CCCCCCCN)CCN', 'C(CCCCCCCCCN)CCN', 'CC(CC(C)(C)CCN)CN', 'C(CCCN(C)Cc1ccccc1)N', 'CN1CCN(CC1)CCCCN', 'CCN1CCC(CC1)N', 'CN1CCN(CC1)CCN', 'CN1CCOCC1CN', 'CCCN1CCC(CC1)N', 'C(CN(C)Cc1ccccc1)N', 'CN(C)CCOCCN', 'C1CCN(CC1)CCOCCN', 'C1CCN(C1)CCOCCN', 'CN(C1=CC=CC=C1)C(=O)CN', 'CN1C=CN=C1CN', 'CN1CCCC1CN', 'CN1CCCC1C(=O)N', 'CN1CCC(CC1)CN', 'CN1CCC(CC1)N', 'CC(C)N1CCN(CC1)CCCN', 'C1CN(CCN1)C(=O)CCN', 'C(C)(CN(C)C)(C)CN', 'CC(C)NCCN', 'C1[CH]C1N', 'C1CC(C1)N', 'C1CCC(C1)N', 'CCN1CCC[C@H]1CN', 'CCN1CCC[C@@H]1CN', 'C1CCC(CC1)N', 'C1CCC(CC1)CN', 'CC(C)(CN(C)C)CN', 'C1CC2CC1CC2N', 'C1CCCC(CCC1)N', 'C1C2CC3CC1CC(C2)(C3)N', 'C1C2CC3CC1CC(C2)(C3)CN', 'C1=CC=C(C=C1)CCCN', 'C1=CC=C(C=C1)CCCCN', 'CC(C)(C)C1=CC=C(C=C1)N', 'CC(C)(C)C1=CC=CC=C1N', 'CCC1=C(C(=CC=C1)CC)N', 'COC1=CC=CC=C1N', 'COC1=CC=CC=C1CN', 'CC1=CC=C(C=C1)C(C)N', 'CC(C)C1=CC=CC=C1N', 'CC(C)C1=CC=C(C=C1)N', 'CC(C)OC1=CC=CC(=C1)N', 'COC1=CC=C(C=C1)CCN', 'CCOC1=C(C=CC=C1)N', 'CCCCCCCCOC1=CC=C(C=C1)N', 'CCOC1=CC=CC=C1CN', 'CC(C)(C)C1=CC(=C(C=C1)C(C)(C)C)N', 'COC1=CC(=C(C=C1)OC)N', 'CC(C)(C)c1ccc(OC)c(c1)N', 'C(Cc1cc(OC)c(OC)cc1)N', 'C(c1cc(C)c(c(OC)c1)OC)N', 'C1=CC=C(C=C1)NCCN', 'C1=CC=C(C=C1)CNCCN', 'CNC1=CC=CC=C1N', 'C1=CC(=CC(=C1)N)CN', 'Cc1c(N)c(C)c(C)c(c1C)N', 'C1=CC=C(C(=C1)CN)N', 'CC1=CC(=C(C=C1C)N)N', 'c1(F)cc(ccc1)N', 'c1(F)ccc(cc1)N', 'C1=CC(=C(C=C1)F)CN', 'C(c1cc(F)ccc1)N', 'C1=CC(=C(C=C1)F)CCN', 'c1(cc(F)ccc1F)N', 'C1=C(C=C(C(=C1)F)F)N', 'C1=C(C=C(C=C1F)F)N', 'C(c1cc(F)ccc1F)N', 'C1=C(C=C(C(=C1F)F)F)N', 'c1(F)cc(F)c(c(F)c1)N', 'C(F)(F)(F)c1cc(ccc1)N', 'C(F)(F)(F)c1c(ccc(c1)F)N', 'CC1=CC=C(C=C1F)N', 'COC1=CC=C(C=C1F)N', 'c1(OC(F)F)ccc(cc1)N', 'C(F)(F)(F)Oc1ccc(cc1)N', 'CC1CC(CCC1N)CC2CCC(C(C2)C)N', 'N1C2NCCNC2NCC1', 'C1CC(C2=C1C=CC=C2)N', 'C1=CC2=C(C=C1)CC(C2)N', 'C1CC2=C(C1)C(=CC=C2)N', 'C1OC2=CC=C(C=C2O1)N', 'C1CCC2=C(C1)C=CC(=C2)N', 'C1C2=C(C=CC(=C2)N)C3=C1C=C(C=C3)N', 'c1(cc2cc3ccccc3cc2cc1)N', 'C1=CC=C2C(=C1)C3=C(C=CC=C3)C(=C2N)N', 'C1=CC=C(C=C1)C2=CC(=CC=C2)N', 'C1=CC=C(C=C1)C2=CC=C(C=C2)N', 'Cc1c(N)ccc(c1)-c1ccc(c(C)c1)N', 'C1=C(C=C(C(=C1)N)N)C2=CC=C(C(=C2)N)N', 'Cc1c(N)c(C)cc(c1)-c1cc(C)c(c(c1)C)N', 'C1=CC=C(C=C1)CC2=CC=CC=C2N', 'C1=CC(=CC=C1CC2=CC=C(C=C2)N)N', 'CNC1=CC=C(C=C1)CC2=CC=C(C=C2)N', 'C1=CC=C(C=C1)C(C2=CC=CC=C2)N', 'C(CC(c1ccccc1)c1ccccc1)N', 'C1=CC=C(C=C1)OC2=CC=CC=C2N', 'C1=CC=C(C=C1)OC2=CC=C(C=C2)N', 'c1(Oc2ccc(F)cc2)ccc(cc1)N', 'C1=CC(=CC=C1N)OC2=CC=C(C=C2)N', 'C1=CC=C(C=C1)COC2=CC(=CC=C2)N', 'C1=CC=C(C=C1)NC2=CC=C(C=C2)N', 'C1=CC=C(C=C1)NC2=CC=CC=C2N', 'C1=CC=C(C=C1)[C@H]([C@H](C2=CC=CC=C2)N)N', 'C1=CC=C(C=C1)C(C2=CC=CC=C2)(C3=CC=CC=C3)N', 'Nc1c(cc(cc1-c1ccccc1)-c1ccccc1)-c1ccccc1', 'C1=CC=C2C(=C1)C=CC(=C2C3=C(C=CC4=CC=CC=C43)N)N', 'N(c1ccccc1)c1ccc(cc1)-c1ccc(cc1)Nc1ccccc1', 'C1(c2c(cccc2)-c2ccccc21)(c1ccc(N)cc1)c1ccc(cc1)N', 'CC(NCCN(CCNC(C)C)CCNC(C)C)C']

frags_heads = []
names_heads = []
frags_heads.extend(frags_whitehead_amines)
names_heads.extend(["Whitehead_2014"] * len(frags_whitehead_amines))

print("{} heads from Whitehead et al. 2014 library.".format(len(frags_whitehead_amines)))
mols2grid.display([Chem.MolFromSmiles(x) for x in frags_whitehead_amines],mol_col="mol", n_cols=7, n_rows=4)
262 heads from Whitehead et al. 2014 library.

To this, we can add a large set of primary amines from the implementation we are using.

CODE
import pandas as pd
amines_klarich = pd.read_csv("https://raw.githubusercontent.com/PatWalters/TS/refs/heads/main/data/primary_amines_ok.smi", header=None, sep=" ", names=["SMILES", "NAME"])
print("{} amines from Klarich et al. 2024 library.".format(len(amines_klarich)))

frags_heads.extend(amines_klarich['SMILES'].tolist())
names_heads.extend(['_'.join(str(x).split()) for x in amines_klarich['NAME'].tolist()])
13842 amines from Klarich et al. 2024 library.
CODE
heads_df = pd.DataFrame([frags_heads, names_heads], index=["SMILES", "NAME"]).T
heads_df.to_csv("head_frags.smi", index=False, header=False, sep=" ")

Tail groups

Next, we gather a set of tail groups we might want to try in this variant screening. Drug designers in this field vary the tail groups in a number of ways, and we include a representative sample of these variations:

  • Varying the length of the tail (8, 10, 12, 14 carbons)

  • Varying the branching of the tail (unbranched and 2-branched tails explored)

  • Varying the saturation of the tail (each branch of each tail can contain a double bond somewhere along its length)

We can systematically generate a set of tail groups that vary in these ways. For virtual screening purposes, we’ll do this combinatorially, generating all possible combinations of these variations.

Our implementation is independently extensible here; we use core RDKit functionality to change individual atoms within the mol object and reconvert back into SMILES, which is much more robust and preferable to doing pure SMILES string manipulation.

CODE
from itertools import combinations, product
from rdkit import Chem
from rdkit import RDLogger

# 1‑line switch‑off for all RDKit warnings/info
RDLogger.DisableLog("rdkit")


# design‑space parameters -----------------------------------------------------
lengths        = [8, 10, 12, 14]   # carbons / tail
unsat_choices  = (0, 1, 2)                 # number of C=C per tail
spacing_allowed = {2, 3}                   # gap(s) between successive C=C bonds
# -----------------------------------------------------------------------------


def spacing_ok(db_positions):
    """Return True iff every spacing between successive double bonds is allowed."""
    if len(db_positions) < 2:
        return True
    db_positions = sorted(db_positions)
    return all(b - a in spacing_allowed for a, b in zip(db_positions, db_positions[1:]))


def build_linear_tail(n_carbon, db_pos):
    """HO–(CH2)n‑1–CH2–  with optional internal C=C bonds."""
    mol = Chem.RWMol()

    c_atoms = [mol.AddAtom(Chem.Atom("C")) for _ in range(n_carbon)]
    o_atom  = mol.AddAtom(Chem.Atom("O"))

    for i in range(n_carbon - 1):
        bond_type = Chem.BondType.DOUBLE if i in db_pos else Chem.BondType.SINGLE
        mol.AddBond(c_atoms[i], c_atoms[i + 1], bond_type)

    mol.AddBond(c_atoms[-1], o_atom, Chem.BondType.SINGLE)
    return Chem.MolToSmiles(mol, canonical=True)

def build_branched_tail(n1, n2, db1, db2):
    """Tertiary alcohol core with two hydrophobic tails."""
    mol  = Chem.RWMol()
    core = mol.AddAtom(Chem.Atom("C"))
    o_at = mol.AddAtom(Chem.Atom("O"))

    mol.AddBond(core, o_at, Chem.BondType.SINGLE)

    def add_chain(start_atom, length, db_pos_set):
        prev = start_atom
        for idx in range(length):
            c = mol.AddAtom(Chem.Atom("C"))
            bond_type = (
                Chem.BondType.DOUBLE if (idx + 1) in db_pos_set else Chem.BondType.SINGLE
            )
            mol.AddBond(prev, c, bond_type)
            prev = c

    add_chain(core, n1, set(db1))
    add_chain(core, n2, set(db2))

    return Chem.MolToSmiles(mol, canonical=True)


frags = set()


# unbranched tails
for L, d in product(lengths, unsat_choices):
    interior = range(1, L - 1)                   # keep core‑C and C‑OH single
    if d > len(interior):
        continue

    for db in combinations(interior, d):
        if not spacing_ok(db):
            continue
        try:
            smi = build_linear_tail(L, db)
            Chem.SanitizeMol(Chem.MolFromSmiles(smi))
            frags.add(smi)
        except:
            pass                           # drops impossible ones silently

# ---------- branched ------------
for L1, L2, d1, d2 in product(lengths, lengths, unsat_choices, unsat_choices):
    interior1 = range(1, L1)                     # skip bond 0 (core‑C1)
    interior2 = range(1, L2)

    if d1 > len(interior1) or d2 > len(interior2):
        continue

    for db1 in combinations(interior1, d1):
        if not spacing_ok(db1):
            continue
        for db2 in combinations(interior2, d2):
            if not spacing_ok(db2):
                continue
            try:
                smi = build_branched_tail(L1, L2, db1, db2)
                Chem.SanitizeMol(Chem.MolFromSmiles(smi))
                frags.add(smi)
            except Exception:
                pass



frags_tails = [Chem.CanonSmiles(f) for f in frags if Chem.MolFromSmiles(f) is not None]
names_tails = ["Tail_fragment"] * len(frags_tails)
tails_df = pd.DataFrame([frags_tails, names_tails], index=["SMILES", "NAME"]).T
tails_df.to_csv("tail_frags.smi", index=False, header=False, sep=" ")
CODE
print(f"{len(frags_tails)} unique fragments written.")
5474 unique fragments written.

Note that this type of combinatorial expansion is much easier to do computationally than in an assay, playing to the strengths of virtual screening.

The implementation involves helper functions linear_smiles and branched_smiles, which automatically generate tails of a particular nature. When all the combinations are expanded, we have a set of many hundreds of tail groups that can be used in the screening.

CODE
mols2grid.display([Chem.MolFromSmiles(x) for x in frags_tails],mol_col="mol", n_cols=7, n_rows=4)

This is too many structures to enumerate

These modifications are sensible for the field and apply very narrowly to only the single base structure SM-102 (we could do the same for any of the hundreds of other purchasable trade lipids of interest, for instance).

But we have already made our way into a combinatorial explosion.

CODE
# Write this in scientific notation
total_size = len(frags_heads) * len(frags_tails) * len(frags_tails)
print(f"Total size of the design space: {total_size:.2e} structures.")
Total size of the design space: 4.23e+11 structures.

Scoring a subset of structures

In a drug development pipeline, the overall goal is to pick structures which do well in experiments. Since experiments are often resource-intensive and costly to perform, AI in drug development seeks to eliminate the need for testing unviable structures, and more quickly focus on promising regions of chemical space for further development. Accordingly, there are many models in drug discovery that predict everything from physicochemical properties to in vivo responses and the results of high-throughput assays.

A model to score these structures

We’ll demonstrate this for LNP development, using our recurring example from previous posts. We’ll use a model from a popular paper (Xu et al. 2024) published recently that attempts to predict the transfection efficiency of an LNP with a given ionizable lipid component.

To do this fully, we’ll need to train such a model on a set of known LNPs with their transfection efficiencies. A popular dataset for this is that of the AGILE study (Xu et al. 2024), which contains a dataset of 1200 LNPs with their transfection efficiencies. It also contains a model pretrained on this dataset; with some effort, this can be downloaded and used to score the giant virtual space we have generated.

For now, we’ll train a new model on the AGILE dataset for transfection efficiency on 1200 structures.

CODE
import pandas as pd
from rdkit.Chem import rdMolDescriptors
#from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import numpy as np
import joblib

agile_dataset = pd.read_csv("https://raw.githubusercontent.com/bowang-lab/AGILE/refs/heads/main/AGILE_smiles_with_value_group.csv", index_col=0)
CODE
agile_fps = []
for s in agile_dataset['combined_mol_SMILES']:
    agile_fps.append(np.array(rdMolDescriptors.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(s), 3, nBits=2048)))
agile_fps = np.vstack(agile_fps)
X_train, X_test, y_train, y_test = train_test_split(agile_fps, agile_dataset['expt_Hela'])
rgr = RandomForestRegressor()
rgr.fit(X_train, y_train)
level = rgr.predict(X_test)
joblib.dump(rgr, 'modl.pkl')
['modl.pkl']
CODE
rgr_pickle = joblib.load('modl.pkl')
CODE
# import torch, yaml, argparse
# from rdkit import Chem
# from torch_geometric.data import Batch
# from models.graph_model import GraphModel          # model wrapper in AGILE
# from utils.mol_graph import mol_to_graph_data_obj  # RDKit → PyG graph

# def load_model(ckpt_path: str):
#     ckpt = torch.load(ckpt_path, map_location="cpu")
#     hparams = ckpt["hyper_parameters"]["model_params"]
#     net = GraphModel(**hparams)
#     net.load_state_dict(ckpt["state_dict"], strict=False)
#     net.eval()
#     return net

# https://github.com/bowang-lab/AGILE/raw/refs/heads/main/ckpt/pretrained_agile_60k/checkpoints/model.pth

Thompson sampling: act according to your beliefs

Thompson sampling (TS) is a decision-making method used to efficiently learn which option is best when facing uncertain outcomes. It was originally developed in the context of clinical trials, and is now widely used in machine learning and experimental design. It balances two paradigmatic goals in tension:

  • Exploration: Testing uncertain or less-studied compounds to improve overall knowledge.

  • Exploitation: Focusing on compounds already showing promise to efficiently optimize outcomes.

Thompson sampling works with a distribution over experimental outcomes representing the scientist’s current beliefs. At any point in time, these beliefs can be sampled to suggest a best course of action (e.g. compounds to test next). Thompson sampling consists of taking this action and updating the belief distribution according to Bayes’ rule.

It is well-understood, both empirically and theoretically, that such a strategy can work even if the initial belief distribution is dramatically misspecified. As more data is collected, the belief distribution converges to the true distribution of outcomes, and the sampling strategy becomes more focused on the options that are truly best.

A typical experimental context involves screening compounds to identify which yield the highest activity. Initially the experimentalist has little information, and belief distributions have high uncertainty, so Thompson sampling often samples high values even for options with lower estimated means, encouraging exploration. If a compound performs well, the experimentalist’s belief in its future performance shifts upward, increasing the likelihood that future samples will lead to selecting it again. Conversely, poor performance lowers future expectations. Over time, this naturally guides experiments toward the most promising compounds, while still periodically exploring less-understood ones to confirm or revise the current understanding. Eventually, the uncertainty ceases to dominate, and sampling tends to favor options with better estimated outcomes.

This approach elegantly and naturally balances exploration and exploitation, analogous to how method development is often approached anyway —- initially trying diverse approaches before refining the most promising ones.

The TS algorithm is one of the workhorses of modern machine learning. The trick is often in how well it works in practice, and here we try it on a truly vast library of ionizable lipids from the literature. In this post, we’ll construct such a library and perform this virtual scoring effort in a realistic lipid nanoparticle engineering context. This is paradigmatic of many small molecule drug discovery subfields, but with a combination of very broad application and a truly vast design space.

CODE
rxn_str = "[#6:1][NX3;H2:2].[#6:3][OX2H1:4].[#6:5][OX2H1:6]>>[#6:1][N:2](CCCCCC(=O)[OX2H1:4][#6:3])CCCCCC(=O)[OX2H1:6][#6:5]"

To finish setting up the TS algorithm, we need to precisely define the scoring function that will be used to evaluate the structures. This implementation constrains us to implement the evaluation function separately in its evaluators.py file.

CODE
class NewMLModelEvaluator(Evaluator):
    def __init__(self, input_dict):
        self.mdl = joblib.load(input_dict["model_filename"])
        self.num_evaluations = 0

    @property
    def counter(self):
        return self.num_evaluations

    def evaluate(self, mol):
        self.num_evaluations += 1
        fp = np.array(rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, 3, nBits=2048))
        return self.mdl.predict(fp.reshape(1, -1))

Finally, all this gets put together in the config file for the TS algorithm, which specifies the reaction schema and scoring function to use.

CODE
config_dict = {
    "reagent_file_list": [
        "data/head_frags.smi",
        "data/tail_frags.smi",
        "data/tail_frags.smi"
    ],
    "reaction_smarts": rxn_str, 
    "num_warmup_trials": 10, 
    "num_ts_iterations": 10000,
    "evaluator_class_name": "NewMLModelEvaluator",
    "evaluator_arg": {
        "model_filename": "modl.pkl"
    },
    "ts_mode": "maximize",
    "log_filename": "ts_logs.txt",
    "results_filename": "classification_model_out.csv"
}

References

Klarich, Kathryn, Brian Goldman, Trevor Kramer, Patrick Riley, and W Patrick Walters. 2024. “Thompson Sampling─ an Efficient Method for Searching Ultralarge Synthesis on Demand Databases.” Journal of Chemical Information and Modeling 64 (4): 1158–71.
Xu, Yue, Shihao Ma, Haotian Cui, Jingan Chen, Shufen Xu, Fanglin Gong, Alex Golubovic, et al. 2024. “AGILE Platform: A Deep Learning Powered Approach to Accelerate LNP Development for mRNA Delivery.” Nature Communications 15 (1): 6305.