Interactive browser, part 2

dataviz
cheminformatics
An interactively clustered heatmap
Author

Akshay Balsubramani

Modified

June 23, 2022

Interactive clustering of data

Previous posts in the series: [Part 1]

Interactive browsers have started to enter wider usage in various biochemical fields which observe extremely high-dimensional and structured data on a regular basis. I recently wrote about building the foundations of such a browser for visualizing chemical space, in which the user can define subsets of chemicals interactively.

Browsers like this need to perform general learning tasks without making unnecessary assumptions. The potential of this workflow can be realized through interactive algorithms driven by user selections. The theme of interactive computation is the focus of this post series. I’m going to explore it in this post, by demonstrating how to add a heatmap, and augment it to group subsets of chemicals based on their fingerprints.

Loading the previous browser

I’m first loading the previously created interactive browser, as described in the previous post in the series. That post covers the basics of how to build an interactive interface with subselection capabilities. The code to run the browser is imported from the relevant file produced in that previous post.

CODE
import requests
url = 'https://raw.githubusercontent.com/b-akshay/blog-tools/main/interactive-browser/part-1/scatter_chemviz.py'
r = requests.get(url)

# make sure your filename is the same as how you want to import 
with open('VE_browser.py', 'w') as f:
    f.write(r.text)

# now we can import
from VE_browser import *
CODE
import seaborn as sns

# Custom colorscales.
# From https://github.com/BIDS/colormap/blob/master/parula.py
# pc = [matplotlib.colors.to_hex(x) for x in parulac]; d = np.arange(len(pc)); d = np.round(d/max(d), 4); parula = [x for x in zip(d, pc)]
cmap_parula = [(0.0, '#352a87'), (0.0159, '#363093'), (0.0317, '#3637a0'), (0.0476, '#353dad'), (0.0635, '#3243ba'), (0.0794, '#2c4ac7'), (0.0952, '#2053d4'), (0.1111, '#0f5cdd'), (0.127, '#0363e1'), (0.1429, '#0268e1'), (0.1587, '#046de0'), (0.1746, '#0871de'), (0.1905, '#0d75dc'), (0.2063, '#1079da'), (0.2222, '#127dd8'), (0.2381, '#1481d6'), (0.254, '#1485d4'), (0.2698, '#1389d3'), (0.2857, '#108ed2'), (0.3016, '#0c93d2'), (0.3175, '#0998d1'), (0.3333, '#079ccf'), (0.3492, '#06a0cd'), (0.3651, '#06a4ca'), (0.381, '#06a7c6'), (0.3968, '#07a9c2'), (0.4127, '#0aacbe'), (0.4286, '#0faeb9'), (0.4444, '#15b1b4'), (0.4603, '#1db3af'), (0.4762, '#25b5a9'), (0.4921, '#2eb7a4'), (0.5079, '#38b99e'), (0.5238, '#42bb98'), (0.5397, '#4dbc92'), (0.5556, '#59bd8c'), (0.5714, '#65be86'), (0.5873, '#71bf80'), (0.6032, '#7cbf7b'), (0.619, '#87bf77'), (0.6349, '#92bf73'), (0.6508, '#9cbf6f'), (0.6667, '#a5be6b'), (0.6825, '#aebe67'), (0.6984, '#b7bd64'), (0.7143, '#c0bc60'), (0.7302, '#c8bc5d'), (0.746, '#d1bb59'), (0.7619, '#d9ba56'), (0.7778, '#e1b952'), (0.7937, '#e9b94e'), (0.8095, '#f1b94a'), (0.8254, '#f8bb44'), (0.8413, '#fdbe3d'), (0.8571, '#ffc337'), (0.873, '#fec832'), (0.8889, '#fcce2e'), (0.9048, '#fad32a'), (0.9206, '#f7d826'), (0.9365, '#f5de21'), (0.9524, '#f5e41d'), (0.9683, '#f5eb18'), (0.9841, '#f6f313'), (1.0, '#f9fb0e')]

# Default discrete colormap for <= 20 categories, from https://sashat.me/2017/01/11/list-of-20-simple-distinct-colors/. See also http://phrogz.net/css/distinct-colors.html and http://tools.medialab.sciences-po.fr/iwanthue/
cmap_custom_discrete = ["#bdbdbd", '#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#7d87b9', '#bec1d4', '#d6bcc0']

