Topic Modeling with MIRA+CODAL#

What is a topic model?#

Topic models, like Latent Dirichlet Allocation (LDA), have traditionally been used to decompose a corpus of text into topics - or themes - composed of words that often appear together in documents. Documents, in turn, are modeled as a mixture of topics based on the words they contain.

MIRA extends these ideas to single-cell genomics data, where topics are groups of genes that are co-expressed or cis-regulatory elements that are co-accessible, and cells are a mixture of these regulatory modules. The topics can be used for enrichment and pathway analysis, while the cells’ topic mixtures can be used to embed the cells in an informative, interpretable latent space.

Topic modeling of batched single-cell data is challenging because these models cannot typically distinguish between biological and technical effects of the assay. CODAL (COvariate Disentangling Augmented Loss) uses a novel mutual information regularization technique to explicitly disentangle these two sources of variation.

Tuning and training#

In this tutorial, we will cover tuning and training MIRA topic models (which now run the CODAL algorithm underneath) to find the best hyperparameters for a given dataset. Single-cell datasets vary widely in terms of quality and complexity, and as such, no set of hyperparameters will ensure an optimal model in all cases. The most important parameter to optimize for a given dataset is the number of topics captured by the model, which represents units of covarying genes or cis-regulatory regions. The number of topics determines the quality of the embedding manifold and veracity of the topics’ functional enrichments, so we recommend rigourous tuning to provide an accurate and informative analysis.

MIRA uses a memory efficient streaming minibatch stochastic gradient descent algorithm for parameter inference. This enables parallelized and fast Bayesian optimization of key hyperparameters of the model. Hyperparameters tuning proceeds in three stages:

  1. Feature selection and data cleaning

  2. Model instantiation

  3. Hyperparameter optimization

Let’s start by importing some packages:

[1]:
import mira

import anndata
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 14})
INFO:numexpr.utils:Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.

And we’ll load some data. To make the tutorial quick, I’ve imported a synthetic single-cell dataset used as part of benchmarking CODAL’s ability to disentangle technical and biological effeccts. This synthetic gene expression dataset contains two batches with 2000 cells each. One batch is “WT”, and the other is “KO”, or perturbed. Importantly, these batches have different cell-state distributions.

[2]:
mira.datasets.CodalFrankencellTutorial()

ko = anndata.read_h5ad('mira-datasets/CODAL_tutorial/perturbation.h5ad')
wt = anndata.read_h5ad('mira-datasets/CODAL_tutorial/wild-type.h5ad')

data = anndata.concat({'ko' : ko, 'wt' : wt}, label='batch', index_unique=':') # combine into one dataframe
INFO:mira.datasets.datasets:Dataset already on disk.
INFO:mira.datasets.datasets:Dataset contents:
        * mira-datasets/CODAL_tutorial
                * perturbation.h5ad
                * wild-type.h5ad

If you’re only working with one batch of data, you can proceed with this same tutorial by ommiting the above concatenation step.

A note on data modalities#

All of MIRA’s models have the same API for instantiating, tuning, and training, regardless of the data mode. This tutorial covers tuning of an expression topic model. When tuning an accessibility model using scATAC-seq data, you must ensure that you are training on a GPU to maximize speed. You can check that you are mounted to a GPU by using:

import torch
assert torch.cuda.is_available()

Step 1: Feature selection and Preprocessing - Gene Expression#

First, we must perform feature selection to find highly variable genes. We recommend finding highly variable genes over the concatenation of all batches you intend to merge, which is easy to accomplish using Scanpy. First, filter very rare genes, and freeze the raw counts:

[3]:
sc.pp.filter_genes(data, min_cells=15)
rawdata = data.X.copy()

Normalize the read depths of each cell, then logarithmize the data:

[4]:
sc.pp.normalize_total(data, target_sum=1e4)
sc.pp.log1p(data)

And calculate highly variable genes. This set of genes will be our “exogenous” genes, or those we include in our statistical model of the data. As a rule of thumb, using the top 2500-5000 highly variable genes works well for describing the data manifold and finding interesting enrichments in topics.

