Guides
The guides package hosts guides for using stormi.
AmortizedNormal
stormi.guides.AmortizedNormal(
self,
model,
model_input,
*,
init_net_params=None,
init_loc_fn=init_to_mean,
init_seed=0,
nn_width=32,
nn_depth=2,
props_t=(0.7, 0.1, 0.1, 0.1),
props_y=(0.5, 0.3, 0.1, 0.1),
props_l=(0.5, 0.3, 0.1, 0.1),
props_pw=(0.7, 0.1, 0.1, 0.1),
hvg_n_top=2000,
rna_embeddings=None,
gene_nn_width=32,
gene_nn_depth=2,
share_gene_trunk=True,
)Compose an AutoNormal with an amortized neural guide for cell-specific (‘local’) parameters.
This wrapper: - infers which amortized heads to enable from model_input (RNA/ATAC presence, number of paths), - builds a two-part AutoGuideList (AutoNormal for globals + MLP for locals), - optionally warms up the MLP with a prior-only SVI and rebinds those weights, - provides convenience utilities to get warm predictions, save/load warm params, and extract posterior means (global + local, batched if needed).
Methods
| Name | Description |
|---|---|
| extract_all_medians | Get (global_medians, local_medians) in one call, with batched local extraction. |
| load_warm | Load the cached warm parameters from a pickle file. |
| plot_warm | Plot UMAPs of warm-start predictions directly from the object. |
| save_warm | Save the cached warm parameters to a pickle file. |
| warm_predictions | Compute warm-start predictions using the cached warm parameters. |
| warm_up | Run prior-only warm-up and rebind the amortized network with learned weights. |
extract_all_medians
stormi.guides.AmortizedNormal.extract_all_medians(
model_input,
training_output,
*,
batch_size=1000,
**kw,
)Get (global_medians, local_medians) in one call, with batched local extraction.
Parameters
model_input : dict Input data dictionary used for inference. Must contain at least ‘data’, ‘obs2sample’, and ‘M_c’. Optional: ‘data_atac’. training_output : dict Dictionary returned by training, expected to include {‘guide’: AutoGuideList, ‘svi’: SVI, ‘svi_state’: Any}. batch_size : int, optional Batch size for local extraction. Default is 1000. **kw : Additional keyword arguments forwarded to extract_local_means_full (e.g., num_paths, etc.).
Returns
tuple[dict, dict] (global_medians, local_medians) dictionaries.
load_warm
stormi.guides.AmortizedNormal.load_warm(path)Load the cached warm parameters from a pickle file.
Parameters
path : str | pathlib.Path File path to read from.
Returns
dict The loaded parameter dict. Also rebinds the guide so that subsequent training starts from these weights.
plot_warm
stormi.guides.AmortizedNormal.plot_warm(
adata_rna,
model_input,
*,
data_atac=None,
day_key=None,
size=100,
ncols=4,
cmap='inferno',
return_axes=False,
)Plot UMAPs of warm-start predictions directly from the object.
Thin wrapper around warm_predictions + plot_warm_params.
Parameters
adata_rna : anndata.AnnData model_input : dict data_atac : jnp.ndarray or None day_key : str or None size, ncols, cmap, return_axes : see plot_warm_params.
Returns
list[matplotlib.axes.Axes] or None
save_warm
stormi.guides.AmortizedNormal.save_warm(path)Save the cached warm parameters to a pickle file.
Parameters
path : str | pathlib.Path File path to write to.
warm_predictions
stormi.guides.AmortizedNormal.warm_predictions(model_input, data_atac=None)Compute warm-start predictions using the cached warm parameters.
Parameters
model_input : dict Must contain arrays needed by warm_forward_predictions. data_atac : jnp.ndarray or None Required when the cached params include an l head.
Returns
dict See warm_forward_predictions.
warm_up
stormi.guides.AmortizedNormal.warm_up(model_input, n_steps=10000, seed=0)Run prior-only warm-up and rebind the amortized network with learned weights.
Parameters
model_input : dict Inputs required by warm_up_guide. n_steps : int Number of warm-up SVI steps. seed : int Random seed for SVI init.
Returns
dict The learned amortized-network parameters. Also cached on the object and used to rebuild the guide with these as initial values.