# Convenient discrete colormaps for large numbers of colors.
cmap_custom_discrete_44 = ['#745745', '#568F34', '#324C20', '#FF891C', '#C9A997', '#C62026', '#F78F82', '#EF4C1F', '#FACB12', '#C19F70', '#824D18', '#CB7513', '#FBBE92', '#CEA636', '#F9DECF', '#9B645F', '#502888', '#F7F79E', '#007F76', '#00A99D', '#3EE5E1', '#65C8D0', '#3E84AA', '#8CB4CD', '#005579', '#C9EBFB', '#000000', '#959595', '#B51D8D', '#C593BF', '#6853A0', '#E8529A', '#F397C0', '#DECCE3', '#E18256', '#9BAA67', '#8ac28e', '#68926b', '#647A4F', '#CFE289', '#00C609', '#C64B55', '#953840', '#D5D5D5']
cmap_custom_discrete_74 = ['#FFFF00', '#1CE6FF', '#FF34FF', '#FF4A46', '#008941', '#006FA6', '#A30059', '#FFDBE5', '#7A4900', '#0000A6', '#63FFAC', '#B79762', '#004D43', '#8FB0FF', '#997D87', '#5A0007', '#809693', '#6A3A4C', '#1B4400', '#4FC601', '#3B5DFF', '#4A3B53', '#FF2F80', '#61615A', '#BA0900', '#6B7900', '#00C2A0', '#FFAA92', '#FF90C9', '#B903AA', '#D16100', '#DDEFFF', '#000035', '#7B4F4B', '#A1C299', '#300018', '#0AA6D8', '#013349', '#00846F', '#372101', '#FFB500', '#C2FFED', '#A079BF', '#CC0744', '#C0B9B2', '#C2FF99', '#001E09', '#00489C', '#6F0062', '#0CBD66', '#EEC3FF', '#456D75', '#B77B68', '#7A87A1', '#788D66', '#885578', '#FAD09F', '#FF8A9A', '#D157A0', '#BEC459', '#456648', '#0086ED', '#886F4C', '#34362D', '#B4A8BD', '#00A6AA', '#452C2C', '#636375', '#A3C8C9', '#FF913F', '#938A81', '#575329', '#00FECF', '#B05B6F']

# Custom red/blue diverging for black background, from https://gka.github.io/palettes
cmap_custom_rdbu_diverging = [[0.0, '#0000ff'], [0.1111, '#442dfa'], [0.2222, '#6b59e0'], [0.3333, '#6766a3'], [0.4444, '#323841'], [0.5555, '#483434'], [0.6666, '#b3635b'], [0.7777, '#ee5d49'], [0.8888, '#ff3621'], [1.0, '#ff0000']]

# Custom yellow/blue diverging for black background. From the following code:
# x = sns.diverging_palette(227, 86, s=98, l=77, n=20, center='dark').as_hex(); [s for s in zip(np.arange(len(x))/(len(x)-1), x)]
cmap_custom_ylbu_diverging = [(0.0, '#3acdfe'), (0.0526, '#37bbe6'), (0.105, '#35a9cf'), (0.157, '#3295b6'), (0.210, '#2f829e'), (0.263, '#2d6f85'), (0.315, '#2a5d6e'), 
                              (0.368, '#274954'), (0.421, '#25373d'), (0.473, '#222324'), (0.526, '#232322'), (0.578, '#363621'), (0.631, '#474720'), (0.684, '#5a5a1e'), 
                              (0.736, '#6b6b1d'), (0.789, '#7e7e1c'), (0.842, '#8f901b'), (0.894, '#a2a21a'), (0.947, '#b3b318'), (1.0, '#c4c417')]
cmap_custom_orpu_diverging = [(0.0, '#c2b5fe'), (0.0526, '#b1a5e6'), (0.105, '#a096cf'), (0.157, '#8e85b6'), (0.210, '#7c759e'), (0.263, '#6a6485'), (0.315, '#59556e'), 
                              (0.368, '#464354'), (0.421, '#35343d'), (0.473, '#232324'), (0.526, '#242323'), (0.578, '#3d332a'), (0.631, '#544132'), (0.684, '#6e523a'), 
                              (0.736, '#856041'), (0.789, '#9e7049'), (0.842, '#b67f50'), (0.894, '#cf8f58'), (0.947, '#e79d5f'), (1.0, '#feac66')]

"""
Interprets dataset to get list of colors, ordered by corresponding color values.
"""
def get_discrete_cmap(num_colors_needed):
    cmap_discrete = cmap_custom_discrete_44
    # If the provided color map has insufficiently many colors, make it cycle
    if len(cmap_discrete) < num_colors_needed:
        cmap_discrete = sns.color_palette(cmap_discrete, num_colors_needed)
        cmap_discrete = ['#%02x%02x%02x' % (int(255*red), int(255*green), int(255*blue)) for (red, green, blue) in cmap_discrete]
    return cmap_discrete


params = {}

params['title'] = "Chemical space viewer"

# Can reverse black/white color scheme
bg_scheme_list = ['black', 'white']
params['bg_color'] = 'white'
params['legend_bgcolor'] = 'white'
params['edge_color'] = 'white'
params['font_color'] = 'black'
params['legend_bordercolor'] = 'black'
params['legend_font_color'] = 'black'

params['hm_colorvar_name'] = 'Value'
params['qnorm_plot'] = False
params['hm_qnorm_plot'] = False

params['colorscale_continuous'] = [(0, "blue"), (0.5, "white"), (1, "red")]    # 'Viridis'
params['colorscale'] = cmap_custom_discrete_44

params['hover_edges'] = ""
params['edge_width'] = 1
params['bg_marker_size_factor'] = 5
params['marker_size_factor'] = 5
params['bg_marker_opacity_factor'] = 0.5
params['marker_opacity_factor'] = 1.0
params['legend_font_size'] = 16
params['hm_font_size'] = 6

legend_font_macro = { 'family': 'sans-serif', 'size': params['legend_font_size'], 'color': params['legend_font_color'] }
colorbar_font_macro = { 'family': 'sans-serif', 'size': 8, 'color': params['legend_font_color'] }
hm_font_macro = { 'family': 'sans-serif', 'size': 8, 'color': params['legend_font_color'] }