[5]:
sc.pp.highly_variable_genes(data, min_disp = 0.5)

Finally, restore the raw counts to the counts layer.

[6]:
data.layers['counts'] = rawdata

Let’s see what this dataset looks like according to PCA:

[7]:
sc.tl.pca(data)
sc.pp.neighbors(data, n_pcs=6)
sc.tl.umap(data, min_dist = 0.2, negative_sample_rate=0.2)
sc.pl.umap(data, color = 'batch', palette= ['#8f7eadff', '#c1e1e2ff'], frameon=False)
... storing 'edge' as categorical
../_images/notebooks_tutorial_CODAL_13_1.png

The two batches show strong separation by technical effects, but also some distinct structural similarities. Let’s see what happens with CODAL integration.

ATAC preprocessing

In this tutorial, we covered the basics of preprocessing GEX data for modeling using MIRA. For a guide on preprocessing scATAC-seq data, see the next tutorial.

Step 2: Model instantiation#

Next, we will instantiate an expression topic model. The hyperparameters will be tuned in the following stage, so all we have to worry about is telling the topic model how to access some information in our AnnData object.

Particularly, and at a minimum, we must tell the model

  • feature_type: what type of features we are working with (either “expression” or “accessibility”)

  • highly_variable_key: which .var key to find our highly variable genes

  • counts_layer: which layer to get the raw counts from.

  • categorical_covariates, continuous_covariates: Technical variables influencing the generative process of the data. For example, a categorical technical factor may be the cells’ batch of origin, as shown here. A continous technical factor might be % of mitchondrial reads. For unbatched data, ignore these parameters.

[8]:
model = mira.topics.make_model(
    data.n_obs, data.n_vars, # helps MIRA choose reasonable values for some hyperparameters which are not tuned.
    feature_type = 'expression',
    highly_variable_key='highly_variable',
    counts_layer='counts',
    categorical_covariates='batch'
)

Next, we have to set minimum and maximum bounds on the learning rate which works for the model. The learning rate will be annealed during training between these values to ensure the best fit. Run the learning rate range test below to collect data on which learning rates work given the model setup:

[9]:
model.get_learning_rate_bounds(data)
INFO:mira.adata_interface.topic_model:Predicting expression from genes from col: highly_variable
WARNING:mira.topic_model.base:Cuda unavailable. Will not use GPU speedup while training.
INFO:mira.topic_model.base:Set learning rates to: (0.001102177009946407, 0.1901675701910509)
[9]:
(0.001102177009946407, 0.1901675701910509)

Now, use the set_learning_rates function to bound the part of the loss curve which contains the steepest decrease.

If you push the upper bound too high, the model is likely to experience gradient overflows. The upper bound works best at or before the point where the slope starts to level off.

[11]:
model.set_learning_rates(1e-3, 0.25) # for larger datasets, the default of 1e-3, 0.1 usually works well.
model.plot_learning_rate_bounds(figsize=(7,3))
[11]:
<AxesSubplot:xlabel='Learning Rate', ylabel='Loss'>
../_images/notebooks_tutorial_CODAL_20_1.png

Step 3: Hyperparameter Optimization#

We offer two methods for choosing the number of topics to represent a dataset. The first, gradient-based topic selection using a Dirichlet Process CODAL model, is faster and works better for larger datasets (>= 10,000 cells).

The second method, Bayesian optimization, takes far more resources, but can be parallelized for speed and additionally tunes the regularization of the model, which can produce better results, especially for smaller datasets (<= 10,000 cells). If you have the time/patience/resources, we suggest using the Bayesian optimization method, which we used for all of our benchmarking tests.

Method 1: Gradient based#

We’ll start with the Gradient-based method. Pass the model and the data:

[12]:
topic_contributions = mira.topics.gradient_tune(model, data)

