10/12/25
In this project, we build a model that can annotate cells by their type. Kinda like Yolo8 but for labeling cells with their names.
A VAE is just a sub-type of neural network that combines the concept of encoding data with approximating probabilities in latent space with bell-shaped curves. The probabilities eventually translate into patterns observed in whatever data has been digitized (encoded), which then becomes the predictive engine you use after training and deployment.
The model first compresses the input data into a smaller summary (encodes), finds some patterns in latent space (pattern-finding), then outputs a decoded version (the prediction).
We will also build a transformer model to compare against the VAE.
Building a model involves the same five steps every time:
- acquire the data
- process the data
- create the model
- test models predictive power
- deploy the trained weights + API to use them
Data + prep
The scRNA-seq dataset we are using is from PBMCpedia (1.4GB and 23.4GB). It has standardized labels, aligned features, and unified batch metadata, all yielding cleaner labels, less bias, and more reliable results.
When we download the data, we get .h5ad files. I used the 1.4GB on my M2 Macbook Pro. The dataset has 4.3M cells, 2000 preselected, highly-variable genes, it is already normalized and log-transformed, and has existing UMAP embeddings. This saves us a lot of time.
To download the data, we run curl -L "https://web.ccb.uni-saarland.de/downloads/pbmcpedia/pbmcpedia-v20250915-full.h5ad" -o pbmcpedia-v20250915-full.h5ad in the root folder of our repo.
We then load the data with data_adapter.py::load_h5ad_dataset, where:
anndata.read_h5ad()loads the data from the downloaded filessklearn.train_test_split()samples the data (100k cells)scanpy.pp.calculate_qc_metrics()computespct_counts_mt, which is the percentage of a cell’s RNA counts coming from mitochondrial genes
We then apply batch correction, which just normalizes the data and removes noise from whatever tools the experiment used. We use scanpy combatto correct batch effects (which are already log-normalized). Combat works with normalized data and modifies the expression matrix in-place. The output embeddings are basically learned patterns, where similar cells get similar numbers, and patterns emerge.
Finally, we run PCA (50 components) on those embeddings to reduce dimensionality, build a neighbor graph, and then compute a 2D UMAP on top of them.
Next, we apply QC filtering, which removes low quality cells based on gene and mitochondrial gene expression levels.
Lastly, we visualize the data with matplotlib and seaborn for violin plots, and scanpy.pp.pca() + scanpy.tl.umap() for dimensionality reduction. We first reduce from 2000 genes to 50 PCA components, then to 2D with UMAP for visualization. To generate these, we run this command from the root of the repo:
uv run python -m src.app.cli --config config/default.yaml prepare_dataFor the filtered violin plot, we get:
And for the UMAP plot, we get:
Nice, first part done!
Models
As mentioned earlier, we want to train two models- a VAE and a transformer, to compare the quality of their outputs. When building both of these models, they follow the same pattern:
- define a PyTorch
nn.Moduleclass, configure dimensions, layers, and dropout based on the hardware specified in the config file - handle hardware placement + management with the
ModelAdapterclasses, managing the optimizer, scheduler, and training state - load, the
.h5adprocessed data, split train/val/test datasets - train the models, iterating epochs with the
train_step()andeval_step(), tracking metrics - test set validation via F1 score
- save the model, metrics file (
metrics.json), training curves, and confusion matrix
VAE Model
The core architecture for this model is in src/app/core/vae_adapter.py. We have the encoder (VAEEncoder), decoder (VAEDecoder), and the VAE class (VAE).
Encoder
The VAEEncoder is comprised of linear layers that learns mu and log_var for learning that probability distribution we mentioned earlier. This is the "probability cloud" that we figure out patterns from. To train it, we pass in a single cell's 2000 gene expression values (batch_size, 2000) through layers 2000→128→64, outputting mu and log_var vectors (both with shape [batch_size, 10]). Each cell is represented as a Gaussian distribution N(mu, sigma²) in 10D latent space (reduced from PCA above).
Reparameterization
The VAE class combines encoder and decoder, then implements a reparameterization trick, which lets gradients flow through mu (center of probability cloud) and sigma (standard deviation, how spread out the cloud is). We need to do this because as we train, we pull out random numbers to test how we are doing, but backpropagation expects smoothness (like water flowing and us inserting sticks that mess up the flow). The trick (VAE.reparameterize()) samples z = mu + exp(0.5*log_var)*epsilon where epsilon~N(0,1), allowing gradients to flow through mu and log_var.
Decoder
The VAEDecoder is made up of linear layers that reconstructs gene expression from sampled point z in latent space via layers 10→64→128→2000. This is where we output the meaningful biological information that we can use to do something.
Loss
Loss combines reconstruction MSE (mean squared error, how well it rebuilds genes) and KL divergence (regularizes latent to N(0,1)).
Training
Training (VAEModelAdapter.train_step()) optimizes both losses.
Inference
For inference, predict() returns only mu (the mean/average of the learned distribution, ignoring variance) for downstream clustering and visualization.
Transformer Model
The core architecture for this model is in src/app/core/transformer_adapter.py. We have the TransformerClassifier class that handles the full architecture, and TransformerModelAdapter for training and inference.
Input Embeddings
The model takes a single cell's 2000 gene expression values (batch_size, 2000) and passes them through an input projection layer (2000→d_model=256). A positional encoding is added to help the model understand the sequence structure of the genes.
Transformer Encoder
The embedded input flows through 4 stacked transformer encoder layers (TransformerEncoderLayer in PyTorch). Each layer has multi-head self-attention (nhead=8) that lets the model learn relationships between genes, followed by a feedforward network (dim_feedforward=512). This is where the model figures out which genes co-express and matter for cell type classification.
Latent Representations
After the transformer layers, we apply mean pooling across the gene dimension to get a single vector per cell. This is projected to a latent_dim=64 representation that captures the cell's identity.
Classification Head
The latent vector passes through a final linear layer (64→num_classes=10) to predict cell type probabilities via softmax.
Loss
Cross-entropy loss measures how well predicted probabilities match true cell type labels.
Training
Training (TransformerModelAdapter.train_step()) computes loss, backpropagates gradients, and uses a learning rate scheduler (ReduceLROnPlateau) that reduces learning rate when validation loss plateaus. Early stopping prevents overfitting (cmd_train_transformer() in cli.py).
Inference
For inference, predict() returns predicted class probabilities for each cell, used to generate confusion matrices and evaluate classification accuracy (F1-macro score).
Results
The results for the VAE are good, but the initial run for the transformer has terrible results. It's predicting mostly one class (likely the majority cell type), meaning it is focusing on that one.
VAE
- Best epoch: 65 (early stopped at 75)
- Test total loss: 269.64
- Reconstruction loss: 265.36
- KL divergence: 4.28
- Training time: 152 seconds (~2.5 min)
Transformer
- Best epoch: 21 (early stopped at 31)
- Test accuracy: 68.03%
- Test F1-macro: 0.0809 (bad, severe class imbalance)
- Training time: 132 seconds (~2.2 min)
We implemented multiple separate changes, such as adding class weights before training, doubling the data size to 100k cells, and doubling the size of the model (deeper and more parameters), but these yielded worse results than the original scores above.
- ✗ Class-weighted loss (made it worse: 25% acc, 0.04 F1)
- ✗ More training data (100k vs 50k cells: no improvement)
- ✗ Larger model (6L/512d vs 4L/256d: no improvement)
Fundamentally, the transformer learns to predict only the majority class (~68% of samples), achieving the same accuracy by always guessing the most common cell type. This most likely happens because the raw 2000-gene input is too noisy compared to the VAE's learned 10D representations, causing the model to settle into the local minimum of "always predict majority class" due to severe class imbalance.
There are a lot of actions we can take to improve this, but I must move forward with other projects. An interesting experiment, nonetheless!
Recreating
If you wish to try this out for yourself, feel free to either clone the codebase and work through it with your favorite AI, or use the tutorial notebook. You can reach out to me through any means on the front page of the site if you have problems.