style_unselected = { 'marker': { 'size': 2.5, 'opacity': 1.0 } }
style_selected = { 'marker': { 'size': 6.0, 'opacity': 1.0 } }
style_outer_dialog_box = { 'padding': 10, 'margin': 5, 'border': 'thin lightgrey solid', # 'borderRadius': 5, 
}

style_invis_dialog_box = { 'padding': 0, 'margin': 5 }
style_hm_colorbar = { 
    'len': 0.3, 'thickness': 20, 'xanchor': 'left', 'yanchor': 'top', 'title': params['hm_colorvar_name'], 'titleside': 'top', 'ticks': 'outside', 
    'titlefont': legend_font_macro, 'tickfont': legend_font_macro 
}
style_text_box = { 'textAlign': 'center', 'width': '100%', 'color': params['font_color'] }

style_legend = {
    'font': legend_font_macro, # bgcolor=params['legend_bgcolor'], 'borderwidth': params['legend_borderwidth'], 
    'borderwidth': 0, #'border': 'thin lightgrey solid', 
    'traceorder': 'normal', 'orientation': 'h', 
    'itemsizing': 'constant'
}

# ==========================================================================================================================
# ==========================================================================================================================

import numpy as np

"""
Returns scatterplot panel with selected points annotated, using the given dataset (in data_df) and color scheme.
"""
def build_main_scatter(
    data_df, color_var, 
    discrete=False, 
    bg_marker_size=params['bg_marker_size_factor'], marker_size=params['marker_size_factor'], 
    annotated_points=[], selected_point_ids=[], 
    highlight=False, selected_style=style_selected
):
    # Put arrows to annotate points if necessary
    annots = []
    point_names = np.array(data_df.index)
    looked_up_ndces = np.where(np.in1d(point_names, annotated_points))[0]
    for point_ndx in looked_up_ndces:
        absc = absc_arr[point_ndx]
        ordi = ordi_arr[point_ndx]
        cname = point_names[point_ndx]
        annots.append({
            'x': absc, 'y': ordi,
            'xref': 'x', 'yref': 'y', 
            # 'text': '<b>Cell {}</b>'.format(cname), 
            'font': { 'color': 'white', 'size': 15 }, 
            'arrowcolor': '#ff69b4', 'showarrow': True, 'arrowhead': 2, 'arrowwidth': 2, 'arrowsize': 2, 
            'ax': 0, 'ay': -50 
        })
    if highlight:
        selected_style['marker']['color'] = '#ff4f00' # Golden gate bridge red
        selected_style['marker']['size'] = 10
    else:
        selected_style['marker'].pop('color', None)    # Remove color if exists
        selected_style['marker']['size'] = 10
    
    traces_list = []
    cumu_color_dict = {}
    
    # Check to see if color_var is continuous or discrete and plot points accordingly
    if not discrete:     # Color_var is continuous
        continuous_color_var = np.array(data_df[color_var])
        spoints = np.where(np.isin(point_names, selected_point_ids))[0]
        # print(time.time() - itime, spoints, selected_point_ids)
        colorbar_title = params['hm_colorvar_name']
        pt_text = ["{}<br>Value: {}".format(point_names[i], round(continuous_color_var[i], 3)) for i in range(len(point_names))]
        max_magnitude = np.percentile(np.abs(continuous_color_var), 98)
        min_magnitude = np.percentile(np.abs(continuous_color_var), 2)
        traces_list.append({ 
            'name': 'Data', 
            'x': data_df['x'], 
            'y': data_df['y'], 
            'selectedpoints': spoints, 
            'hoverinfo': 'text', 
            'hovertext': pt_text, 
            'text': point_names, 
            'mode': 'markers', 
            'marker': {
                'size': bg_marker_size, 
                'opacity': params['marker_opacity_factor'], 
                'symbol': 'circle', 
                'showscale': True, 
                'colorbar': {
                    'len': 0.3, 
                    'thickness': 20, 
                    'xanchor': 'right', 'yanchor': 'top', 
                    'title': colorbar_title,
                    'titleside': 'top',
                    'ticks': 'outside', 
                    'titlefont': colorbar_font_macro, 
                    'tickfont': colorbar_font_macro
                }, 
                'color': continuous_color_var, 
                'colorscale': 'Viridis', #cmap_parula, #[(0, "white"), (1, "blue")], 
                'cmin': min_magnitude, 
                'cmax': max_magnitude
            }, 
            'selected': selected_style, 
            'type': 'scattergl'
        })
    else:    # Categorical color scheme, one trace per color
        cnt = 0
        num_colors_needed = len(np.unique(data_df[color_var]))
        colorscale_list = get_discrete_cmap(num_colors_needed)
        for idx in np.unique(data_df[color_var]):
            val = data_df.loc[data_df[color_var] == idx, :]
            point_ids_this_trace = list(val.index)
            spoint_ndces_this_trace = np.where(np.isin(point_ids_this_trace, selected_point_ids))[0]
            if idx not in cumu_color_dict:
                trace_color = colorscale_list[cnt]
                cnt += 1
                cumu_color_dict[idx] = trace_color
            trace_opacity = 1.0
            pt_text = ["{}<br>{}".format(point_ids_this_trace[i], idx) for i in range(len(point_ids_this_trace))]
            trace_info = {
                'name': str(idx), 
                'x': val['x'], 
                'y': val['y'], 
                'selectedpoints': spoint_ndces_this_trace, 
                'hoverinfo': 'text', 
                'hovertext': pt_text, 
                'text': point_ids_this_trace, 
                'mode': 'markers', 
                'opacity': trace_opacity, 
                'marker': { 'size': bg_marker_size, 'opacity': params['marker_opacity_factor'], 'symbol': 'circle', 'color': trace_color }, 
                'selected': selected_style
            }
            trace_info.update({'type': 'scattergl'})
            if False: #params['three_dims']:
                trace_info.update({ 'type': 'scatter3d', 'z': np.zeros(val.shape[0]) })
            traces_list.append(trace_info)
    
    return { 
        'data': traces_list, 
        'layout': {
            'margin': { 'l': 0, 'r': 0, 'b': 0, 't': 20}, 
            'clickmode': 'event',  # https://github.com/plotly/plotly.js/pull/2944/
            'hovermode': 'closest', 
            'dragmode': 'select', 
            'uirevision': 'Default dataset',     # https://github.com/plotly/plotly.js/pull/3236
            'xaxis': {
                'automargin': True, 
                'showticklabels': False, 
                'showgrid': False, 'showline': False, 'zeroline': False, 'visible': False 
                #'style': {'display': 'none'}
            }, 
            'yaxis': {
                'automargin': True, 
                'showticklabels': False, 
                'showgrid': False, 'showline': False, 'zeroline': False, 'visible': False 
                #'style': {'display': 'none'}
            }, 
            'legend': style_legend, 
            'annotations': annots, 
            'plot_bgcolor': params['bg_color'], 
            'paper_bgcolor': params['bg_color']
        }
    }