This method works by instantiating a special version of the CODAL model with far too many topics, which are gradually pruned if that topic is not needed to describe the data. The function returns the maximum contribution of each topic to any cell in the dataset. The predicted number of topics is given by the elbo of the maximum contribution curve, minus 1. A rule of thumb is that the last valid topic to include in the model is followed by a drop-off, after which all subsequent topics hover between 0.-0.05 maximum contributions.

This dropoff is particularly steep. For another example, see the next tutorial.

[15]:
NUM_TOPICS = 6

mira.pl.plot_topic_contributions(topic_contributions, NUM_TOPICS)
[15]:
<AxesSubplot:xlabel='Topic Number', ylabel='Max contribution'>
../_images/notebooks_tutorial_CODAL_24_1.png

You can then train your final topic model like so:

[15]:
model = model.set_params(num_topics = NUM_TOPICS).fit(data)
INFO:mira.adata_interface.topic_model:Predicting expression from genes from col: highly_variable
WARNING:mira.topic_model.base:Cuda unavailable. Will not use GPU speedup while training.
INFO:mira.topic_model.base:Moving model to device: cpu

(Optional) Method 2: Bayesian Optimization#

The Bayesian optimization method is much more comprehensive, jointly tuning the number of topics and the regularization of the model. First, instantiate a BayesianTuner object. This class takes as arguments:

  • The model

  • save_name, which allows us to resume or reload a training session. This is a unique filepath at which the tuner will save its progress

  • n_jobs, the number of parallel processes to run.

  • the min_topics and max_topics, bounds on the number of topics which may be contained in a dataset.

In single-core mode, tuning is deterministic (if nothing changes/no trials drop due to environmental reasons). Multicore mode is also deterministic, barring race conditions. If trials finish out of order due to environmental differences, the tuning will not reproduce exactly.

Notes on n_jobs#

During the fitting stage, MIRA caches the dataset to disk, then streams minibatches into memory for gradient descent steps. This means each training instance has a very small memory footprint which is irrespective of the size of the dataset. Therefore, one can conservatively allocate one process per ~1 GB of available memory.

However, parallelization requires that the tuner has a backend database to coordinate all of those processes. By default, MIRA uses an SQLite table which requires no setup, but this only works up to 5 concurrent jobs (n_jobs=5). For more processes, you will have to use the REDIS backend database. To do this, start a Redis server running in the background - this can be configured however you like. Then, pass the argment storage = mira.topics.Redis() to the tuner object.

In single-core mode, tuning is deterministic (if nothing changes/no trials drop due to environmental reasons). Multicore mode is also deterministic, barring race conditions. If trials finish out of order due to environmental differences, the tuning will not reproduce exactly.

Notes on min_topics, max_topics#

Setting reasonable bounds on the range of topics one might encounter can help speed up convergence of the tuner. For example, typical PBMC datasets contain 10-15 topics, bone marrow differentiation contains 20-30 topics, and the embryonic differentiation dataset analyzed in our manuscript contained 70-80 topics. One option is to get a quick estimate of how many cell states are in your dataset via PCA or clustering, then provide bounds which overlap this estimate. Alternatively, you can use bounds around an estimate from Method 1 above. Worst case, one can provide very permissive bounds, though this will make search take longer before converging.

[9]:
tuner = mira.topics.BayesianTuner(
        model = model,
        n_jobs=5,
        save_name = 'tutorial/0',
        #### IMPORTANT
        min_topics = 3, max_topics = 20, # tailor for your dataset!!!!
        #### See "Notes on min_topics, max_topics" above
        #storage = mira.topics.Redis() # if using REDIS backend for more (>5) processes
)

After instantiating the tuner, call the fit function and provide your dataset:

[15]:
tuner.fit(data)
Trials finished: 43 | Best trial: 17 | Best score: 7.2553e+02
Press ctrl+C,ctrl+C or esc,I+I,I+I in Jupyter notebook to stop early.

