Topic Modeling with MIRA+CODAL#
What is a topic model?#
Topic models, like Latent Dirichlet Allocation (LDA), have traditionally been used to decompose a corpus of text into topics - or themes - composed of words that often appear together in documents. Documents, in turn, are modeled as a mixture of topics based on the words they contain.
MIRA extends these ideas to single-cell genomics data, where topics are groups of genes that are co-expressed or cis-regulatory elements that are co-accessible, and cells are a mixture of these regulatory modules. The topics can be used for enrichment and pathway analysis, while the cells’ topic mixtures can be used to embed the cells in an informative, interpretable latent space.
Topic modeling of batched single-cell data is challenging because these models cannot typically distinguish between biological and technical effects of the assay. CODAL (COvariate Disentangling Augmented Loss) uses a novel mutual information regularization technique to explicitly disentangle these two sources of variation.
Tuning and training#
In this tutorial, we will cover tuning and training MIRA topic models (which now run the CODAL algorithm underneath) to find the best hyperparameters for a given dataset. Single-cell datasets vary widely in terms of quality and complexity, and as such, no set of hyperparameters will ensure an optimal model in all cases. The most important parameter to optimize for a given dataset is the number of topics captured by the model, which represents units of covarying genes or cis-regulatory regions. The number of topics determines the quality of the embedding manifold and veracity of the topics’ functional enrichments, so we recommend rigourous tuning to provide an accurate and informative analysis.
MIRA uses a memory efficient streaming minibatch stochastic gradient descent algorithm for parameter inference. This enables parallelized and fast Bayesian optimization of key hyperparameters of the model. Hyperparameters tuning proceeds in three stages:
Feature selection and data cleaning
Model instantiation
Hyperparameter optimization
Let’s start by importing some packages:
[1]:
import mira
import anndata
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 14})
INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
And we’ll load some data. To make the tutorial quick, I’ve imported a synthetic single-cell dataset used as part of benchmarking CODAL’s ability to disentangle technical and biological effeccts. This synthetic gene expression dataset contains two batches with 2000 cells each. One batch is “WT”, and the other is “KO”, or perturbed. Importantly, these batches have different cell-state distributions.
[2]:
mira.datasets.CodalFrankencellTutorial()
ko = anndata.read_h5ad('mira-datasets/CODAL_tutorial/perturbation.h5ad')
wt = anndata.read_h5ad('mira-datasets/CODAL_tutorial/wild-type.h5ad')
data = anndata.concat({'ko' : ko, 'wt' : wt}, label='batch', index_unique=':') # combine into one dataframe
INFO:mira.datasets.datasets:Dataset already on disk.
INFO:mira.datasets.datasets:Dataset contents:
* mira-datasets/CODAL_tutorial
* perturbation.h5ad
* wild-type.h5ad
If you’re only working with one batch of data, you can proceed with this same tutorial by ommiting the above concatenation step.
A note on data modalities#
All of MIRA’s models have the same API for instantiating, tuning, and training, regardless of the data mode. This tutorial covers tuning of an expression topic model. When tuning an accessibility model using scATAC-seq data, you must ensure that you are training on a GPU to maximize speed. You can check that you are mounted to a GPU by using:
import torch
assert torch.cuda.is_available()
Step 1: Feature selection and Preprocessing - Gene Expression#
First, we must perform feature selection to find highly variable genes. We recommend finding highly variable genes over the concatenation of all batches you intend to merge, which is easy to accomplish using Scanpy. First, filter very rare genes, and freeze the raw counts:
[3]:
sc.pp.filter_genes(data, min_cells=15)
rawdata = data.X.copy()
Normalize the read depths of each cell, then logarithmize the data:
[4]:
sc.pp.normalize_total(data, target_sum=1e4)
sc.pp.log1p(data)
And calculate highly variable genes. This set of genes will be our “exogenous” genes, or those we include in our statistical model of the data. As a rule of thumb, using the top 2500-5000 highly variable genes works well for describing the data manifold and finding interesting enrichments in topics.
[5]:
sc.pp.highly_variable_genes(data, min_disp = 0.5)
Finally, restore the raw counts to the counts
layer.
[6]:
data.layers['counts'] = rawdata
Let’s see what this dataset looks like according to PCA:
[7]:
sc.tl.pca(data)
sc.pp.neighbors(data, n_pcs=6)
sc.tl.umap(data, min_dist = 0.2, negative_sample_rate=0.2)
sc.pl.umap(data, color = 'batch', palette= ['#8f7eadff', '#c1e1e2ff'], frameon=False)
... storing 'edge' as categorical
The two batches show strong separation by technical effects, but also some distinct structural similarities. Let’s see what happens with CODAL integration.
ATAC preprocessing
In this tutorial, we covered the basics of preprocessing GEX data for modeling using MIRA. For a guide on preprocessing scATAC-seq data, see the next tutorial.
Step 2: Model instantiation#
Next, we will instantiate an expression topic model. The hyperparameters will be tuned in the following stage, so all we have to worry about is telling the topic model how to access some information in our AnnData
object.
Particularly, and at a minimum, we must tell the model
feature_type
: what type of features we are working with (either “expression” or “accessibility”)highly_variable_key
: which.var
key to find our highly variable genescounts_layer
: which layer to get the raw counts from.categorical_covariates
,continuous_covariates
: Technical variables influencing the generative process of the data. For example, a categorical technical factor may be the cells’ batch of origin, as shown here. A continous technical factor might be % of mitchondrial reads. For unbatched data, ignore these parameters.
[8]:
model = mira.topics.make_model(
data.n_obs, data.n_vars, # helps MIRA choose reasonable values for some hyperparameters which are not tuned.
feature_type = 'expression',
highly_variable_key='highly_variable',
counts_layer='counts',
categorical_covariates='batch'
)
Next, we have to set minimum and maximum bounds on the learning rate which works for the model. The learning rate will be annealed during training between these values to ensure the best fit. Run the learning rate range test below to collect data on which learning rates work given the model setup:
[9]:
model.get_learning_rate_bounds(data)
INFO:mira.adata_interface.topic_model:Predicting expression from genes from col: highly_variable
WARNING:mira.topic_model.base:Cuda unavailable. Will not use GPU speedup while training.
INFO:mira.topic_model.base:Set learning rates to: (0.001102177009946407, 0.1901675701910509)
[9]:
(0.001102177009946407, 0.1901675701910509)
Now, use the set_learning_rates
function to bound the part of the loss curve which contains the steepest decrease.
If you push the upper bound too high, the model is likely to experience gradient overflows. The upper bound works best at or before the point where the slope starts to level off.
[11]:
model.set_learning_rates(1e-3, 0.25) # for larger datasets, the default of 1e-3, 0.1 usually works well.
model.plot_learning_rate_bounds(figsize=(7,3))
[11]:
<AxesSubplot:xlabel='Learning Rate', ylabel='Loss'>
Step 3: Hyperparameter Optimization#
We offer two methods for choosing the number of topics to represent a dataset. The first, gradient-based topic selection using a Dirichlet Process CODAL model, is faster and works better for larger datasets (>= 10,000 cells).
The second method, Bayesian optimization, takes far more resources, but can be parallelized for speed and additionally tunes the regularization of the model, which can produce better results, especially for smaller datasets (<= 10,000 cells). If you have the time/patience/resources, we suggest using the Bayesian optimization method, which we used for all of our benchmarking tests.
Method 1: Gradient based#
We’ll start with the Gradient-based method. Pass the model and the data:
[12]:
topic_contributions = mira.topics.gradient_tune(model, data)
This method works by instantiating a special version of the CODAL model with far too many topics, which are gradually pruned if that topic is not needed to describe the data. The function returns the maximum contribution of each topic to any cell in the dataset. The predicted number of topics is given by the elbo of the maximum contribution curve, minus 1. A rule of thumb is that the last valid topic to include in the model is followed by a drop-off, after which all subsequent topics hover between 0.-0.05 maximum contributions.
This dropoff is particularly steep. For another example, see the next tutorial.
[15]:
NUM_TOPICS = 6
mira.pl.plot_topic_contributions(topic_contributions, NUM_TOPICS)
[15]:
<AxesSubplot:xlabel='Topic Number', ylabel='Max contribution'>
You can then train your final topic model like so:
[15]:
model = model.set_params(num_topics = NUM_TOPICS).fit(data)
INFO:mira.adata_interface.topic_model:Predicting expression from genes from col: highly_variable
WARNING:mira.topic_model.base:Cuda unavailable. Will not use GPU speedup while training.
INFO:mira.topic_model.base:Moving model to device: cpu
(Optional) Method 2: Bayesian Optimization#
The Bayesian optimization method is much more comprehensive, jointly tuning the number of topics and the regularization of the model. First, instantiate a BayesianTuner
object. This class takes as arguments:
The model
save_name
, which allows us to resume or reload a training session. This is a unique filepath at which the tuner will save its progressn_jobs
, the number of parallel processes to run.the
min_topics
andmax_topics
, bounds on the number of topics which may be contained in a dataset.
In single-core mode, tuning is deterministic (if nothing changes/no trials drop due to environmental reasons). Multicore mode is also deterministic, barring race conditions. If trials finish out of order due to environmental differences, the tuning will not reproduce exactly.
Notes on n_jobs
#
During the fitting stage, MIRA caches the dataset to disk, then streams minibatches into memory for gradient descent steps. This means each training instance has a very small memory footprint which is irrespective of the size of the dataset. Therefore, one can conservatively allocate one process per ~1 GB of available memory.
However, parallelization requires that the tuner has a backend database to coordinate all of those processes. By default, MIRA uses an SQLite table which requires no setup, but this only works up to 5 concurrent jobs (n_jobs=5
). For more processes, you will have to use the REDIS backend database. To do this, start a Redis server running in the background - this can be configured however you like. Then, pass the argment storage = mira.topics.Redis()
to the tuner object.
In single-core mode, tuning is deterministic (if nothing changes/no trials drop due to environmental reasons). Multicore mode is also deterministic, barring race conditions. If trials finish out of order due to environmental differences, the tuning will not reproduce exactly.
Notes on min_topics
, max_topics
#
Setting reasonable bounds on the range of topics one might encounter can help speed up convergence of the tuner. For example, typical PBMC datasets contain 10-15 topics, bone marrow differentiation contains 20-30 topics, and the embryonic differentiation dataset analyzed in our manuscript contained 70-80 topics. One option is to get a quick estimate of how many cell states are in your dataset via PCA or clustering, then provide bounds which overlap this estimate. Alternatively, you can use
bounds around an estimate from Method 1
above. Worst case, one can provide very permissive bounds, though this will make search take longer before converging.
[9]:
tuner = mira.topics.BayesianTuner(
model = model,
n_jobs=5,
save_name = 'tutorial/0',
#### IMPORTANT
min_topics = 3, max_topics = 20, # tailor for your dataset!!!!
#### See "Notes on min_topics, max_topics" above
#storage = mira.topics.Redis() # if using REDIS backend for more (>5) processes
)
After instantiating the tuner, call the fit
function and provide your dataset:
[15]:
tuner.fit(data)
Trials finished: 43 | Best trial: 17 | Best score: 7.2553e+02
Press ctrl+C,ctrl+C or esc,I+I,I+I in Jupyter notebook to stop early.
Tensorboard logidr: runs/tutorial/0
#Topics | #Trials
3 | ■ ■
4 | ■
6 | ■ ■ ■ ■ ■
7 | ■ ■ ■ ■ ■ ■ ■ ■
8 | ■ ■ ■
9 | ■
10 | ■ ■
12 | ■
13 | ■
16 | ■ ■
17 | ■
18 | ■ ■
19 | ■ ■
20 | ■ ■
Trial | Result (● = best so far) | Params
#0 | | pruned at step: 8 | {'decoder_dropout': 0.1120, 'num_topics': 16}
#1 | | pruned at step: 16 | {'decoder_dropout': 0.1397, 'num_topics': 12}
#2 | | pruned at step: 8 | {'decoder_dropout': 0.0583, 'num_topics': 19}
#3 | ● | completed, score: 7.2636e+02 | {'decoder_dropout': 0.1389, 'num_topics': 6}
#4 | | pruned at step: 8 | {'decoder_dropout': 0.0515, 'num_topics': 16}
#5 | | pruned at step: 8 | {'decoder_dropout': 0.0794, 'num_topics': 17}
#6 | | pruned at step: 16 | {'decoder_dropout': 0.0547, 'num_topics': 18}
#7 | ● | completed, score: 7.2566e+02 | {'decoder_dropout': 0.1304, 'num_topics': 6}
#8 | | pruned at step: 8 | {'decoder_dropout': 0.0554, 'num_topics': 20}
#9 | | pruned at step: 8 | {'decoder_dropout': 0.0738, 'num_topics': 20}
#10 | | pruned at step: 16 | {'decoder_dropout': 0.1297, 'num_topics': 13}
#11 | | pruned at step: 8 | {'decoder_dropout': 0.0802, 'num_topics': 18}
#12 | | pruned at step: 8 | {'decoder_dropout': 0.0809, 'num_topics': 19}
#13 | | completed, score: 7.2571e+02 | {'decoder_dropout': 0.0754, 'num_topics': 8}
#14 | | pruned at step: 16 | {'decoder_dropout': 0.0764, 'num_topics': 3}
#15 | | pruned at step: 8 | {'decoder_dropout': 0.1373, 'num_topics': 3}
#16 | | pruned at step: 8 | {'decoder_dropout': 0.1204, 'num_topics': 4}
#17 | ● | completed, score: 7.2553e+02 | {'decoder_dropout': 0.1196, 'num_topics': 7}
#18 | | completed, score: 7.2628e+02 | {'decoder_dropout': 0.1196, 'num_topics': 7}
#19 | | completed, score: 7.2693e+02 | {'decoder_dropout': 0.1234, 'num_topics': 6}
#20 | | completed, score: 7.2654e+02 | {'decoder_dropout': 0.1484, 'num_topics': 6}
#21 | | pruned at step: 16 | {'decoder_dropout': 0.1257, 'num_topics': 7}
#22 | | pruned at step: 8 | {'decoder_dropout': 0.0825, 'num_topics': 10}
#23 | | pruned at step: 8 | {'decoder_dropout': 0.0825, 'num_topics': 10}
#24 | | pruned at step: 8 | {'decoder_dropout': 0.0819, 'num_topics': 8}
#25 | | pruned at step: 16 | {'decoder_dropout': 0.1078, 'num_topics': 8}
#26 | | completed, score: 7.2593e+02 | {'decoder_dropout': 0.0703, 'num_topics': 9}
#27 | | completed, score: 7.2622e+02 | {'decoder_dropout': 0.1421, 'num_topics': 7}
#28 | | completed, score: 7.2637e+02 | {'decoder_dropout': 0.1421, 'num_topics': 7}
#29 | | completed, score: 7.2568e+02 | {'decoder_dropout': 0.1196, 'num_topics': 7}
#30 | | completed, score: 7.2729e+02 | {'decoder_dropout': 0.1484, 'num_topics': 6}
#33 | | pruned at step: 16 | {'decoder_dropout': 0.0716, 'num_topics': 7}
#34 | | pruned at step: 16 | {'decoder_dropout': 0.1257, 'num_topics': 7}
Running trials:
Trial | Progress | Params
#31 |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■| {'decoder_dropout': 0.0683, 'num_topics': 9}
#32 |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■| {'decoder_dropout': 0.0674, 'num_topics': 9}
#35 |■■■■■■■■■■■■■■■■■■■ | {'decoder_dropout': 0.1421, 'num_topics': 7}
#36 |■■■■■■■■■ | {'decoder_dropout': 0.0716, 'num_topics': 7}
#37 |■■■■■■■■ | {'decoder_dropout': 0.1421, 'num_topics': 7}
#38 | | {'decoder_dropout': 0.0716, 'num_topics': 7}
#39 | | {'decoder_dropout': 0.0716, 'num_topics': 7}
#40 | | {'decoder_dropout': 0.0716, 'num_topics': 7}
#41 | | {'decoder_dropout': 0.0716, 'num_topics': 7}
#42 | | {'decoder_dropout': 0.0716, 'num_topics': 7}
After running some optimization trials with random hyperparamters, the tuner will hone in on the optimal number of topics that best represents the dataset. A good indicator of convergence is that the histogram on the dashboard shows the model is preferentially choosing topic numbers over some limited range.
You can stop tuning at any time with Ctrl-C or (esc-I-I in jupyter notebook) without issue.
After tuning, you can assess the tuner fit by plotting the intermediate loss values recorded during each trial with plot_intermediate_values
. Here, we just want to ensure that model training is converging to a stable solution.
[10]:
ax = tuner.plot_intermediate_values(palette='Spectral_r',
log_hue=True, figsize=(7,3))
ax.set(ylim = (7e2, 7.7e2))
[10]:
[(700.0, 770.0)]
Next, we can check that the losses achieved by various models are convex with respect to the number of topics. This check ensures that a reasonable number of topics was chosen for the model and that the tuner converged on that esimate:
[11]:
tuner.plot_pareto_front(include_pruned_trials=False, label_pareto_front=True,
figsize = (5,5))
[11]:
<AxesSubplot:xlabel='Num_topics', ylabel='Elbo'>
Finally, you can access the model which is the best fit for your dataset using the fetch_best_weights
command. The weights and hyperparamter choices for this model were saved during the tuning phase, so it is ready to be used for your analysis.
[10]:
model = tuner.fetch_best_weights()
INFO:mira.topic_model.base:Moving model to CPU for inference.
INFO:mira.topic_model.base:Moving model to device: cpu
Tuning persistence#
Tuning results/trials/parameters are saved to a SQLite database in the working directory called “mira-tuning.db” (which can be changed through the storage
parameter of the tuning object). Specific optimization runs are saved to their own tables, named by the save_name
parameter. Thus, to resume tuning after an interuption, one must simply instantiate a tuning object with the same storage
and save_name
parameters as the previous process. The new object will reference the saved
results and pick up where the last process left off. This means if you start a new tuning process by running a script - which is then interrupted - rerunning that script will resume the original training session.
Step 4: Analysis#
One can continue to this step using either of the models trained above.
The first thing you may want to do with your model is embed the cells in some low-dimensional space for visualization. The predict
function infers the topic composition of each cell in the dataset, then transforms the topic compositions to euclidean space for nearest-neighbors analysis. From there, we can used the scanpy
workflow to generate a 2-dimensional UMAP visualization of the dataset.
For more on analyzing embeddings, projections, and topics, please see the MIRA joint representation tutorial and MIRA topic analysis tutorial. The APIs are unchanged, except those outlined in the Model Persistence section.
In the next step, make sure to use “X_umap_features” as the representation! - Euclidean distances and nearest neighbors graphs from topics alone are nonsensical.
[11]:
model.predict(data)
# scanpy workflow #
sc.pp.neighbors(data, use_rep = 'X_umap_features', metric = 'manhattan')
sc.tl.umap(data, min_dist=0.1, negative_sample_rate=0.05)
INFO:mira.adata_interface.core:Added key to obsm: X_topic_compositions
INFO:mira.adata_interface.core:Added key to obsm: X_umap_features
INFO:mira.adata_interface.topic_model:Added cols: topic_0, topic_1, topic_2, topic_3, topic_4, topic_5, topic_6
INFO:mira.adata_interface.core:Added key to varm: topic_feature_compositions
INFO:mira.adata_interface.core:Added key to varm: topic_feature_activations
INFO:mira.adata_interface.topic_model:Added key to uns: topic_dendogram
Voilà, potting this dataset, we can see that the batches have been successfully merged! We also see that CODAL was able to determine that the KO batch (purple) contained a new cell type unseen in the control batch (blue). As we show in our manuscript, correcting for technical effects in this situation is quite challenging for other methods.
[12]:
ax = sc.pl.umap(data[np.random.choice(len(data), len(data))], frameon=False, color = 'batch',
title = '', palette= ['#8f7eadff', '#c1e1e2ff'], show = False)
ax.set_title('UMAP projection')
/Users/alynch/opt/miniconda3/envs/codal/lib/python3.7/site-packages/anndata/compat/_overloaded_dict.py:106: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.
self.data[key] = value
/Users/alynch/opt/miniconda3/envs/codal/lib/python3.7/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
utils.warn_names_duplicates("obs")
[12]:
Text(0.5, 1.0, 'UMAP projection')
One can plot the distribution of topics across cells to see how the latent space reflects changes in cell state:
[13]:
sc.pl.umap(data, color = model.topic_cols, cmap='BuPu', ncols=3,
add_outline=True, outline_width=(0.1,0), frameon=False)
Assessing disentanglement#
We can visually assess the extend of disentanglement using mira.pl.plot_disentanglement
. First, use model.get_batch_effect
and model.impute
to estimate the biological and batch effects influencing each gene in each cell:
[14]:
model.impute(data)
model.get_batch_effect(data)
INFO:mira.adata_interface.topic_model:Fetching key X_topic_compositions from obsm
INFO:mira.adata_interface.core:Added layer: imputed
INFO:mira.adata_interface.topic_model:Fetching key X_topic_compositions from obsm
INFO:mira.adata_interface.core:Added layer: batch_effect
Then plot, coloing by batch and by expression counts to see how the model explains technical variation in the data:
[15]:
fig, ax = plt.subplots(1,2,figsize=(10,4.5), sharey=True)
gene = '112'
mira.pl.plot_disentanglement(data, gene = gene, hue = 'batch', palette=['#8f7eadff', '#c1e1e2ff'], ax = ax[0])
mira.pl.plot_disentanglement(data, gene = gene, palette='Greys', vmin = -1, ax = ax[1])
ax[0].set(title = 'Colored by batch', xlim = (-0.5,0.5))
ax[1].set(title = 'Colored by counts', xlim = (-0.5,0.5))
[15]:
[Text(0.5, 1.0, 'Colored by counts'), (-0.5, 0.5)]
Model persistence#
Once you have a trained topic model, it would be a good idea to save it to disk so you can access it again later. We recommend the .pth
extension, which is the extension Pytorch uses when saving weights.
[16]:
model.save('mira-datasets/tutorial_model.pth')
That model can be reloaded using:
[17]:
model = mira.topic_model.load_model('mira-datasets/tutorial_model.pth')
INFO:mira.topic_model.base:Moving model to CPU for inference.
INFO:mira.topic_model.base:Moving model to device: cpu
Please note, this API was changed from previous MIRA versions.