def run_update_scatterplot(
    data_df, color_var, 
    annotated_points=[],      # Selected points annotated
    selected_style=style_selected, highlighted_points=[]
):
    pointIDs_to_select = highlighted_points
    num_colors_needed = len(np.unique(data_df[color_var]))
    # Anything less than 75 categories is currently considered a categorical colormap.
    discrete_color = (num_colors_needed <= 75)
    return build_main_scatter(
        data_df, color_var, discrete=discrete_color, 
        highlight=True, 
        bg_marker_size=params['bg_marker_size_factor'], marker_size=params['marker_size_factor'], 
        annotated_points=annotated_points, selected_point_ids=pointIDs_to_select, 
        selected_style=selected_style
    )


# ==========================================================================================================================
# ==========================================================================================================================

import plotly.graph_objects as go
import scanpy as sc, anndata
import time

anndata_all = sc.read('approved_drugs.h5')

data_df = anndata_all.obs



# ==========================================================================================================================
# ==========================================================================================================================

import os, base64
from io import BytesIO

running_colab = False


import scipy

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
if running_colab:
    from jupyter_dash import JupyterDash


def create_div_mainctrl():
    color_options = list(data_df.columns)
    default_val = color_options[0] if len(color_options) > 0 else ' '
    
    return html.Div(
        children=[
            html.Div(
                className='row', 
                children=[
                    html.Div(
                        className='row', 
                        children=[
                            html.Div(
                                children='Select color: ', 
                                style={ 'textAlign': 'center', 'color': params['font_color'], 'padding-top': '0px' }
                            ), 
                            dcc.Dropdown(
                                id='color-selection', 
                                options = [ {'value': color_options[i], 'label': color_options[i]} for i in range(len(color_options)) ], 
                                value=default_val, 
                                placeholder="Select color...", clearable=False
                            )], 
                        style={'width': '100%', 'padding-top': '10px', 'display': 'inline-block', 'float': 'center'}
                    )
                    ], 
                style={'width': '100%', 'padding-top': '10px', 'display': 'inline-block'}
            ), 
            dcc.Textarea(
                id='selected-chems',
                value=', '.join([]),
                style={'width': '100%', 'height': 100},
            ), 
            html.Div(
                className='six columns', 
                children=[
                    html.A(
                        html.Button(
                            id='download-button', 
                            children='Save', 
                            style=style_text_box, 
                            n_clicks=0, 
                            n_clicks_timestamp=0
                        ), 
                        id='download-set-link',
                        download="selected_set.csv", 
                        href="",
                        target="_blank", 
                        style={
                            'width': '100%', 
                            'textAlign': 'center', 
                            'color': params['font_color']
                        }
                    )], 
                style={'padding-top': '20px'}
            ), 
            html.Div([html.Img(id="chem2D-image")], style={'width': '50%', 'height': 100} )
        ], 
        style={'width': '29%', 'display': 'inline-block', 'fontSize': 12, 'margin': 5}
    )


