Nothing beats a short gif-length video for driving awareness of ML work on social media. Even when the work involves complex algorithms on data, a few plots of the algorithm’s output ranging over different inputs can be worth an entire abstract.
Here’s a short way to make a data-driven gif from a sequence of data-driven plots.
Code
Most of the code is wrapped in a single function, which makes scatter plots with matplotlib.
The wonderful library celluloid is used to string together frames (images). The function below returns the animation as a matplotlib “ArtistAnimation” object.
CODE
import matplotlib.pyplot as pltfrom IPython.display import HTMLfrom celluloid import Cameradef create_animation(coords, clrarr, niter=None, qnorm=True):""" Params: coords: An array with each row representing an observation (point in scatterplot). First and second columns are treated as x and y coordinates. clrarr: One of two possibilities: - List of arrays, representing color values at each frame. - Function taking one argument, the iteration number. In this case niter must be specified. Returns: animation: A matplotlib ArtistAnimation object """ fig, axes = plt.subplots(1,1, figsize=(6,6)) plt.tight_layout() camera = Camera(fig) axes.axes.xaxis.set_ticks([]) axes.axes.yaxis.set_ticks([]) plt.subplots_adjust(wspace=0, hspace=0)#fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None) num_iterations =len(clrarr) if niter isNoneelse niterfor it inrange(num_iterations):if it %10==0:print(f"iter: [{it}/{num_iterations}]")if niter isNone: new_colors = clrarr[it]else: # clrarr should be a function.# TBD: If clrarr is not a function, throw an exception# If clrarr is a function, assume it follows the spec above. Then: new_colors = clrarr(it) plot_arr = scipy.stats.rankdata(new_colors) if qnorm else new_colors axes.scatter(coords[:,0], coords[:,1], c=plot_arr, s=0.3) camera.snap() plt.show() animation = camera.animate()return animation
ModuleNotFoundError: No module named 'celluloid'
Example: ranking on a data manifold
I’ll demonstrate label propagation on single-cell gene expression data from the GTEx consortium, as I’ve written about previously.
CODE
import sys, numpy as np, time, scipyimport matplotlib.pyplot as plt%matplotlib inlineimport anndata, scanpy as scprint('Packages imported.')# !curl https://storage.googleapis.com/gtex_analysis_v9/snrna_seq_data/GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5ad -o GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5adgtex9_all_adta_name ='GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5ad'fname = gtex9_all_adta_nameadta = sc.read(fname)
Packages imported.
Ranking algorithm
I’ll use a brief implementation of the algorithm written in my earlier post. Please see that post for the description.
CODE
import requestsurl ='https://raw.githubusercontent.com/b-akshay/blog-tools/main/label_propagation.py'r = requests.get(url)# make sure your filename is the same as how you want to import withopen('label_propagation.py', 'w') as f: f.write(r.text)# now we can importfrom label_propagation import*
Run algorithm and generate animation
We’ll pick 2 points to be the query, and highlight the diffusion of ranks from these two seed query points.
CODE
carr = np.zeros((adta.shape[0], 1))# Arbitrary points as the initial querycarr[0, 0] =1carr[100000, 0] =1itime = time.time()cumuvals = label_propagation(carr, adta.obsp['connectivities'], np.ravel(carr ==1), param_alpha=0.999, method='iterative')print("Time taken for label propagation: {}".format(time.time() - itime))
Now we render the animation.
CODE
itime = time.time()coords = adta.obsm['X_umap']animation = create_animation(coords, cumuvals)print("Time taken to generate frames: {}".format(time.time() - itime))
iter: [0/18]
iter: [10/18]
Time taken to generate frames: 2.497175931930542
Render animation
Now we can render this inline in the browser; there are other ways to do it, which ran into more painful dependency issues when I tried them.
CODE
itime = time.time()#HTML(animation.to_html5_video())HTML(animation.to_jshtml())print("Time taken to produce html animation: {}".format(time.time() - itime))
Time taken to produce html animation: 57.97186589241028
CODE
HTML(animation.to_jshtml())
Save animation
If necessary, the animation can be saved. The easiest way to do this as an mp4 is to set up the ffmpeg MovieWriter and then run the following.
We use the imageio package (Klein et al. 2023) to convert the images to a GIF. https://towardsdatascience.com/how-to-create-a-gif-from-matplotlib-plots-in-python-6bec6c0c952c
CODE
import numpy as npimport matplotlib.pyplot as pltfrom celluloid import Camerafrom scipy.special import combfrom math import logdef probability_at_least_half_heads(k, p):returnsum(comb(k, i) * p**i * (1-p)**(k-i) for i inrange(int(0.5*k), k+1))# Define the range of k and biasesk_values = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] # Number of coin flipslogbiases = np.linspace(-5, 0, 100) # Biases from 0.01 to 1fig, ax = plt.subplots()camera = Camera(fig)for i inrange(len(k_values)):for k in k_values[:i+1]: probabilities = [probability_at_least_half_heads(k, np.exp(p)) for p in logbiases] ax.plot(logbiases, probabilities, label=f'k={k}'if k == k_values[i] else'_nolegend_') ax.legend(loc='upper left') ax.set_xlabel('Negative loss (log Pr(single coin head))') ax.set_ylabel('Probability of event') ax.set_title('Chance of at least half heads in k coin flips') camera.snap()animation = camera.animate()animation.save('coin_flip_probability.gif', writer='pillow')
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
References
Klein, Almar, Sebastian Wallkötter, Steven Silvester, Anthony Tanbakuchi, actions-user, Paul Müller, Juan Nunez-Iglesias, et al. 2023. Imageio/Imageio: V2.31.1 (version v2.31.1). Zenodo. https://doi.org/10.5281/zenodo.8025955.