Tensorboard logidr: runs/tutorial/0
#Topics | #Trials

      3 | ■ ■
      4 | ■
      6 | ■ ■ ■ ■ ■
      7 | ■ ■ ■ ■ ■ ■ ■ ■
      8 | ■ ■ ■
      9 | ■
     10 | ■ ■
     12 | ■
     13 | ■
     16 | ■ ■
     17 | ■
     18 | ■ ■
     19 | ■ ■
     20 | ■ ■

Trial | Result (● = best so far)         | Params
 #0   |   | pruned at step: 8            | {'decoder_dropout': 0.1120, 'num_topics': 16}
 #1   |   | pruned at step: 16           | {'decoder_dropout': 0.1397, 'num_topics': 12}
 #2   |   | pruned at step: 8            | {'decoder_dropout': 0.0583, 'num_topics': 19}
 #3   | ● | completed, score: 7.2636e+02 | {'decoder_dropout': 0.1389, 'num_topics': 6}
 #4   |   | pruned at step: 8            | {'decoder_dropout': 0.0515, 'num_topics': 16}
 #5   |   | pruned at step: 8            | {'decoder_dropout': 0.0794, 'num_topics': 17}
 #6   |   | pruned at step: 16           | {'decoder_dropout': 0.0547, 'num_topics': 18}
 #7   | ● | completed, score: 7.2566e+02 | {'decoder_dropout': 0.1304, 'num_topics': 6}
 #8   |   | pruned at step: 8            | {'decoder_dropout': 0.0554, 'num_topics': 20}
 #9   |   | pruned at step: 8            | {'decoder_dropout': 0.0738, 'num_topics': 20}
 #10  |   | pruned at step: 16           | {'decoder_dropout': 0.1297, 'num_topics': 13}
 #11  |   | pruned at step: 8            | {'decoder_dropout': 0.0802, 'num_topics': 18}
 #12  |   | pruned at step: 8            | {'decoder_dropout': 0.0809, 'num_topics': 19}
 #13  |   | completed, score: 7.2571e+02 | {'decoder_dropout': 0.0754, 'num_topics': 8}
 #14  |   | pruned at step: 16           | {'decoder_dropout': 0.0764, 'num_topics': 3}
 #15  |   | pruned at step: 8            | {'decoder_dropout': 0.1373, 'num_topics': 3}
 #16  |   | pruned at step: 8            | {'decoder_dropout': 0.1204, 'num_topics': 4}
 #17  | ● | completed, score: 7.2553e+02 | {'decoder_dropout': 0.1196, 'num_topics': 7}
 #18  |   | completed, score: 7.2628e+02 | {'decoder_dropout': 0.1196, 'num_topics': 7}
 #19  |   | completed, score: 7.2693e+02 | {'decoder_dropout': 0.1234, 'num_topics': 6}
 #20  |   | completed, score: 7.2654e+02 | {'decoder_dropout': 0.1484, 'num_topics': 6}
 #21  |   | pruned at step: 16           | {'decoder_dropout': 0.1257, 'num_topics': 7}
 #22  |   | pruned at step: 8            | {'decoder_dropout': 0.0825, 'num_topics': 10}
 #23  |   | pruned at step: 8            | {'decoder_dropout': 0.0825, 'num_topics': 10}
 #24  |   | pruned at step: 8            | {'decoder_dropout': 0.0819, 'num_topics': 8}
 #25  |   | pruned at step: 16           | {'decoder_dropout': 0.1078, 'num_topics': 8}
 #26  |   | completed, score: 7.2593e+02 | {'decoder_dropout': 0.0703, 'num_topics': 9}
 #27  |   | completed, score: 7.2622e+02 | {'decoder_dropout': 0.1421, 'num_topics': 7}
 #28  |   | completed, score: 7.2637e+02 | {'decoder_dropout': 0.1421, 'num_topics': 7}
 #29  |   | completed, score: 7.2568e+02 | {'decoder_dropout': 0.1196, 'num_topics': 7}
 #30  |   | completed, score: 7.2729e+02 | {'decoder_dropout': 0.1484, 'num_topics': 6}
 #33  |   | pruned at step: 16           | {'decoder_dropout': 0.0716, 'num_topics': 7}
 #34  |   | pruned at step: 16           | {'decoder_dropout': 0.1257, 'num_topics': 7}