def create_div_layout():
    return html.Div(
        className="container", 
        children=[
            html.Div(
                className='row', 
                children=[ html.H1(id='title', children=params['title'], style=style_text_box) ]
            ), 
            html.Div(
                className="browser-div", 
                children=[
                    create_div_mainctrl(), 
                    html.Div(
                        className='row', 
                        children=[
                            html.Div(
                                children=[
                                    dcc.Graph(
                                        id='landscape-plot',
                                        config={'displaylogo': False, 'displayModeBar': True}, 
                                        style={ 'height': '100vh'}
                                    )], 
                                style={}
                            )], 
                        style={'width': '69%', 'display': 'inline-block', 'float': 'right', 'fontSize': 12, 'margin': 5}
                    ), 
                    html.Div([ html.Pre(id='test-select-data', style={ 'color': params['font_color'], 'overflowX': 'scroll' } ) ]),     # For testing purposes only!
                    html.Div(
                        className='row', 
                        children=[ 
                            dcc.Markdown(
                                """ """
                            )], 
                        style={ 'textAlign': 'center', 'color': params['font_color'], 'padding-bottom': '10px' }
                    )],
                style={ 'width': '100vw', 'max-width': 'none' }
            )
        ],
        style={ 'backgroundColor': params['bg_color'], 'width': '100vw', 'max-width': 'none' }
    )


if running_colab:
    JupyterDash.infer_jupyter_proxy_config()
    app = JupyterDash(__name__)
else:
    external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
    app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

# Create server variable with Flask server object for use with gunicorn
server = app.server

app.title = params['title']
app.layout = create_div_layout()


# ==========================================================================================================================
# ==========================================================================================================================

from rdkit import Chem
from rdkit.Chem.Draw import MolsToGridImage


def b64string_image(smiles_list, molecule_name_list, num_molecules=12, return_encoding=False):
    """
    This code displays the first `num_molecules` molecules given in the input lists. 
    Returns a b64encoded image of the grid of molecules.
    """
    mol_list = [Chem.MolFromSmiles(x) for x in smiles_list]
    img = MolsToGridImage(mol_list[:num_molecules], molsPerRow=4, subImgSize=(150,150), legends=molecule_name_list, returnPNG=False)
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    encoded_image = base64.b64encode(buffered.getvalue())
    if return_encoding:
        return encoded_image.decode()
    else:
        return 'data:image/png;base64,{}'.format(encoded_image.decode())


# Example usage.
from IPython import display

src_str = b64string_image(list(anndata_all.obs['SMILES']), list(anndata_all.obs['Preferred name']), return_encoding=True)
display.HTML(f"<div><img src='data:image/png;base64, {src_str}'/></div>")


# ==========================================================================================================================
# ==========================================================================================================================

"""
Update the selected chemical image
"""
@app.callback(
    Output('chem2D-image', 'src'),
    [Input('landscape-plot', 'selectedData'), 
     Input('landscape-plot', 'clickData')])
def display_selected_data(
    selected_points, 
    clicked_points, 
    num_molecules=12
):
    empty_plot = "data:image/gif;base64,R0lGODlhAQABAAAAACwAAAAAAQABAAA="
    if selected_points:
        if len(selected_points['points']) == 0:
            return empty_plot
        this_df = data_df.loc[np.array([str(p['text']) for p in selected_points['points']])]
        smiles_list = list(this_df['SMILES'])
        molecule_name_list = list(this_df['Preferred name'])
        src_str = b64string_image(smiles_list, molecule_name_list, num_molecules=num_molecules)
    else:
        return empty_plot
    return src_str


"""
Update the text box with selected chemicals.
"""
@app.callback(
    Output('selected-chems', 'value'), 
    [Input('landscape-plot', 'selectedData')])
def update_text_selection(selected_points):
    if (selected_points is not None) and ('points' in selected_points):
        selected_IDs = [str(p['text']) for p in selected_points['points']]
    else:
        selected_IDs = []
    return ', '.join(selected_IDs) + '\n\n'


def get_pointIDs(selectedData_points):
    toret = []
    if (selectedData_points is not None) and ('points' in selectedData_points):
        for p in selectedData_points['points']:
            pt_txt = p['text'].split('<br>')[0]
            toret.append(pt_txt)
        return toret
    else:
        return []


@app.callback(
    Output('download-set-link', 'href'),
    [Input('landscape-plot', 'selectedData')]
)
def save_selection(landscape_data):
    subset_store = get_pointIDs(landscape_data)
    save_contents = '\n'.join(subset_store)
    return "data:text/csv;charset=utf-8," + save_contents


"""
Update the main scatterplot panel.
"""
@app.callback(
    Output('landscape-plot', 'figure'), 
    [Input('color-selection', 'value')]
)
def update_landscape(color_var):
    annotated_points = []
    lscape = run_update_scatterplot(
        data_df, 
        color_var
    )
    return lscape

A heatmap to display the features

I’m going to add another panel to the interface, with profound implications. This is a heatmap that displays feature-level values for every observation, i.e. it displays the “raw data” – in this case, the values of individual fingerprints for a single molecule. It allows the user to define subgroups of the data based on their representations directly, rather than based on some derived 2D scatterplots.

Selecting features to view

In this case, we could use my original featurization of 2048 Morgan (extended connectivity) fingerprint bits for this chemical data.

The heatmap only takes up less than half the screen horizontally - a few hundred pixels. So there are clearly too many features to view individually - more than the number of pixels displaying the heatmap. This is quite common in data visualization, and it is crucially important which features are selected to view.

