API Reference
Table of Contents
- data/dataset.py
- data/io.py
- model/instance.py
- model/nn.py
- model/predict.py
- model/state.py
- train/augment.py
- train/logger.py
- train/opt.py
- train/trainer.py
- utils/hvg.py
File: data/dataset.py
Class GinsengDataset
An on-disk dataset for training single-cell classifiers.
Attributes:
- path (
Path): Path to the zarr dataset on disk. - root (
zarr.Group): The root zarr group object. - n_cells (
int): Total number of cells (observations) in the dataset. - n_genes (
int): Total number of genes (variables) in the dataset. - label_names (
list of str): Human-readable names for the integer labels. - gene_names (
list of str): Names of the genes stored in the dataset. - labels (
np.ndarray): Integer labels for every cell in the dataset. - groups (
np.ndarray or None): Categorical group indices (e.g., batch or donor) if provided. - train_idx (
np.ndarray or None): Indices of cells assigned to the training split. - test_idx (
np.ndarray or None): Indices of cells assigned to the test split.
Method _load_metadata
Load metadata and small arrays into memory.
Method create
Create a GinsengDataset from an AnnData object or file path.
Parameters:
- path (
str | Path): Output path where the zarr dataset directory will be created. - adata (
str | Path | ad.AnnData): Input data. Can be an AnnData object, a local path to a (.h5ad, .h5, or 10x directory), or a URL to a supported file format. - label_key (
str): The column name inadata.obscontaining the target labels (e.g., cell type). - layer (
str): The key inadata.layersto use for expression counts. If None, usesadata.X(default : None). - genes (
str | list of str | np.ndarray): Gene selection/filtering logic. - If a string: Assumes it is a column inadata.varcontaining a boolean mask (e.g., "highly_variable"). - If a list or array: A specific set of gene names to keep. This will also reorder the output to match the provided list. - If None: Keeps all genes (default : None). - group_key (
str): The column name inadata.obscontaining grouping metadata, such as "batch" or "donor" (default : None). - chunk_size (
int): Number of rows (cells) per zarr chunk for the expression matrix. Larger chunks improve compression but require more RAM during streaming (default : 4096). - overwrite (
bool): Whether to delete the existing directory atpathif it exists (default : True).
Returns:
GinsengDataset: An initialized instance of the dataset pointing to the new zarr store.
Method make_split
Create train/test splits and store indices on disk.
Parameters:
- fraction (
float): The proportion of data (or groups) to include in the test split. Must be in the range [0.0, 1.0) (default : 0.1). - stratify_group (
bool): If True, splits by groups (e.g., donors) rather than individual cells. Requires group_key to have been provided during creation (default : False). - seed (
int): Random seed for reproducibility (default : 123).
Method stream
Stream mini-batches of (expression, label) from the zarr store.
Parameters:
- batch_size (
int): Number of cells to yield in each batch. - split (
Literal["train", "test", "all"]): Which data split to stream from (default : "train"). - balance_labels (
bool): If True, downsamples the training split to match the frequency of the least common class (default : False). - shuffle (
bool): Whether to shuffle the indices before streaming (default : True). Yields ------ - X (
np.ndarray): Expression matrix batch of shape (batch_size, n_genes). - y (
np.ndarray): Integer labels batch of shape (batch_size,).
File: data/io.py
Function read_10x_mtx
Read 10x Genomics mtx format into an AnnData object.
Parameters:
- path (
str): Path to directory containing matrix.mtx, barcodes.tsv, and genes.tsv/features.tsv. - var_names (
str): Select 'gene_symbols' or 'gene_ids' as index for var. - make_unique (
bool): If True, make var_names unique.
Returns:
AnnData:
Function read_10x_h5
Read 10x Genomics h5 format into an AnnData object.
Parameters:
- path (
str): Path to 10x Genomics .h5 file. - genome (
str, optional): Genome to extract (if file contains multiple genomes). If None, will use the first genome available. - var_names (
str): Select 'gene_symbols' or 'gene_ids' as index for var. - make_unique (
bool): If True, make var_names unique.
Returns:
AnnData:
Function read_adata
Read an AnnData object from various supported file formats.
Parameters:
- path (
str | Path): Path to the input count data stored in 10x .h5 format, AnnData .h5ad format, or in a 10x matrix market format folder. - backed (
bool): If True and the input is an.h5adfile, open the file in backed mode.
Returns:
AnnData: The loaded AnnData object containing gene expression data.
Function _load_pytree
Recursively load a PyTree from an HDF5 group.
Function save_model
Save a Ginseng model state to a single HDF5 file.
Parameters:
- state (
GinsengClassifierState): Complete model state to save. - path (
str | Path): Path to the output HDF5 file (will add .h5 if not present).
Examples:
>>> # After training
>>> state = GinsengClassifierState(
... params=model.params,
... genes=dataset.gene_names,
... label_keys=dataset.label_names,
... label_values=np.arange(len(dataset.label_names)),
... n_genes=dataset.n_genes,
... n_classes=len(dataset.label_names),
... hidden_dim=256,
... normalize=True,
... target_sum=1e4,
... dropout_rate=0.5,
... training=False,
... )
>>> save_model(state, "./models/my_classifier.h5")
Function _save_pytree
Recursively save a PyTree to an HDF5 group.
Function load_model
Load a Ginseng model state from an HDF5 file.
Parameters:
- path (
str | Path): Path to the HDF5 file.
Returns:
GinsengClassifierState: Complete model state ready for inference or continued training.
Examples:
>>> state = load_model("./models/my_classifier.h5")
>>>
>>> # Recreate model
>>> model = GinsengClassifier(
... n_genes=state.n_genes,
... n_classes=state.n_classes,
... hidden_dim=state.hidden_dim,
... dropout_rate=state.dropout_rate,
... normalize=state.normalize,
... target_sum=state.target_sum,
... )
>>> model.params = state.params
>>>
>>> # Prepare new data with correct gene order
>>> new_data_ordered = new_data[state.genes]
>>> predictions = model.predict(new_data_ordered.values, training=False)
File: model/instance.py
Class GinsengClassifier
Wrapper class for cell type annotation with automatic key management.
Attributes:
- n_genes (
int): Number of genes in input data. - n_classes (
int): Number of cell type classes. - hidden_dim (
int): Hidden dimension for attention mechanism (default: 256). - dropout_rate (
float): Dropout rate during training (default: 0.5). - normalize (
bool): Whether to normalize input data (default: True). - target_sum (
None | float): Target sum for normalization (default: 1e4). - seed (
int): Random seed for reproducibility (default: 42).
Examples:
>>> model = GinsengClassifier(n_genes=2000, n_classes=10)
>>> # During training (JAX-compatible)
>>> loss = model.loss(model.params, model._get_key(), x_batch, y_batch)
>>> # For prediction (Standard usage)
>>> logits = model.predict(x_test, training=False)
Method _get_key
Get a new PRNG key and update internal state.
Returns:
Array: A new JAX PRNG key.
Method predict
Generate predictions for input data.
Parameters:
- x (
Array): Input gene expression matrix. - params (
PyTree, optional): Model parameters. If None, uses internal self.params. - key (
Array, optional): PRNG key. If None, uses a new key from self._get_key(). - training (
bool): Whether to use training mode (with dropout). - return_attn (
bool): Whether to return attention weights.
Returns:
Array or tuple: Logits, or (logits, attention) if return_attn=True.
Method loss
Compute cross-entropy loss for a batch.
Parameters:
- params (
PyTree): Model parameters to differentiate against. - key (
Array): PRNG key for dropout randomness. - x (
Array): Input gene expression matrix. - y (
Array): True class labels.
Returns:
Array: Scalar loss value.
Method evaluate
Evaluate model on a batch and return loss and accuracy.
Parameters:
- x (
Array): Input gene expression matrix. - y (
Array): True class labels.
Returns:
tuple[float, float]: (loss, accuracy) on the batch.
Function nn_normalize
Normalize counts per cell and apply log transform.
Parameters:
- x (
Array): Count matrix. - target_sum (
float): Target total count per cell after normalization.
Returns:
Array: Normalized expression matrix.
Function nn_annotate_init
Initialize parameters for the cell type annotation network.
Parameters:
- key (
Array): PRNG key array for random initialization. - n_genes (
int): Number of genes. - n_classes (
int): Number of classes. - hidden_dim (
int): Dimension of hidden layers for attention mechanism.
Returns:
PyTree[Float[Array, "..."]]: Dictionary of initialized model parameters.
Function nn_annotate
Annotate cells using instance-based attention neural network.
Parameters:
- params (
PyTree[Float[Array, "..."]]): Model parameters fromnn_annotate_init. - key (
Array): PRNG key array for random number generation. - x (
Array): Input gene expression matrix (batch_size, n_genes). - dropout_rate (
float): Dropout probability (default: 0.5). - normalize (
bool): If True, normalize and log-transform counts (default: True). - target_sum (
None | float): Target total count per cell after normalization (default: 1e4). - return_attn (
bool): If True, also return attention weights (default: False). - training (
bool): Whether the model is in training mode (default: True).
Returns:
Float[Array, ...] | tuple[Float[Array, ...], Float[Array, ...]]: Logits for each class. Ifreturn_attn=True, gene-level attention weights are also returned as second element.
Function nn_annotate_loss
Cross-entropy loss for cell type annotation model.
Parameters:
- params (
PyTree[Float[Array, "..."]]): Model parameters fromnn_annotate_init. - key (
Array): PRNG key array for random number generation. - x (
Array): Input gene expression matrix (batch_size, n_genes). - y (
Array): True cell type labels as integer class indices (batch_size,). - dropout_rate (
float): Dropout probability (default: 0.5). - normalize (
bool): If True, normalize and log-transform counts (default: True). - target_sum (
None | float): Target total count per cell after normalization. Defaults to 1e4 (standard for scRNA-seq).
Returns:
Float[Array, ""]: Scalar cross-entropy loss.
Function nn_annotate_evaluate
Evaluate model performance on a batch.
Parameters:
- params (
PyTree[Float[Array, "..."]]): Model parameters fromnn_annotate_init. - key (
Array): PRNG key array for random number generation. - x (
Array): Input gene expression matrix (batch_size, n_genes). - y (
Array): True cell type labels as integer class indices (batch_size,). - dropout_rate (
float): Dropout probability (default: 0.5). - normalize (
bool): If True, normalize and log-transform counts (default: True). - target_sum (
None | float): Target total count per cell after normalization. Defaults to 1e4 (standard for scRNA-seq).
Returns:
tuple[Float[Array, ""], Int[Array, ""], Int[Array, ""]]: Tuple of (loss, total_samples, correct_predictions).
File: model/nn.py
Function nn_xavier_uniform
Initialize weights with Xavier uniform distribution.
Parameters:
- key (
Array): PRNG key array for random number generation. - shape (
tuple of int): Shape of the weight matrix.
Returns:
Array: Initialized weight matrix.
Function nn_init_linear
Initialize parameters for a linear layer.
Parameters:
- key (
Array): PRNG key array for random initialization. - in_dim (
int): Input dimension. - out_dim (
int): Output dimension.
Returns:
PyTree[Float[Array, "..."]]: Dictionary with weight matrixWand bias vectorb.
Function nn_linear
Apply a linear transformation.
Parameters:
- params (
PyTree[Float[Array, "..."]]): Layer parameters withWandb. - x (
Array): Input tensor.
Returns:
Array: Transformed output tensor.
Function nn_dropout
Apply dropout to input array with optimized implementation.
Parameters:
- key (
Array): PRNG key array for dropout mask. - x (
Array): Input tensor. - rate (
float): Dropout probability. - training (
bool): If True, apply dropout. Set to False for deterministic output after training.
Returns:
Array: Tensor with dropout applied.
File: model/predict.py
Function _get_gene_indices
Map model gene names to indices within the AnnData object.
Parameters:
- adata (
AnnData): The single-cell data object. - gene_names (
np.ndarray): Array of gene names expected by the model in a specific order. - gene_key (
str | None): Column name inadata.varcontaining gene names. If None, usesadata.var_names.
Returns:
available_idx: np.ndarray The integer indices of genes found inadatathat match the model's expected genes. out_positions : np.ndarray The corresponding positions in the model's input vector where these genes belong.
Function annotate_iter
Batch iterator that handles gene reordering and zero-padding for missing genes.
Parameters:
- adata (
AnnData): The single-cell data object. - gene_names (
np.ndarray): Ordered gene names from the trained model. - gene_key (
str | None): Column inadata.varto use for gene matching. - layer (
str | None): Key inadata.layersto use for counts. If None, usesadata.X. - batch_size (
int): Number of cells to process per batch. Yields ------ - batch_tensor (
jax.numpy.ndarray): A JAX-compatible array of shape (batch_size, n_model_genes). - start (
int): Starting observation index. - end (
int): Ending observation index.
Function classify
Annotate single-cell sequencing data using a trained ginseng classifier.
Parameters:
- model_state (
GinsengClassifierState | str | Path): A loaded GinsengClassifierState or a path to a saved state file. - adata (
AnnData | str | Path): AnnData object or path to count data (.h5ad, .h5, or Matrix Market). - gene_key (
str | None): Column in.varcontaining gene names. If None, uses index. - layer (
str | None): Key inadata.layersto use for counts. If None, usesadata.X. - backed (
bool): If True andadatais a path, reads data in backed mode. - normalize (
bool | None): Override model normalization setting. - target_sum (
float | None): Override model target sum. - randomness (
bool): If True, enables dropout during inference. - batch_size (
int): Number of cells to process in each forward pass. - copy (
bool): If True, returns a modified copy of the AnnData object. - store_probs (
bool): If True, stores the full probability matrix inadata.obsm. - return_table (
bool): If True, returns a pandas DataFrame instead of modifying AnnData. - seed (
int): Random seed. - silent (
bool): If True, suppresses progress bar and warnings.
Returns:
AnnData | pd.DataFrame | None: Returns a DataFrame ifreturn_tableis True, a copy of AnnData ifcopyis True, otherwise modifies in-place and returns None.
File: model/state.py
Class GinsengClassifierState
Complete state of a trained Ginseng model.
All information needed to save, load, and use a trained model.
Attributes:
- params (
PyTree): Model parameters (weights and biases). - genes (
np.ndarray): Gene names in the exact order expected by the model. - label_keys (
np.ndarray): Label names (e.g., ['T-cell', 'B-cell', 'Macrophage']). - label_values (
np.ndarray): Integer values corresponding to each label (e.g., [0, 1, 2]). - n_genes (
int): Number of genes. - n_classes (
int): Number of classes. - hidden_dim (
int): Hidden dimension used in attention mechanism. - normalize (
bool): Whether input data should be normalized. - target_sum (
float): Target sum for normalization. - dropout_rate (
float): Dropout rate used during training. - training (
bool): Whether weights should be frozen (False after training).
Function classifier_from_state
Create a GinsengClassifier from a loaded state.
Parameters:
- state (
GinsengClassifierState): Loaded model state.
Returns:
GinsengClassifier: Model ready for inference.
Function state_from_classifier_trainer
Create a GinsengClassifierState from a trainer after .fit().
Parameters:
- trainer (
GinsengClassifierTrainer): Trainer instance with a trained model.
Returns:
GinsengClassifierState: Complete model state.
Examples:
>>> trainer = GinsengClassifierTrainer(dataset, settings)
>>> trainer.fit(epochs=50)
>>> state = state_from_trainer(trainer)
>>> save_model(state, "./models/my_classifier.h5")
File: train/augment.py
Function augment_mask
Randomly mask out counts across cells.
Parameters:
- key (
Array): PRNG key array for dropout mask. - x (
Array): Input tensor. - rate (
float): Dropout probability.
Returns:
Array: Tensor with dropout applied.
Function augment_background
Randomly add Poisson-distributed background noise to counts.
Parameters:
- key (
Array): PRNG key array for noise generation. - x (
Array): Non-normalized count matrix. - lam_max (
float): Maximum mean of Poisson distribution for sampling added noise.
Returns:
Array: Counts with added Poisson noise.
Function augment_dropgene
Randomly zero out entire genes.
Parameters:
- key (
Array): PRNG key array for mask generation. - x (
Array): Count matrix. - lower (
int): Minimum number of genes to mask out. - upper (
int): Maximum (exclusive) number of genes to mask out.
Returns:
Array: Filtered counts.
Function augment
Apply a combination of single-cell RNA relevant augmentations.
Parameters:
- key (
Array): PRNG key array for mask generation. - x (
Array): Count matrix. - rate (
float): Dropout probability. - lam_max (
float): Maximum mean of Poisson distribution for sampling added noise. - lower (
int): Minimum number of genes to mask out. - upper (
int): Maximum (exclusive) number of genes to mask out.
Returns:
Array: Augmented counts.
File: train/logger.py
Class GinsengLogger
Logger for storing training and validation metrics across epochs.
Attributes:
- epoch (
list[int]): List of epoch indices. - train_loss (
list[float]): Training loss values for each epoch. - holdout_loss (
list[float]): Holdout loss values for each epoch. - holdout_accuracy (
list[float]): Holdout accuracy values for each epoch.
Method update
Update the logger with new training and validation metrics.
Parameters:
- epoch (
int): Current epoch index. - train_loss (
float): Training loss at this epoch. - holdout_loss (
float): Validation loss at this epoch. - holdout_accuracy (
float): Validation accuracy at this epoch.
Method report
Print most recent result to standard output.
Parameters:
- silent (
bool): If True, suppresses report output. - flush (
bool): If True, write report to output immediately.
Returns:
None: Output is printed to standard output.
File: train/opt.py
Class AdamState
State for the Adam optimizer.
Attributes:
- step (
int): Current optimization step. - m (
PyTree of Array): Exponential moving average of gradients. - v (
PyTree of Array): Exponential moving average of squared gradients.
Function opt_init_adam
Initialize Adam optimizer state.
Parameters:
- params (
PyTree[Float[Array, ...]]): Model parameters to be optimized.
Returns:
AdamState: Initial optimizer state with zeroed moments.
Function opt_adam_update
Perform one Adam optimization step.
Parameters:
- grads (
PyTree[Float[Array, ...]]): Gradients of the loss w.r.t. parameters. - params (
PyTree[Float[Array, ...]]): Current model parameters. - state (
AdamState): Current optimizer state. - lr (
float): Learning rate. - betas (
tuple[float, float]): Exponential decay rates for first and second moment estimates. - eps (
float): Numerical stability constant. - weight_decay (
float): L2 regularization factor.
Returns:
tuple: Updated parameters and new optimizer state.
File: train/trainer.py
Class GinsengClassifierTrainerSettings
Training configuration settings for GinsengClassifierTrainer.
Class GinsengClassifierTrainer
Trainer for GinsengClassifier.
Method _train_step
Internal training step to update parameters.
Method fit
Fit the model.
Parameters:
- epochs (
int): Number of training epochs to run (default : 10). - silent (
bool): If True, suppresses training progress output (default : False).
Returns:
GinsengClassifier: The trained classifier with updated parameters.
Method _run_epoch
Run a single training epoch.
Method _validate
Run validation on the holdout split.
File: utils/hvg.py
Function select_hvgs
Select highly variable genes from raw or normalized counts.
Parameters:
- adata (
AnnData): AnnData object with gene names stored invar. - n_top_genes (
int): Number of top highly variable genes to select. - layer (
str, optional): Key inadata.layersto use. If None, usesadata.X. - target_sum (
float, optional): If provided, scales each cell to this sum and applies log1p transformation (default: 1e4). - min_mean (
float): Lower quantile bound on mean gene expression. - max_mean (
float): Upper quantile bound on mean gene expression. - n_bins (
int): Select genes across this many gene expression bins. - chunk_size (
int): Number of cells to process in memory at once. - silent (
bool): If True, suppresses progress bar. - copy (
bool): If True, returns a copy of the AnnData.
Returns:
Optional[AnnData]: Marks highly variable genes invar['ginseng_genes'].