Running trials:
Trial | Progress                         | Params
 #31  |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■| {'decoder_dropout': 0.0683, 'num_topics': 9}
 #32  |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■| {'decoder_dropout': 0.0674, 'num_topics': 9}
 #35  |■■■■■■■■■■■■■■■■■■■               | {'decoder_dropout': 0.1421, 'num_topics': 7}
 #36  |■■■■■■■■■                         | {'decoder_dropout': 0.0716, 'num_topics': 7}
 #37  |■■■■■■■■                          | {'decoder_dropout': 0.1421, 'num_topics': 7}
 #38  |                                  | {'decoder_dropout': 0.0716, 'num_topics': 7}
 #39  |                                  | {'decoder_dropout': 0.0716, 'num_topics': 7}
 #40  |                                  | {'decoder_dropout': 0.0716, 'num_topics': 7}
 #41  |                                  | {'decoder_dropout': 0.0716, 'num_topics': 7}
 #42  |                                  | {'decoder_dropout': 0.0716, 'num_topics': 7}


After running some optimization trials with random hyperparamters, the tuner will hone in on the optimal number of topics that best represents the dataset. A good indicator of convergence is that the histogram on the dashboard shows the model is preferentially choosing topic numbers over some limited range.

You can stop tuning at any time with Ctrl-C or (esc-I-I in jupyter notebook) without issue.

After tuning, you can assess the tuner fit by plotting the intermediate loss values recorded during each trial with plot_intermediate_values. Here, we just want to ensure that model training is converging to a stable solution.

[10]:
ax = tuner.plot_intermediate_values(palette='Spectral_r',
                                   log_hue=True, figsize=(7,3))
ax.set(ylim = (7e2, 7.7e2))
[10]:
[(700.0, 770.0)]
../_images/notebooks_tutorial_CODAL_32_1.png

Next, we can check that the losses achieved by various models are convex with respect to the number of topics. This check ensures that a reasonable number of topics was chosen for the model and that the tuner converged on that esimate:

[11]:
tuner.plot_pareto_front(include_pruned_trials=False, label_pareto_front=True,
                       figsize = (5,5))
[11]:
<AxesSubplot:xlabel='Num_topics', ylabel='Elbo'>
../_images/notebooks_tutorial_CODAL_34_1.png

Finally, you can access the model which is the best fit for your dataset using the fetch_best_weights command. The weights and hyperparamter choices for this model were saved during the tuning phase, so it is ready to be used for your analysis.

[10]:
model = tuner.fetch_best_weights()
INFO:mira.topic_model.base:Moving model to CPU for inference.
INFO:mira.topic_model.base:Moving model to device: cpu

Tuning persistence#

Tuning results/trials/parameters are saved to a SQLite database in the working directory called “mira-tuning.db” (which can be changed through the storage parameter of the tuning object). Specific optimization runs are saved to their own tables, named by the save_name parameter. Thus, to resume tuning after an interuption, one must simply instantiate a tuning object with the same storage and save_name parameters as the previous process. The new object will reference the saved results and pick up where the last process left off. This means if you start a new tuning process by running a script - which is then interrupted - rerunning that script will resume the original training session.

Step 4: Analysis#

One can continue to this step using either of the models trained above.

The first thing you may want to do with your model is embed the cells in some low-dimensional space for visualization. The predict function infers the topic composition of each cell in the dataset, then transforms the topic compositions to euclidean space for nearest-neighbors analysis. From there, we can used the scanpy workflow to generate a 2-dimensional UMAP visualization of the dataset.

For more on analyzing embeddings, projections, and topics, please see the MIRA joint representation tutorial and MIRA topic analysis tutorial. The APIs are unchanged, except those outlined in the Model Persistence section.

In the next step, make sure to use “X_umap_features” as the representation! - Euclidean distances and nearest neighbors graphs from topics alone are nonsensical.