As a first pass, I’ve implemented a simple function that selects the maximum-variance features in the data.

CODE
def interesting_feat_ndces(fit_data, num_feats_todisplay=500):
    num_feats_todisplay = min(fit_data.shape[1], num_feats_todisplay)
    if ((fit_data is None) or 
        (np.prod(fit_data.shape) == 0)
       ):
        return np.arange(num_feats_todisplay)
    feat_ndces = np.argsort(np.std(fit_data, axis=0))[::-1][:num_feats_todisplay]
    return feat_ndces

This runs at interactive speeds, and provides a guardrail against rendering astronomically many features.

Note: This will fail if the data are normalized to have unit variance per feature. So you may need to change this for your particular purpose.

Timing

What do I mean by “interactive speeds”? Running it on this data (>2K features) takes about 0.02s on my laptop. This is far below the ~1 second needed to perceptually throw the user out of the feeling of interactivity, a hard limit on our interactive designs that I’ve written about recently. Timing this feature selection step is easy, and good practice.

CODE
%timeit interesting_feat_ndces(anndata_all.X)
20.3 ms ± 877 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Displaying metadata

Though it’s possible to interactively narrow down to a manageable number of fingerprint bits, they are not easy to individually interpret. So we opt instead to display some descriptors calculated using various human-interpretable RDKit utilities, as I wrote about earlier.

These are already stored as metadata columns in the .obs dataframe; there are 208 such columns we display as a heatmap.

CODE
print("Shape of displayed metadata: {}".format(anndata_all.obs.iloc[:, :-8].values.shape))
Shape of displayed metadata: (5903, 208)

Co-clustering to group the selected data

This uses a function compute_coclustering which groups the rows and columns, and returns the cluster indices and IDs.

We focus on one operation in particular: ordering the rows and columns of the heatmap to emphasize the internal structure. This can be done with co-clustering – assigning joint cluster labels to the rows and columns of the heatmap.

An implementation with linear algebra

There are many methods for co-clustering, including work linking it to information theory (Dhillon, Mallela, and Modha 2003). I’ll use the most efficient, the spectral partitioning method of (Dhillon 2001). This is implemented in Scikit-learn. Here’s a wrapper around it setting it up for use by the browser.

CODE
from sklearn.cluster import SpectralCoclustering

def compute_coclustering(
    fit_data, 
    num_clusters=1, 
    tol_bicluster=0.005,  # sparsity otherwise annoyingly causes underflows w/ sklearn
):
    if num_clusters == 1:
        num_clusters = min(fit_data.shape[0], 5)    # = (working_object.shape[1]//5)
    if scipy.sparse.issparse(fit_data):
        fit_data = fit_data.toarray()
    row_labels, col_labels = cocluster_core_sklearn(fit_data, num_clusters)
    return (np.argsort(row_labels), np.argsort(col_labels), row_labels, col_labels)


def cocluster_core_sklearn(
    fit_data, 
    num_clusters, 
    random_state=0
):
    model = SpectralCoclustering(n_clusters=num_clusters, random_state=random_state)
    model.fit(fit_data)
    return model.row_labels_, model.column_labels_

Timing

This is fast enough to be within the ~0.5-to-1-second perceptual barrier on my laptop, when considering the entire dataset of nearly 6,000 chemicals.

CODE
%timeit compute_coclustering(anndata_all.X[:, interesting_feat_ndces(anndata_all.X)])
300 ms ± 2.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Displaying and viewing the resulting heatmap

Generating the heatmap in Plotly

Putting it all together, here’s some wrapper code for translating all these configuration and style options into Plotly for a heatmap.

\({\normalsize \textbf{Details}}\)
  • The code in hm_hovertext() creates tooltip text for the heatmap. This is done pretty simply right now, but is where the performance bottleneck is.
  • The code uses several display configuration options from the params dictionary defined above. The code could be fewer lines, but everyone likes different display configurations!
  • The main implementation details here arise from the differences in handling discrete and continuous colorscales, which affects subset selection and annotation of points.
  • Further details, like the displayed name of the color variable, are set to generic defaults that are exposed in the code and can be changed.
CODE
from sklearn.preprocessing import StandardScaler


def hm_hovertext(data, rownames, colnames):
    pt_text = []
    # First the rows, then the cols
    for r in range(data.shape[0]):
        pt_text.append(["Observation: {}".format(str(rownames[r])) for k in data[r, :]])
        for c in range(data.shape[1]):
            pt_text[r][c] += "<br>Feature: {}<br>Value: {}".format(str(colnames[c]), str(round(data[r][c], 3)))
    return pt_text


