from functools import partial
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDelta
from pyro.infer import TraceMeanField_ELBO
from pyro.infer.autoguide.initialization import init_to_mean, init_to_value, init_to_sample
from pyro import poutine
import numpy as np
import logging
from pyro.contrib.autoname import scope
from mira.topic_model.base import EarlyStopping
import warnings
from mira.rp_model.optim import LBFGS as stochastic_LBFGS
from scipy.stats import nbinom
from scipy.sparse import isspmatrix
from mira.adata_interface.rp_model import wraps_rp_func, add_isd_results, \
add_predictions, fetch_TSS_from_adata
import mira.adata_interface.rp_model as rpi
from mira.adata_interface.core import add_layer, wraps_modelfunc
import mira.adata_interface.core as adi
import h5py as h5
from tqdm.auto import tqdm
import os
import glob
logger = logging.getLogger(__name__)
class SaveCallback:
def __init__(self, prefix):
self.prefix = prefix
def __call__(self, model):
if model.was_fit:
model.save(self.prefix)
def mean_default_init_to_value(
site = None,
values = {},
*,
fallback = init_to_mean,
):
if site is None:
return partial(mean_default_init_to_value, values=values, fallback=fallback)
if site["name"] in values:
return values[site["name"]]
if fallback is not None:
return fallback(site)
raise ValueError(f"No init strategy specified for site {repr(site['name'])}")
class BaseModel:
@classmethod
def load_dir(cls,counts_layer = None,*,expr_model, accessibility_model, prefix):
'''
Load directory of RP models. Adds all available RP models into a container.
Parameters
----------
expr_model: mira.topics.ExpressionTopicModel
Trained MIRA expression topic model.
accessibility_model : mira.topics.AccessibilityTopicModel
Trained MIRA accessibility topic model.
counts_layer : str, default=None
Layer in AnnData that countains raw counts for modeling.
prefix : str
Prefix under which RP models were saved.
Examples
--------
.. code-block :: python
>>> litemodel = mira.rp.LITE_Model.load_dir(
... counts_layer = 'counts',
... expr_model = rna_model,
... accessibility_model = atac_model,
... prefix = 'path/to/rpmodels/'
... )
'''
paths = glob.glob(prefix + cls.prefix + '*.pth')
if len(paths) == 0:
if len(glob.glob(prefix + cls.old_prefix + '*.pth')) > 0:
logger.error('''
Cannot load models, but found a models using older file conventions.
Please use "mira.rp.{}Model.convert_models(<prefix>) to convert old
models to the new format.
'''.format(cls.prefix))
raise ValueError('No models found at {}'.format(str(prefix)))
genes = [os.path.basename(x.split('_')[-1].split('.')[0])
for x in paths]
model = cls(expr_model = expr_model, accessibility_model = accessibility_model,
counts_layer = counts_layer, genes = genes).load(prefix)
return model
@classmethod
def convert_models(cls, prefix):
paths = glob.glob(prefix + cls.old_prefix + '*.pth')
if len(paths) == 0:
raise ValueError('No models found at {}'.format(str(prefix)))
for path in tqdm(paths, desc = 'Reformatting models'):
old_model = torch.load(path)
old_model['guide'] = {
old_key.replace(cls.old_prefix, cls.prefix).replace('logdistance','distance') : v
for old_key,v in old_model['guide'].items()
}
gene = os.path.basename(path).split('_')[-1].split('.')[0]
torch.save(old_model,
os.path.join(prefix,
'{}{}.pth'.format(cls.prefix, gene))
)
@classmethod
def _make(cls, expr_model, accessibility_model, counts_layer, models, learning_rate, use_NITE_features):
self = BaseModel.__new__(cls)
self.expr_model = expr_model
self.accessibility_model = accessibility_model
self.learning_rate = learning_rate
self.use_NITE_features = use_NITE_features
self.counts_layer = counts_layer
self.models = models
return self
def __init__(self,*,
expr_model,
accessibility_model,
genes,
learning_rate = 1,
counts_layer = None,
search_reps = 1,
initialization_model = None):
'''
Parameters
----------
expr_model: mira.topics.ExpressionTopicModel
Trained MIRA expression topic model.
accessibility_model : mira.topics.AccessibilityTopicModel
Trained MIRA accessibility topic model.
genes : np.ndarray[str], list[str]
List of genes for which to learn RP models.
learning_rate : float>0
Learning rate for L-BGFS optimizer.
counts_layer : str, default=None
Layer in AnnData that countains raw counts for modeling.
initialization_model : mira.rp.LITE_Model, mira.rp.NITE_Model, None
Initialize parameters of RP model using the provided model before
further optimization with L-BGFS. This is used when training the NITE
model, which is initialized with the LITE model parameters learned
for the same genes, then retrained to optimized the NITE model's
extra parameters. This procedure speeds training.
Attributes
----------
genes : np.ndarray[str]
Array of gene names for models
features : np.ndarray[str]
Array of gene names for models
models : list[mira.rp.GeneModel]
List of trained RP models
model_type : {"NITE", "LITE"}
Examples
--------
Setup requires RNA and ATAC AnnData objects with shared cell barcodes
and trained topic models for both modes:
.. code-block:: python
>>> rp_args = dict(expr_adata = rna_data, atac_adata = atac_data)
Instantiating a LITE model (local chromatin accessibility only):
.. code-block:: python
>>> litemodel = mira.rp.LITE_Model(
... expr_model = rna_model,
... accessibility_model = atac_model,
... counts_layer = 'counts',
... genes = ['LEF1','WNT3','EDA','NOTCH1'],
... )
>>> litemodel.fit(**rp_args)
Instantiating a NITE model (local chromatin accessibility only):
>>> nitemodel = mira.rp.NITE_Model(
... expr_model = rna_model,
... accessibility_model = atac_model,
... counts_layer = 'counts',
... genes = litemodel.genes,
... instantiation_model = litemodel
... )
>>> nitemodel.fit(**rp_args)
'''
self.expr_model = expr_model
self.accessibility_model = accessibility_model
self.learning_rate = learning_rate
self.counts_layer = counts_layer
assert(isinstance(search_reps, int) and search_reps > 0)
self.search_reps = search_reps
self.models = []
for gene in genes:
init_params = None
try:
init_params = initialization_model.get_model(gene).posterior_map
except (IndexError, AttributeError):
pass
self.models.append(
GeneModel(
gene = gene,
learning_rate = learning_rate,
use_NITE_features = self.use_NITE_features,
init_params= init_params,
search_reps = search_reps,
)
)
def subset(self, genes):
'''
Return a subset container of RP models.
Parameters
----------
genes : np.ndarray[str], list[str]
List of genes to subset from RP model
Examples
--------
.. code-block :: python
>>> less_models = litemodel.subset(['LEF1','WNT3'])
'''
assert(isinstance(genes, (list, np.ndarray)))
for gene in genes:
if not gene in self.genes:
raise ValueError('Gene {} is not in RP model'.format(str(gene)))
return self._make(
expr_model = self.expr_model,
accessibility_model = self.accessibility_model, counts_layer=self.counts_layer,
learning_rate = self.learning_rate, use_NITE_features = self.use_NITE_features,
models = [model for model in self.models if model.gene in genes]
)
def join(self, rp_model):
'''
Merge RP models from two model containers.
Parameters
----------
rp_model : mira.rp.LITE_Model, mira.rp.NITE_Model
RP model container from which to append new RP models
Examples
--------
.. code-block :: python
>>> model1.genes
... ['LEF1','WNT3']
>>> model2.genes
... ['CTSC','EDAR']
>>> merged_model = model1.join(model2)
>>> merged_model.genes
... ['LEF1','WNT3','CTSC','EDAR']
'''
assert(isinstance(rp_model, BaseModel))
assert(rp_model.use_NITE_features == self.use_NITE_features), 'Cannot join LITE model with NITE model'
add_models = np.setdiff1d(rp_model.genes, self.genes)
for add_gene in add_models:
self.models.append(
rp_model.get_model(add_gene)
)
return self
def __getitem__(self, gene):
'''
Alias for `get_model(gene)`.
Examples
--------
>>> rp_model["LEF1"]
... <mira.rp_model.rp_model.GeneModel at 0x7fa07af1cf10>
'''
return self.get_model(gene)
@property
def genes(self):
return np.array([model.gene for model in self.models])
@property
def features(self):
return self.genes
@property
def model_type(self):
if self.use_NITE_features:
return 'NITE'
else:
return 'LITE'
def _get_masks(self, tss_distance):
promoter_mask = np.abs(tss_distance) <= 1500
upstream_mask = np.logical_and(tss_distance < 0, ~promoter_mask)
downstream_mask = np.logical_and(tss_distance > 0, ~promoter_mask)
return promoter_mask, upstream_mask, downstream_mask
@staticmethod
def bn(x, mu, var, eps):
return (x - mu)/np.sqrt( var + eps)
#def _get_batcheffects(self, model, batcheffect_embeddings, idx):
#
# lin_output = np.dot(batcheffect_embeddings, model._get_covariates_linear_layer()[:, idx])
# return self.bn(
# )
def _get_region_weights(self, NITE_features, softmax_denom, idx):
model = self.accessibility_model
rate = model._get_gamma()[idx] * self.bn(
NITE_features.dot(model._get_beta()[:, idx]),
model._get_bn_mean()[idx],
model._get_bn_var()[idx],
model.decoder.bn.eps
) + model._get_bias()[idx]
region_probabilities = np.exp(rate)/softmax_denom[:, np.newaxis]
return region_probabilities
def _get_features_for_model(self,*, gene_expr, read_depth, correction_vector,
expr_softmax_denom, NITE_features, atac_softmax_denom, upstream_idx, downstream_idx,
promoter_idx, upstream_distances, downstream_distances, include_factor_data = False):
features = dict(
gene_expr = gene_expr,
read_depth = read_depth,
softmax_denom = expr_softmax_denom,
NITE_features = NITE_features,
upstream_distances = upstream_distances,
downstream_distances = downstream_distances,
correction_vector = correction_vector,
)
if include_factor_data:
features.update(dict(
promoter_idx = promoter_idx,
upstream_idx = upstream_idx,
downstream_idx = downstream_idx
))
for k, idx in zip(['upstream_weights', 'downstream_weights', 'promoter_weights'],
[upstream_idx, downstream_idx, promoter_idx]):
features[k] = self._get_region_weights(NITE_features, atac_softmax_denom, idx) * 1e4
return features
def save(self, prefix):
'''
Save RP models.
Parameters
----------
prefix : str
Prefix under which to save RP models. May be filename prefix
or directory. RP models will save with format:
**{prefix}_{LITE/NITE}_{gene}.pth**
'''
for model in self.models:
model.save(prefix)
def get_model(self, gene):
'''
Gets model for gene
Parameters
----------
gene : str
Fetch RP model for this gene
'''
try:
return self.models[np.argwhere(self.genes == gene)[0,0]]
except IndexError:
raise IndexError('Model for gene {} does not exist'.format(gene))
def load(self, prefix):
'''
Load RP models saved with *prefix*.
Parameters
----------
prefix : str
Prefix under which RP models were saved.
'''
genes = self.genes
self.models = []
for gene in genes:
try:
model = GeneModel(gene = gene, use_NITE_features = self.use_NITE_features)
model.load(prefix)
self.models.append(model)
except FileNotFoundError:
old_filename = prefix + self.old_prefix + gene + '.pth'
if os.path.isfile(old_filename):
logger.warn('''
Cannot load {} model, but found a model using older file conventions: {}.
Please use "mira.rp.{}_Model.convert_models(<prefix>) to convert old
models to the new format.
'''.format(
gene, old_filename, self.model_type
))
else:
logger.warn('Cannot load {} model. File not found.'.format(gene))
if len(self.models) == 0:
raise ValueError('No models loaded.')
return self
def subset_fit_models(self, models):
self.models = []
for model in models:
if not model.was_fit:
logger.warn('{} model failed to fit.'.format(model.gene))
else:
self.models.append(model)
return self
@wraps_rp_func(lambda self, expr_adata, atac_data, output, **kwargs : self.subset_fit_models(output), bar_desc = 'Fitting models')
def fit(self, model, features, callback = None):
'''
Optimize parameters of RP models to learn *cis*-regulatory relationships.
Parameters
----------
expr_adata : anndata.AnnData
AnnData of expression features
atac_adata : anndata.AnnData
AnnData of accessibility features. Must be annotated with
mira.tl.get_distance_to_TSS.
Returns
-------
rp_model : mira.rp.LITE_Model, mira.rp.NITE_Model
RP model with optimized parameters
'''
try:
for key in features:
if features[key].dtype == np.float64:
features[key] = features[key].astype(np.float32)
model.fit(features)
except ValueError:
pass
if not callback is None:
callback(model)
return model
@wraps_rp_func(lambda self, expr_adata, atac_data, output, **kwargs: np.array(output).sum(), bar_desc = 'Scoring')
def score(self, model, features):
return model.score(features)
@wraps_rp_func(lambda self, expr_adata, atac_data, output, **kwargs: \
add_predictions(expr_adata, (self.features, output), model_type = self.model_type, sparse = True),
bar_desc = 'Predicting expression')
def predict(self, model, features):
'''
Predicts the expression of genes given their *cis*-accessibility state.
Also evaluates the probability of that prediction for LITE/NITE evaluation.
Parameters
----------
expr_adata : anndata.AnnData
AnnData of expression features
atac_adata : anndata.AnnData
AnnData of accessibility features. Must be annotated with
mira.tl.get_distance_to_TSS.
Returns
-------
anndata.AnnData
`.layers['LITE_prediction']` or `.layers['NITE_prediction']`: np.ndarray[float] of shape (n_cells, n_features)
Predicted relative frequencies of features using LITE or NITE model, respectively
`.layers['LTIE_logp']` or `.layers['NITE_logp']`Â : np.ndarray[float] of shape (n_cells, n_features)
Probability of observed expression given posterior predictive estimate of LITE or
NITE model, respectively.
'''
try:
return True, model.predict(features)
except ValueError:
return False, None
@wraps_rp_func(lambda self, expr_adata, atac_data, output, **kwargs: \
add_layer(expr_adata, (self.features, np.hstack(output)), add_layer = self.model_type + '_logp', sparse = True),
bar_desc = 'Getting logp(Data)')
def get_logp(self, model, features):
return model.get_logp(features)
'''@wraps_rp_func(lambda self, expr_adata, atac_data, output, **kwargs: \
add_layer(expr_adata, (self.features, np.hstack(output)), add_layer = self.model_type + '_samples', sparse = True)
)
def _sample_posterior(self, model, features, site = 'prediction'):
return model.to_numpy(model.get_posterior_sample(features, site))[:, np.newaxis]'''
@wraps_rp_func(lambda self, expr_adata, atac_adata, output, **kwargs : output, bar_desc = 'Formatting features')
def get_features(self, model, features):
return features
@wraps_rp_func(lambda self, expr_adata, atac_adata, output, **kwargs : output,
bar_desc = 'Formatting features', include_factor_data = True)
def get_isd_features(self, model, features,*,hits_matrix, metadata):
return features
@wraps_rp_func(add_isd_results,
bar_desc = 'Predicting TF influence', include_factor_data = True)
def probabilistic_isd(self, model, features, n_samples = 1500, checkpoint = None,
*,hits_matrix, metadata):
'''
For each gene, calcuate association scores with each transcription factor.
Association scores detect when a TF binds within *cis*-regulatory
elements (CREs) that are influential to expression predictions for that gene.
CREs that influence the RP model expression prediction are nearby a
gene's TSS and have accessibility that correlates with expression. This
model assumes these attributes indicate a factor is more likely to
regulate a gene.
Parameters
----------
expr_adata : anndata.AnnData
AnnData of expression features
atac_adata : anndata.AnnData
AnnData of accessibility features. Must be annotated with TSS and factor
binding data using mira.tl.get_distance_to_TSS **and**
mira.tl.get_motif_hits_in_peaks/mira.tl.get_CHIP_hits_in_peaks.
n_samples : int>0, default=1500
Downsample cells to this amount for calculations. Speeds up computation
time. Cells are sampled by stratifying over expression levels.
checkpoint : str, default = None
Path to checkpoint h5 file. pISD calculations can be slow, and saving
a checkpoint ensures progress is not lost if calculations are
interrupted. To resume from a checkpoint, just pass the path to the h5.
Returns
-------
anndata.AnnData
`.varm['motifs-prob_deletion']` or `.varm['chip-prob_deletion']`: np.ndarray[float] of shape (n_genes, n_factors)
Association scores for each gene-TF combination. Higher scores indicate
greater predicted association/regulatory influence.
'''
already_calculated = False
if not checkpoint is None:
if not os.path.isfile(checkpoint):
h5.File(checkpoint, 'w').close()
with h5.File(checkpoint, 'r') as h:
try:
h[model.gene]
already_calculated = True
except KeyError:
pass
if checkpoint is None or not already_calculated:
result = model.probabilistic_isd(features, hits_matrix, n_samples = n_samples)
if not checkpoint is None:
with h5.File(checkpoint, 'a') as h:
g = h.create_group(model.gene)
g.create_dataset('samples_mask', data = result[1])
g.create_dataset('isd', data = result[0])
return result
else:
with h5.File(checkpoint, 'r') as h:
g = h[model.gene]
result = g['isd'][...], g['samples_mask'][...]
return result
@property
def parameters_(self):
'''
Returns parameters of all contained RP models.
'''
return {
gene : self[gene].parameters_
for gene in self.features
}
class LITE_Model(BaseModel):
use_NITE_features = False
prefix = 'LITE_'
old_prefix = 'cis_'
def __init__(self,*, expr_model, accessibility_model, genes, learning_rate = 1,
counts_layer = None, initialization_model = None, search_reps = 1):
'''
Container for multiple regulatory potential (RP) LITE models. LITE models
learn a relationship between a gene's expression and accessibility in
nearby *cis*-regulatory elements (CRE). The MIRA model assumes the regulatory
influence of a CRE on a gene decays with respect to distance from that
gene. MIRA learns this distance using variational Bayesian inference.
With a trained RP model, one may assess the
* LITE/NITE characteristics of a gene: whether that gene's expression is decoupled from changes in local chromatin.
* Chromatin differential: the relative levels of nearby accessibility versus gene expression.
* *Insilico*-deletion: predicts transcription factor regulators based on a model of nearby binding in influential CREs, as determined by the RP model.
Parameters
----------
expr_model: mira.topics.ExpressionTopicModel
Trained MIRA expression topic model.
accessibility_model : mira.topics.AccessibilityTopicModel
Trained MIRA accessibility topic model.
genes : np.ndarray[str], list[str]
List of genes for which to learn RP models.
learning_rate : float>0
Learning rate for L-BGFS optimizer.
counts_layer : str, default=None
Layer in AnnData that countains raw counts for modeling.
initialization_model : mira.rp.LITE_Model, mira.rp.NITE_Model, None
Initialize parameters of RP model using the provided model before
further optimization with L-BGFS. This is used when training the NITE
model, which is initialized with the LITE model parameters learned
for the same genes, then retrained to optimized the NITE model's
extra parameters. This procedure speeds training.
Attributes
----------
genes : np.ndarray[str]
Array of gene names for models
features : np.ndarray[str]
Array of gene names for models
models : list[mira.rp.GeneModel]
List of trained RP models
model_type : {"NITE", "LITE"}
Examples
--------
Setup requires RNA and ATAC AnnData objects with shared cell barcodes
and trained topic models for both modes:
.. code-block:: python
>>> rp_args = dict(expr_adata = rna_data, atac_adata = atac_data)
Instantiating a LITE model (local chromatin accessibility only):
.. code-block:: python
>>> litemodel = mira.rp.LITE_Model(
... expr_model = rna_model,
... accessibility_model = atac_model,
... counts_layer = 'counts',
... genes = ['LEF1','WNT3','EDA','NOTCH1'],
... )
>>> litemodel.fit(**rp_args)
'''
super().__init__(
expr_model = expr_model,
accessibility_model = accessibility_model,
genes = genes,
learning_rate = learning_rate,
initialization_model = initialization_model,
counts_layer=counts_layer,
search_reps = search_reps,
)
def spawn_NITE_model(self):
'''
Returns a NITE model seeded with the LITE model's
parameters.
'''
return NITE_Model(
expr_model= self.expr_model,
accessibility_model=self.accessibility_model,
genes = self.genes,
learning_rate=self.learning_rate,
counts_layer=self.counts_layer,
initialization_model=self,
search_reps=self.search_reps
)
class NITE_Model(BaseModel):
use_NITE_features = True
prefix = 'NITE_'
old_prefix = 'trans_'
def __init__(self,*, expr_model, accessibility_model, genes, learning_rate = 1,
counts_layer = None, initialization_model = None, search_reps = 1):
'''
Container for multiple regulatory potential (RP) NITE models. NITE models
learn a relationship between a gene's expression and accessibility in
nearby *cis*-regulatory elements (CRE), **and** the cell-wide chromatin landscape.
The predictive capacity of local vs. cell-wide chromatin in predicting a gene's
expression state determines a gene's *NITE Score*, and edulicates whether that
gene is primarily regulated by local or nonlocal mechanisms.
Parameters
----------
expr_model: mira.topics.ExpressionTopicModel
Trained MIRA expression topic model.
accessibility_model : mira.topics.AccessibilityTopicModel
Trained MIRA accessibility topic model.
genes : np.ndarray[str], list[str]
List of genes for which to learn RP models.
learning_rate : float>0
Learning rate for L-BGFS optimizer.
counts_layer : str, default=None
Layer in AnnData that countains raw counts for modeling.
initialization_model : mira.rp.LITE_Model, mira.rp.NITE_Model, None
Initialize parameters of RP model using the provided model before
further optimization with L-BGFS. This is used when training the NITE
model, which is initialized with the LITE model parameters learned
for the same genes, then retrained to optimized the NITE model's
extra parameters. This procedure speeds training.
Attributes
----------
genes : np.ndarray[str]
Array of gene names for models
features : np.ndarray[str]
Array of gene names for models
models : list[mira.rp.GeneModel]
List of trained RP models
model_type : {"NITE", "LITE"}
Examples
--------
Setup requires RNA and ATAC AnnData objects with shared cell barcodes
and trained topic models for both modes:
.. code-block:: python
>>> rp_args = dict(expr_adata = rna_data, atac_adata = atac_data)
Instantiating a NITE model (local chromatin accessibility only):
>>> nitemodel = mira.rp.NITE_Model(
... expr_model = rna_model,
... accessibility_model = atac_model,
... counts_layer = 'counts',
... genes = litemodel.genes,
... instantiation_model = litemodel
... )
>>> nitemodel.fit(**rp_args)
'''
super().__init__(
expr_model = expr_model,
accessibility_model = accessibility_model,
genes = genes,
learning_rate = learning_rate,
initialization_model = initialization_model,
counts_layer=counts_layer,
search_reps = search_reps,
)
[docs]class GeneModel:
'''
Gene-level RP model object.
'''
def __init__(self,*,
gene,
learning_rate = 1.,
use_NITE_features = False,
init_params = None,
search_reps = 1,
):
self.gene = gene
self.learning_rate = learning_rate
self.use_NITE_features = use_NITE_features
self.was_fit = False
self.search_reps= search_reps
self.init_params = init_params
def _get_weights(self, loading = False, seed = None):
if not seed is None:
pyro.set_rng_seed(seed)
pyro.clear_param_store()
self.bn = torch.nn.BatchNorm1d(1, momentum = 1.0, affine = False)
if self.init_params is None:
if self.use_NITE_features and not loading:
logger.warn('\nTraining NITE regulation model for {} without providing pre-trained LITE models may cause divergence in statistical testing.'\
.format(self.gene))
if seed is None:
self.guide = AutoDelta(self.model, init_loc_fn = init_to_mean)
else:
self.guide = AutoDelta(self.model, init_loc_fn = init_to_sample)
else:
self.seed_params = {self.prefix + '/' + k.split('/')[-1] : v.detach().clone() for k,v in self.init_params.items()}
self.guide = AutoDelta(self.model, init_loc_fn = mean_default_init_to_value(values = self.seed_params))
def get_prefix(self):
return ('NITE' if self.use_NITE_features else 'LITE') + '_' + self.gene
def RP(self, weights, distances, d):
return (weights * torch.pow(0.5, distances/(1e3 * d))).sum(-1)
def model(self,
gene_expr,
correction_vector,
softmax_denom,
read_depth,
upstream_weights,
upstream_distances,
downstream_weights,
downstream_distances,
promoter_weights,
NITE_features):
with scope(prefix = self.get_prefix()):
with pyro.plate("spans", 3):
a = pyro.sample("a", dist.HalfNormal(1.))
with pyro.plate("upstream-downstream", 2):
d = pyro.sample('distance', dist.LogNormal(np.log(15), 1.2))
if self.use_NITE_features and hasattr(self, 'seed_params'):
theta = self.seed_params[self.prefix + '/theta']
theta.requires_grad = False
else:
theta = pyro.sample('theta', dist.Gamma(2., 0.5))
gamma = pyro.sample('gamma', dist.LogNormal(0., 0.5))
bias = pyro.sample('bias', dist.Normal(0, 5.))
if self.use_NITE_features:
with pyro.plate('NITE_coefs', NITE_features.shape[-1]):
a_NITE = pyro.sample('a_NITE', dist.Normal(0.,1.))
with pyro.plate('data', len(upstream_weights)):
f_Z = a[0] * self.RP(upstream_weights, upstream_distances, d[0])\
+ a[1] * self.RP(downstream_weights, downstream_distances, d[1]) \
+ a[2] * promoter_weights.sum(-1)
if self.use_NITE_features:
f_Z = f_Z + torch.matmul(NITE_features, torch.unsqueeze(a_NITE, 0).T).reshape(-1)
expr_prediction = gamma* self.bn(f_Z.reshape((-1,1)).float()).reshape(-1) + bias
pyro.deterministic('prediction', expr_prediction.exp()/softmax_denom) # assuming batch effects are 0!
independent_rate = (expr_prediction + correction_vector).exp()
rate = independent_rate/softmax_denom
mu = read_depth.exp() * rate
pyro.deterministic('mu', mu)
p = mu / (mu + theta)
pyro.deterministic('prob_success', p)
NB = dist.NegativeBinomial(total_count = theta, probs = p)
pyro.sample('obs', NB, obs = gene_expr)
def _t(self, X):
return torch.tensor(X, requires_grad=False)
@staticmethod
def get_loss_fn():
return TraceMeanField_ELBO().differentiable_loss
def get_optimizer(self, params):
#return torch.optim.LBFGS(params, lr=self.learning_rate, line_search_fn = 'strong_wolfe')
return stochastic_LBFGS(params, lr = self.learning_rate, history_size = 5,
line_search = 'Armijo')
def get_loss_and_grads(self, optimizer, features):
optimizer.zero_grad()
loss = self.get_loss_fn()(self.model, self.guide, **features)
loss.backward()
grads = optimizer._gather_flat_grad()
return loss, grads
def armijo_step(self, optimizer, features, update_curvature = True):
def closure():
optimizer.zero_grad()
loss = self.get_loss_fn()(self.model, self.guide, **features)
return loss
obj_loss, grad = self.get_loss_and_grads(optimizer, features)
# compute initial gradient and objective
p = optimizer.two_loop_recursion(-grad)
p/=torch.norm(p)
# perform line search step
options = {'closure': closure, 'current_loss': obj_loss, 'interpolate': True}
obj_loss, lr, _, _, _, _ = optimizer.step(p, grad, options=options)
# compute gradient
obj_loss.backward()
grad = optimizer._gather_flat_grad()
# curvature update
if update_curvature:
optimizer.curvature_update(grad, eps=0.2, damping=True)
return obj_loss.detach().item()
def fit(self, features):
features = {k : self._t(v) for k, v in features.items()}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
def find_init_point(seed):
self._get_weights(seed = seed)
with poutine.trace(param_only=True) as param_capture:
loss = self.get_loss_fn()(self.model, self.guide, **features)
return seed, loss
best_seed = None
if self.search_reps > 1:
best_seed = sorted(map(find_init_point, [None, *range(self.search_reps - 1)]), key = lambda x : x[1])[0][0]
self._get_weights(seed = best_seed)
N = len(features['upstream_weights'])
with poutine.trace(param_only=True) as param_capture:
loss = self.get_loss_fn()(self.model, self.guide, **features)
params = {site["value"].unconstrained() for site in param_capture.trace.nodes.values()}
optimizer = self.get_optimizer(params)
early_stopper = EarlyStopping(patience = 3, tolerance = 1e-4)
update_curvature = False
self.loss = []
self.bn.train()
for i in range(100):
self.loss.append(
float(self.armijo_step(optimizer, features, update_curvature = update_curvature)/N)
)
update_curvature = not update_curvature
if early_stopper(self.loss[-1]):
break
self.was_fit = True
self.posterior_map = self.guide()
if self.use_NITE_features and hasattr(self, 'seed_params'):
theta_name = self.prefix + '/theta'
self.posterior_map[theta_name] = self.seed_params[theta_name]
del optimizer
del features
del self.guide
return self
def get_posterior_sample(self, features):
features = {k : self._t(v) for k, v in features.items()}
self.bn.eval()
guide = AutoDelta(self.model, init_loc_fn = init_to_value(values = self.posterior_map))
guide_trace = poutine.trace(guide).get_trace(**features)
#print(guide_trace)
model_trace = poutine.trace(poutine.replay(self.model, guide_trace))\
.get_trace(**features)
return model_trace
@property
def prefix(self):
return self.get_prefix()
def __getitem__(self, gene):
return self.get_model(gene)
def predict(self, features):
class RPLogpTracker(TraceMeanField_ELBO):
def rate_distortion(self, model_trace, guide_trace):
model_logp = 0.
data_logp = 0.
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
if model_site["is_observed"]:
data_logp += model_site["log_prob_sum"].clone().detach().numpy()
else:
guide_site = guide_trace.nodes[name]
model_logp += model_site['fn'].log_prob(
guide_site['fn']()
).sum().clone().detach().numpy()
return model_logp, data_logp
features = {k : self._t(v) for k, v in features.items()}
self.bn.eval()
guide = AutoDelta(self.model, init_loc_fn = init_to_value(values = self.posterior_map))
loss_fn = RPLogpTracker()
trace, guide_trace = loss_fn._get_trace(self.model, guide, [], features)
rate, distortion = loss_fn.rate_distortion(trace, guide_trace)
expression_prediction = self.to_numpy(
trace.nodes[self.prefix + '/prediction']['value'])[:, np.newaxis]
logp_data = self._get_logp(features['gene_expr'], trace)
return expression_prediction, logp_data, rate
def score(self, features):
trace = self.get_posterior_sample(features)
return self._get_logp(features['gene_expr'], trace).sum()
def _get_logp(self, gene_expr, trace):
p = trace.nodes[self.prefix + '/prob_success']['value']
theta = self.posterior_map[self.prefix + '/theta']
logp = dist.NegativeBinomial(total_count = theta, probs = p)\
.log_prob(self._t(gene_expr))
logp_data = self.to_numpy(logp)[:, np.newaxis]
return logp_data
def get_logp(self, features):
raise DeprecationWarning('As of MIRA 0.2.0, "get_logp" is deprecated and no longer necessary.The "predict" method now returns logp(data) information')
@staticmethod
def to_numpy(X):
return X.clone().detach().cpu().numpy()
def get_savename(self, prefix):
return prefix + self.prefix + '.pth'
def _get_save_data(self):
return dict(bn = self.bn.state_dict(), guide = self.posterior_map)
def _load_save_data(self, state):
self._get_weights(loading = True)
self.bn.load_state_dict(state['bn'])
self.posterior_map = state['guide']
def save(self, prefix):
torch.save(self._get_save_data(), self.get_savename(prefix))
def load(self, prefix):
state = torch.load(self.get_savename(prefix))
self._load_save_data(state)
def _get_normalized_params(self):
d = {
k[len(self.prefix) + 1:] : v.detach().cpu().numpy()
for k, v in self.posterior_map.items()
}
d['bn_mean'] = self.to_numpy(self.bn.running_mean)
d['bn_var'] = self.to_numpy(self.bn.running_var)
d['bn_eps'] = self.bn.eps
return d
@staticmethod
def _select_informative_samples(expression, n_bins = 20, n_samples = 1500, seed = 2556):
'''
Bin based on contribution to overall expression, then take stratified sample to get most informative cells.
'''
np.random.seed(seed)
expression = np.ravel(expression)
assert(np.all(expression >= 0))
expression = np.log1p(expression)
expression += np.mean(expression)
sort_order = np.argsort(-expression)
cummulative_counts = np.cumsum(expression[sort_order])
counts_per_bin = expression.sum()/(n_bins - 1)
samples_per_bin = n_samples//n_bins
bin_num = cummulative_counts//counts_per_bin
differential = 0
informative_samples = []
samples_taken = 0
for _bin, _count in zip(*np.unique(bin_num, return_counts = True)):
if _bin == n_bins - 1:
take_samples = n_samples - samples_taken
else:
take_samples = samples_per_bin + differential
if _count < take_samples:
informative_samples.append(
sort_order[bin_num == _bin]
)
differential = take_samples - _count
samples_taken += _count
else:
differential = 0
samples_taken += take_samples
informative_samples.append(
np.random.choice(sort_order[bin_num == _bin], size = take_samples, replace = False)
)
return np.concatenate(informative_samples)
@staticmethod
def _prob_ISD(hits_matrix,*, correction_vector,
upstream_weights, downstream_weights,
promoter_weights, upstream_idx, promoter_idx, downstream_idx,
upstream_distances, downstream_distances, read_depth,
softmax_denom, gene_expr, NITE_features, params, bn_eps):
assert(isspmatrix(hits_matrix))
assert(len(hits_matrix.shape) == 2)
num_factors = hits_matrix.shape[0]
def tile(x):
x = np.expand_dims(x, -1)
return np.tile(x, num_factors+1).transpose((0,2,1))
def delete_regions(weights, region_mask):
num_regions = len(region_mask)
hits = 1 - hits_matrix[:, region_mask].toarray().astype(int) #1, factors, regions
hits = np.vstack([np.ones((1, num_regions)), hits])
hits = hits[np.newaxis, :, :].astype(int)
return np.multiply(weights, hits)
upstream_weights = delete_regions(tile(upstream_weights), upstream_idx) #cells, factors, regions
promoter_weights = delete_regions(tile(promoter_weights), promoter_idx)
downstream_weights = delete_regions(tile(downstream_weights), downstream_idx)
read_depth = read_depth[:, np.newaxis]
softmax_denom = softmax_denom[:, np.newaxis]
upstream_distances = upstream_distances[np.newaxis, np.newaxis, :]
downstream_distances = downstream_distances[np.newaxis,np.newaxis, :]
expression = gene_expr[:, np.newaxis]
def RP(weights, distances, d):
return (weights * np.power(0.5, distances/(1e3 * d))).sum(-1)
f_Z = params['a'][0] * RP(upstream_weights, upstream_distances, params['distance'][0]) \
+ params['a'][1] * RP(downstream_weights, downstream_distances, params['distance'][1]) \
+ params['a'][2] * promoter_weights.sum(-1) # cells, factors
original_data = f_Z[:,0]
sorted_first_col = np.sort(original_data).reshape(-1)
quantiles = np.argsort(f_Z, axis = 0).argsort(0)
f_Z = sorted_first_col[quantiles]
f_Z[:,0] = original_data
#f_Z = (f_Z - f_Z[:,0].mean(0,keepdims = True))/np.sqrt(f_Z[:, 0].var(0, keepdims = True) + bn_eps)
f_Z = (f_Z - params['bn_mean'])/np.sqrt(params['bn_var'] + bn_eps)
indep_rate = np.exp(params['gamma'] * f_Z + params['bias'] + \
correction_vector[:, np.newaxis])
compositional_rate = indep_rate/softmax_denom
mu = np.exp(read_depth) * compositional_rate
p = mu / (mu + params['theta'])
logp_data = nbinom(params['theta'], 1 - p).logpmf(expression)
logp_summary = logp_data.sum(0)
return logp_summary[0] - logp_summary[1:]#, f_Z, expression, logp_data
def probabilistic_isd(self, features, hits_matrix, n_samples = 1500, n_bins = 20):
np.random.seed(2556)
N = len(features['gene_expr'])
informative_samples = self._select_informative_samples(features['gene_expr'],
n_bins = n_bins, n_samples = n_samples)
for k in 'gene_expr,correction_vector,upstream_weights,downstream_weights,promoter_weights,softmax_denom,read_depth,NITE_features'.split(','):
features[k] = features[k][informative_samples]
samples_mask = np.zeros(N)
samples_mask[informative_samples] = 1
samples_mask = samples_mask.astype(bool)
return self._prob_ISD(
hits_matrix, **features,
params = self._get_normalized_params(),
bn_eps= self.bn.eps
), samples_mask
def _get_RP_model_coordinates(self, scale_height = False, bin_size = 50,
decay_periods = 20, promoter_width = 3000, *,
gene_chrom, gene_start, gene_end, gene_strand):
assert(isinstance(promoter_width, int) and promoter_width > 0)
assert(isinstance(decay_periods, int) and decay_periods > 0)
assert(isinstance(bin_size, int) and bin_size > 0)
assert(scale_height in [True, False])
rp_params = self._get_normalized_params()
upstream, downstream = 1e3 * rp_params['distance']
left_decay, right_decay, start_pos = upstream, downstream, gene_start
left_a, promoter_a, right_a = rp_params['a']
if gene_strand == '-':
left_decay, right_decay, start_pos = downstream, upstream, gene_end
right_a, promoter_a, left_a = rp_params['a']
left_extent = int(decay_periods*left_decay)
left_x = np.linspace(1, left_extent, left_extent//bin_size).astype(int)
left_y = 0.5**(left_x / left_decay) * (left_a if scale_height else 1.)
right_extent = int(decay_periods*right_decay)
right_x = np.linspace(0, right_extent, right_extent//bin_size).astype(int)
right_y = 0.5**(right_x / right_decay) * (right_a if scale_height else 1.)
left_x = -left_x[::-1] - promoter_width//2 + start_pos
right_x = right_x + promoter_width//2 + start_pos
promoter_x = [-promoter_width//2 + start_pos]
promoter_y = [promoter_a if scale_height else 1.]
x = np.concatenate([left_x, promoter_x, right_x])
y = np.concatenate([left_y[::-1], promoter_y, right_y])
return x, y
@property
def parameters_(self):
'''
Returns maximum a posteriori estimate of RP model parameters
as dictionary dict[str : parameter, float : value].
'''
norm_params = self._get_normalized_params()
params = {
k : np.atleast_1d(v)[0]
for k, v in norm_params.items() if len(np.atleast_1d(v)) == 1
}
params['a_upstream'] = norm_params['a'][0]
params['a_promoter'] = norm_params['a'][1]
params['a_downstream'] = norm_params['a'][2]
params['distance_upstream'] = norm_params['distance'][0]
params['distance_downstream'] = norm_params['distance'][1]
if self.use_NITE_features:
params.update({
'a-NITE_' + str(i) : v
for i, v in enumerate(norm_params['a_NITE'])
})
return params
[docs] @adi.wraps_modelfunc(fetch_TSS_from_adata,
fill_kwargs = ['gene_chrom','gene_start','gene_end','gene_strand'])
def write_bedgraph(self, scale_height = False, bin_size = 50,
decay_periods = 20, promoter_width = 3000,*, save_name,
gene_chrom, gene_start, gene_end, gene_strand):
'''
Write bedgraph of RP model coverage. Useful for visualization with
Bedtools.
Parameters
----------
adata : anndata.AnnData
AnnData object with TSS data annotated by `mira.tl.get_distance_to_TSS`.
save_name : str
Path to saved bedgraph file.
scale_height : boolean, default = False
Write RP model tails proportional in height to their respective
multiplicative coeffecient. Useful for evaluating not only the distance
of predicted regulatory influence, but the weighted importance of regions
in terms of predicting expression.
decay_periods : int>0, default = 10
Number of decay periods to write.
promoter_width : int>0, default = 0
Width of flat region at promoter of gene in base pairs (bp). MIRA default is 3000 bp.
Returns
-------
None
'''
coord, value = self._get_RP_model_coordinates(scale_height = scale_height, bin_size = bin_size,
decay_periods = decay_periods, promoter_width = promoter_width,
gene_chrom = gene_chrom, gene_start = gene_start, gene_end = gene_end, gene_strand = gene_strand)
with open(save_name, 'w') as f:
for start, end, val in zip(coord[:-1], coord[1:], value):
print(gene_chrom, start, end, val, sep = '\t', end = '\n', file = f)
[docs] @adi.wraps_modelfunc(rpi.fetch_get_influential_local_peaks, rpi.return_peaks_by_idx,
fill_kwargs=['peak_idx','tss_distance'])
def get_influential_local_peaks(self, peak_idx, tss_distance, decay_periods = 5):
'''
Returns the `.var` field of the adata, but subset for only peaks within
the local chromatin neighborhood of a gene. The local chromatin neighborhood
is defined by the decay distance parameter for that gene's RP model.
Parameters
----------
adata : anndata.AnnData
AnnData object with ATAC features and TSS annotations.
decay_periods : int > 0, default = 5
Return peaks that are within `decay_periods*upstream_decay_distance` upstream
of gene and `decay_periods*downstream_decay_distance` downstream of gene,
where upstream and downstream decay distances are given by the parameters
of the RP model.
Returns
-------
pd.DataFrame :
subset from `adata.var` to include only features/peaks within
the gene's local chromatin neighborhood. This function adds two columns:
`distance_to_TSS` : int
Distance, in base pairs, from the gene's TSS
`is_upstream` : boolean
If peak is upstream or downstream of gene
'''
assert isinstance(decay_periods, (int, float)) and decay_periods > 0
downstream_mask = (tss_distance >= 0) \
& (tss_distance < (decay_periods * 1e3 * self.parameters_['distance_downstream']))
upstream_mask = (tss_distance < 0) \
& (np.abs(tss_distance) < (decay_periods * 1e3 * self.parameters_['distance_upstream']))
combined_mask = upstream_mask | downstream_mask
return peak_idx[combined_mask], tss_distance[combined_mask]