[11]:
model.predict(data)

# scanpy workflow #
sc.pp.neighbors(data, use_rep = 'X_umap_features', metric = 'manhattan')
sc.tl.umap(data, min_dist=0.1, negative_sample_rate=0.05)
INFO:mira.adata_interface.core:Added key to obsm: X_topic_compositions
INFO:mira.adata_interface.core:Added key to obsm: X_umap_features
INFO:mira.adata_interface.topic_model:Added cols: topic_0, topic_1, topic_2, topic_3, topic_4, topic_5, topic_6
INFO:mira.adata_interface.core:Added key to varm: topic_feature_compositions
INFO:mira.adata_interface.core:Added key to varm: topic_feature_activations
INFO:mira.adata_interface.topic_model:Added key to uns: topic_dendogram

Voilà, potting this dataset, we can see that the batches have been successfully merged! We also see that CODAL was able to determine that the KO batch (purple) contained a new cell type unseen in the control batch (blue). As we show in our manuscript, correcting for technical effects in this situation is quite challenging for other methods.

[12]:
ax = sc.pl.umap(data[np.random.choice(len(data), len(data))], frameon=False, color = 'batch',
               title = '', palette= ['#8f7eadff', '#c1e1e2ff'], show = False)
ax.set_title('UMAP projection')
/Users/alynch/opt/miniconda3/envs/codal/lib/python3.7/site-packages/anndata/compat/_overloaded_dict.py:106: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.
  self.data[key] = value
/Users/alynch/opt/miniconda3/envs/codal/lib/python3.7/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.
  utils.warn_names_duplicates("obs")
[12]:
Text(0.5, 1.0, 'UMAP projection')
../_images/notebooks_tutorial_CODAL_41_2.png

One can plot the distribution of topics across cells to see how the latent space reflects changes in cell state:

[13]:
sc.pl.umap(data, color = model.topic_cols, cmap='BuPu', ncols=3,
           add_outline=True, outline_width=(0.1,0), frameon=False)
../_images/notebooks_tutorial_CODAL_43_0.png

Assessing disentanglement#

We can visually assess the extend of disentanglement using mira.pl.plot_disentanglement. First, use model.get_batch_effect and model.impute to estimate the biological and batch effects influencing each gene in each cell:

[14]:
model.impute(data)
model.get_batch_effect(data)
INFO:mira.adata_interface.topic_model:Fetching key X_topic_compositions from obsm
INFO:mira.adata_interface.core:Added layer: imputed
INFO:mira.adata_interface.topic_model:Fetching key X_topic_compositions from obsm
INFO:mira.adata_interface.core:Added layer: batch_effect

Then plot, coloing by batch and by expression counts to see how the model explains technical variation in the data:

[15]:
fig, ax = plt.subplots(1,2,figsize=(10,4.5), sharey=True)

gene = '112'
mira.pl.plot_disentanglement(data, gene = gene, hue = 'batch', palette=['#8f7eadff', '#c1e1e2ff'], ax = ax[0])
mira.pl.plot_disentanglement(data, gene = gene, palette='Greys', vmin = -1, ax = ax[1])

ax[0].set(title = 'Colored by batch', xlim = (-0.5,0.5))
ax[1].set(title = 'Colored by counts', xlim = (-0.5,0.5))
[15]:
[Text(0.5, 1.0, 'Colored by counts'), (-0.5, 0.5)]
../_images/notebooks_tutorial_CODAL_47_1.png

Model persistence#

Once you have a trained topic model, it would be a good idea to save it to disk so you can access it again later. We recommend the .pth extension, which is the extension Pytorch uses when saving weights.

[16]:
model.save('mira-datasets/tutorial_model.pth')

That model can be reloaded using:

[17]:
model = mira.topic_model.load_model('mira-datasets/tutorial_model.pth')
INFO:mira.topic_model.base:Moving model to CPU for inference.
INFO:mira.topic_model.base:Moving model to device: cpu

Please note, this API was changed from previous MIRA versions.