def display_heatmap_cb(
    data_df, 
    color_var, 
    row_annots=None, 
    show_legend=False, 
    col_alphabet=True, 
    plot_raw=True, 
    max_cols_heatmap=400, 
    xaxis_label=True, yaxis_label=True, 
    scatter_frac_domain=0.10
):
    itime = time.time()
    if data_df is None or len(data_df.shape) < 2:
        return
    working_object = data_df
    
    # Identify (interesting) features to plot. Currently: high-variance ones
    if working_object.shape[1] > 500:
        feat_ndces = interesting_feat_ndces(working_object.values)
        working_object = working_object.iloc[:, feat_ndces]
    
    # Here we subsample down to `max_rows_allowed` rows if needed, and make data prettier for printing
    max_rows_allowed = 1000
    if working_object.shape[0] > max_rows_allowed:
        ndxs = np.random.choice(np.arange(working_object.shape[0]), size=max_rows_allowed, replace=False)
        working_object = working_object.iloc[ndxs, :]
        if row_annots is not None:
            row_annots = row_annots[ndxs]
    
    # Spectral coclustering to cluster the heatmap. We always order rows (points) by spectral projection, but cols (features) can have different orderings for different viewing options.
    if (working_object.shape[0] > 1):
        fit_data = StandardScaler().fit_transform(working_object.values)
        ordered_rows, ordered_cols, row_clustIDs, col_clustIDs = compute_coclustering(fit_data)
        if row_annots is not None:
            ordered_rows = np.lexsort((row_clustIDs, row_annots))
        working_object = working_object.iloc[ordered_rows, :]
    else:
        ordered_cols = np.arange(working_object.shape[1])   # Don't reorder at all
    if col_alphabet:
        ordered_cols = np.argsort(working_object.columns)    # Order columns alphabetically by feature name
    working_object = working_object.iloc[:, ordered_cols]
    # Finished reordering rows/cols
    
    working_object = working_object.copy()
    hm_point_names = np.array(working_object.index)
    absc_labels = np.array(working_object.columns)
    
    row_scat_traces = hm_row_scatter(working_object, color_var, hm_point_names)
    if not plot_raw:
        working_object.values = StandardScaler().fit_transform(working_object.values)
    
    pt_text = hm_hovertext(working_object.values, hm_point_names, absc_labels)
    
    hm_trace = {
        'z': working_object.values, 
        'x': absc_labels, 
        'customdata': hm_point_names, 
        'hoverinfo': 'text',
        'text': pt_text, 
        'colorscale': params['colorscale_continuous'], 
        'colorbar': {
            'len': 0.3, 'thickness': 20, 
            'xanchor': 'left', 'yanchor': 'top', 
            'title': params['hm_colorvar_name'], 'titleside': 'top', 'ticks': 'outside', 
            'titlefont': colorbar_font_macro, 
            'tickfont': colorbar_font_macro
        }, 
        'type': 'heatmap'
    }
    max_magnitude = np.percentile(np.abs(working_object.values), 98) if working_object.shape[0] > 0 else 2
    hm_trace['zmin'] = -max_magnitude
    hm_trace['zmax'] = max_magnitude
    
    return {
        'data': [ hm_trace ] + row_scat_traces, 
        'layout': {
            'xaxis': {
                'automargin': True, 
                'showticklabels': False, 
                'showgrid': False, 'showline': False, 'zeroline': False, #'visible': False, 
                #'style': {'display': 'none'}, 
                'domain': [scatter_frac_domain, 1]
            }, 
            'yaxis': {
                'automargin': True, 
                'showticklabels': False, 
                'showgrid': False, 'showline': False, 'zeroline': False, #'visible': False, 
                #'style': {'display': 'none'}
            }, 
            'annotations': [{
                    'x': 0.5, 'y': 1.10, 'showarrow': False, 
                    'font': { 'family': 'sans-serif', 'size': 15, 'color': params['legend_font_color'] }, 
                    'text': 'Features' if xaxis_label else '',
                    'xref': 'paper', 'yref': 'paper'
                }, 
                {
                    'x': 0.05, 'y': 0.5, 'showarrow': False, 
                    'font': { 'family': 'sans-serif', 'size': 15, 'color': params['legend_font_color'] }, 
                    'text': 'Observations' if yaxis_label else '', 'textangle': -90, 
                    'xref': 'paper', 'yref': 'paper'
                }
            ], 
            'margin': { 'l': 30, 'r': 0, 'b': 0, 't': 70 }, 
            'hovermode': 'closest', 'clickmode': 'event',  # https://github.com/plotly/plotly.js/pull/2944/
            'uirevision': 'Default dataset', 
            'legend': style_legend, 'showlegend': show_legend, 
            'plot_bgcolor': params['bg_color'], 'paper_bgcolor': params['bg_color'], 
            'xaxis2': {
                'showgrid': False, 'showline': False, 'zeroline': False, 'visible': False, 
                'domain': [0, scatter_frac_domain], 
                'range': [-1, 0.2]
            }
        }
    }


def hm_row_scatter(data_df, color_var, hm_point_names):
    return []

This defines a dummy function hm_row_scatter. The job of this function is to return a column of scatterplot points alongside the rows to serve as a link between the scatterplot and the heatmap. Making this link work for significantly better subset selection is a task in itself. There is also additional code there I’ve left unexplained for now, involving customizable “row annotations” that can be used to modify the order of what’s displayed.

Delving into all this would make our discussion here very long, and is the subject of a subsequent post.

The end product

The code from this notebook is available as a file; setting it up an environment will deploy the app.

There are a couple of interesting stories along the way which merit their own code snippets.

Viewing the heatmap statically

One bit of code here normalizes the heatmap and then shows how to render it in HTML.

CODE
import plotly.graph_objects as go
from IPython.display import HTML
import scanpy as sc, anndata
import time

