Introduction: data-driven gifs

Nothing beats a short gif-length video for driving awareness of ML work on social media. Even when the work involves complex algorithms on data, a few plots of the algorithm's output ranging over different inputs can be worth an entire abstract.

In this notebook, I'll document how I make a data-driven gif from a sequence of data-driven images, i.e. plots.

Code

Most of the code is wrapped in a single function, which makes scatter plots with matplotlib.

The wonderful library celluloid is used to string together frames (images). The function below returns the animation as a matplotlib "ArtistAnimation" object.

import matplotlib.pyplot as plt
from IPython.display import HTML
from celluloid import Camera

def create_animation(coords, clrarr, niter=None, qnorm=True):
    """
    Params:
        coords: An array with each row representing an observation (point in scatterplot). First and second columns are treated as x and y coordinates.
        clrarr: One of two possibilities:
            - List of arrays, representing color values at each frame.
            - Function taking one argument, the iteration number. In this case niter must be specified.
    
    Returns:
        animation: A matplotlib ArtistAnimation object
    """
    fig, axes = plt.subplots(1,1, figsize=(6,6))
    plt.tight_layout()
    camera = Camera(fig)
    
    axes.axes.xaxis.set_ticks([])
    axes.axes.yaxis.set_ticks([])
    plt.subplots_adjust(wspace=0, hspace=0)
    #fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)
    
    num_iterations = len(clrarr) if niter is None else niter
    for it in range(num_iterations):
        if it % 10 == 0:
            print(f"iter: [{it}/{num_iterations}]")
        if niter is None:
            new_colors = clrarr[it]
        else:   # clrarr should be a function.
            # TBD: If clrarr is not a function, throw an exception
            # If clrarr is a function, assume it follows the spec above. Then:
            new_colors = clrarr(it)
        plot_arr = scipy.stats.rankdata(new_colors) if qnorm else new_colors
        axes.scatter(coords[:,0], coords[:,1], c=plot_arr, s=0.3)
        camera.snap()
        plt.show()
    animation = camera.animate()
    return animation

Example: ranking on a data manifold

Single-cell data

An important area of application of graph methods in the life sciences is in the analysis of single-cell data. A single-cell experiment can measure the biochemical activity of thousands of cells at once, yielding a relatively comprehensive sample of a wide range of cell types, including some which cannot be physically dissociated from each other.

With extensive cell type heterogeneity being typical and interesting, single-cell researchers want to squeeze as much as possible out of each dataset without averaging over cell-to-cell variation. So researchers have favored a form of analysis that extensively uses the nearest-neighbor graph 1 between the cells as a "data manifold." Single-cell datasets are a good way of trying out these methods.

I'll demonstrate label propagation on single-cell gene expression data from the GTEx consortium, as I've written about previously.

import sys, numpy as np, time, scipy
import matplotlib.pyplot as plt
%matplotlib inline

import anndata, scanpy as sc

print('Packages imported.')


# !curl https://storage.googleapis.com/gtex_analysis_v9/snrna_seq_data/GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5ad -o GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5ad

gtex9_all_adta_name = 'GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5ad'
fname = gtex9_all_adta_name
adta = sc.read(fname)
Packages imported.

Import algorithm

I'll use a brief implementation of the algorithm written in my earlier post. Please see that post for the description.

import requests
url = 'https://raw.githubusercontent.com/b-akshay/blog-tools/main/label_propagation.py'
r = requests.get(url)

# make sure your filename is the same as how you want to import 
with open('label_propagation.py', 'w') as f:
    f.write(r.text)

# now we can import
from label_propagation import *

Run algorithm and generate animation

We'll pick 2 points to be the query, and highlight the diffusion of ranks from these two seed query points.

carr = np.zeros((adta.shape[0], 1))
carr[0, 0] = 1
carr[100000, 0] = 1
itime = time.time()
cumuvals = label_propagation(carr, adta.obsp['connectivities'], np.ravel(carr == 1), param_alpha=0.999, method='iterative')
print("Time taken for label propagation: {}".format(time.time() - itime))
7.453010006368602
2.8609513970753886
1.8050233011499106
1.1854938842580347
0.7845198224322854
0.5154911327134338
0.3414808632840092
0.22661226968256298
0.15236782319569356
0.10311849590080815
0.07074001697455569
0.0489321664681581
0.034294447731728184
0.024280293350557564
0.017433353971035893
0.012681724358227134
0.009377004553010345
Time taken for label propagation: 0.15927386283874512

#hide-output

itime = time.time()

coords = adta.obsm['X_umap']
animation = create_animation(coords, cumuvals)

print("Time taken to generate frames: {}".format(time.time() - itime))

Render animation

Now we can render this inline in the browser; there are other ways to do it, which ran into more painful dependency issues when I tried them.

itime = time.time()

#HTML(animation.to_html5_video())
HTML(animation.to_jshtml())

print("Time taken to produce html animation: {}".format(time.time() - itime))
Time taken to produce html animation: 57.97186589241028
HTML(animation.to_jshtml())
</input>

Save animation

If necessary, the animation can be saved. The easiest way to do this as an mp4 is to set up the ffmpeg MovieWriter and then run the following.

animation.save("ranking_search.mp4", dpi=200, fps=20)

References