def custom_colwise_norm_df(data_df):
    # Define custom column-wise normalization here.
    data_df = data_df.iloc[:, :-8]
    for col in data_df.columns:
        newvals = np.nan_to_num(data_df[col])
        if np.std(newvals) > 0:
            newvals = scipy.stats.zscore(newvals)
        data_df[col].values[:] = newvals
    return data_df

anndata_all = sc.read('approved_drugs.h5')
data_df = anndata_all.obs
data_df = custom_colwise_norm_df(data_df)

# Export scatterplot as HTML.
fig_data = display_heatmap_cb(data_df, 'MolWt')
fig = go.Figure(**fig_data)
# HTML(fig.to_html())

Cumulative timing

Let’s time everything so far:

  • Feature selection
  • Coclustering of the heatmap
  • Putting together the Plotly figure
  • Rendering the figure (e.g. in HTML)

Recall that for a fully interactive experience, we should keep the overall computation time under ~1s when the user selects a chunk of the scatterplot.

CODE
from IPython.display import HTML

%timeit fig_data = display_heatmap_cb(data_df, 'MolWt')
%timeit fig = go.Figure(**fig_data)
%timeit HTML(fig.to_html())
1.45 s ± 79.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
427 ms ± 5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
309 ms ± 5.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is not happening with the present design; large heatmaps cannot be rendered without the user feeling like they have to pause their thoughts. After profiling the code in display_heatmap_cb, it turns out that most of the time is being spent in hm_hovertext. We leave optimization of this function to a later post.

Warning: The bottleneck of the interface at present is putting together the tooltip text that displays when cells are hovered upon.

The main heatmap callback

The central use of the heatmap is as an aid in exploring and showing structure in (clustering) whatever data the user has selected. So at the most basic level, the callback that updates the heatmap fires when the user selection changes. We add this Dash code to the app.

CODE
"""
Update the main heatmap panel.
"""
@app.callback(
    Output('main-heatmap', 'figure'),
    [Input('landscape-plot', 'selectedData')])
def update_heatmap(
    selected_points
):
    if (selected_points is not None) and ('points' in selected_points):
        selected_IDs = [p['text'].split('<br>')[0] for p in selected_points['points']]
    else:
        selected_IDs = []
    if len(selected_IDs) == 0:
        selected_IDs = data_df.index
    
    subsetted_data = data_df.loc[np.array(selected_IDs)]
    subsetted_data = custom_colwise_norm_df(subsetted_data)
    print(f"Subsetted data: {subsetted_data.shape}")
    row_annotations = None
    return display_heatmap_cb(
        subsetted_data, 'MolWt', 
        row_annots=row_annotations, 
        xaxis_label=True, yaxis_label=True
    )

Summary: interactive clustering

We’ve seen how to display the raw data in a heatmap, and cocluster it to visually organize information on the fly at a user-defined resolution.

Takeaways

Rule of thumb: No more than ~500K entries in the displayed heatmap.

\({\normalsize \textbf{Explanation}}\)
  • Screen capability: A screen has finite resolution, and a row/column of the heatmap takes at least a pixel to intelligibly distinguish. So this will display the maximum possible if e.g. a 1000 x 500 pixel area of the interface being devoted to the heatmap.
  • User perception: This also goes back to one of the central principles I talked about in a previous post - all computations should occur in less than a second in order to make a human user feel like the experience is interactive. One of the most intensive types of computations the browser performs is rendering the heatmap and its tens of thousands of individual entries, each with a different tooltip that displays upon hovering. This seems to be a bottleneck at scales around \(~10^5\) or more – much more than this cannot comfortably be displayed on most desktop/laptop screens, let alone mobile.

This turns out to be a pretty stringent requirement - even 1000 observations of 500 features each can take a while to render, as we’ll see. Such considerations are tied to the framework, and using better heatmap renderers can give significant speedups.

But the datasets we deal with in biochemical sciences are often several orders of magnitude larger. I’ll write next about bridging that gap of scale.

Up ahead: scaling up

As the capabilities of this browser grow, we wonder: how much algorithmic functionality can we enable at interactive speeds?

  • Next up is a crucial pit stop along this journey, adding several upgrades to the browser’s abilities that allow the user to zoom into data-driven subsets.
  • After that, we’ll open up some corners of the algorithmic toolbox on these subsets, and demonstrate what is possible at these speeds.

One key thing to remember is that the process is already bottlenecked by runtime considerations. Around 10,000 points seem to be all that will visualize at perceptually interactive speeds on a CPU-bound local machine.1 So this sets a practical limit on the size of the datasets passed into algorithms for learning.

In the next posts, we’ll look at what algorithmic and visualization functionality we can include to make the user’s life easier.

References

Dhillon, Inderjit S. 2001. “Co-Clustering Documents and Words Using Bipartite Spectral Graph Partitioning.” In Proceedings of the Seventh ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 269–74.
Dhillon, Inderjit S, Subramanyam Mallela, and Dharmendra S Modha. 2003. “Information-Theoretic Co-Clustering.” In Proceedings of the Ninth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 89–98.

Footnotes

  1. However, much more than that is possible using GPUs, with technologies like WebGL.↩︎