Extropic recently announced their new ML-based chip and sister library, thrml. They’ve basically built an ASIC for energy-based models, where you are sampling from the energy landscape itself rather than training a network in the usual way. This means letting the hardware physically settle into low-energy states that represent likely outcomes.
This is exactly what biology already does. Quite literally, our molecules naturally conform to their lowest free-energy state, like amino acids folding into proteins, or a drug fitting a target. Each configuration’s probability is proportional to , just like molecules finding low-energy conformations. The chip effectively performs that relaxation in silicon, sampling probable low-energy configurations rather than computing a single deterministic outcome. This is different than what we are used to in normal programming.
In ML, however, it is common to receive multiple potential answers with associated probabilities. This is a big reason why we still require the wet lab for biology. We have a distribution of possibilities that locally minimizes free-energy and need to confirm which works best in the real world in a larger system.
Likewise, when building an ML model to minimize free-energy in some biological context, we go through a ton of transformations across many abstraction layers: a generalized hardware layer, mathematical calculations, tensor transformations, and library type conversions just to reach the same minimum free-energy conformation that nature finds automatically.
So, one can imagine that if you have this whole process embedded in hardware, it would be that much more efficient. We can’t test the real chip yet, but we can simulate it with their new library thrml on a GPU. Let’s do that.
Table of Contents
- Table of Contents
- I. Potts-Style Model
- Potts-style conditional energy function
- Conditional Boltzmann distribution
- Training objective
- II. Preprocessing
- III. Training
- IV. Conclusion
- Resources
I. Potts-Style Model
For this benchmark, we are training an energy-based model (EBM) to minimize free-energy on a cellular perturbation dataset. The goal is to compare the performance and efficiency between Extropic’s energy-based hardware chip and a conventional GPU, both running the same learning objective.
Because we don’t have access to Extropic’s physical chips, we are going to use thrml, their JAX-based virtual simulator, running on top of my GPU.
Is this setup completely fair given that we’re running a virtualized layer on a general-purpose GPU? Not really. But it should be sufficient for an initial proof of concept.
I’m no physicist and am not going to re-derive equations from scratch. Instead, for the energy-based model, we’re going to use the 2025 Bialek/Callan Ising model paper:
Potts-style conditional energy function
- : total system energy given cell state under perturbation
- : gene iii’s response (down/neutral/up)
- : the perturbation (target gene, dose, time, etc)
- : learned bias for gene i, a linear projection of perturbation embedding
- : symmetric coupling between genes i and j (learned, sparse, or GRN-prior)
Conditional Boltzmann distribution
- : probability of observing cell state given perturbation
- : partition function ensuring all probabilities sum to 1 across possible gene states
- states with lower energy correspond to higher probability
- the exponential term ensures that improbable (high-energy) states contribute less
- the denominator normalizes the distribution so that
Training objective
- → lowers energy on observed (post-perturbation) states
- → raises energy on sampled (model) states
- sampling is performed via block-Gibbs (genes/pathways as blocks)
II. Preprocessing
We are using the dataset from the Arc Virtual Cell Challenge. This includes about 15GB of cell perturbation data from 221,273 cells and 18,080 genes. Each record captures how a cell’s gene expression changes after a perturbation such as a CRISPR knockout, a drug treatment, or a viral infection. The data is stored in a Compressed Sparse Row (CSR) matrix with 48 experimental batches and an imbalanced set of 151 target genes.
To handle this efficiently, the .h5ad file was loaded in backed mode, allowing data to stream directly from disk rather than fully loading into memory. I then identified the top 2,000 highly variable genes (HVGs) per batch using the Seurat v3 method and took the union across all batches. After this reduction, I performed library-size normalization to 10,000 reads per cell, applied a log1p transform, and standardized each gene to zero mean and unit variance based on the training set only.
Because the dataset turned out to be highly imbalanced, with some perturbations represented by tens of thousands of cells and others by only a few dozen, I created stratified train/validation/test splits by target gene. I’ll probably use class-weighted sampling during training to mitigate bias toward dominant targets such as “non-targeting” controls. Instead of performing heavy batch correction, I kept the batch identifier as a categorical variable embedded in the perturbation function , allowing the model to learn batch context directly.
After HVG selection, the dataset was reduced to a dense ~221k*2k matrix (≈1.8 GB in float32 precision), small enough for in-memory GPU training. All processed outputs were saved to the artifacts/ directory, including:
tensors.pt(expression data)conditions.pt(target and batch encodings)vocab.json(categorical mappings)splits.json(train/validation/test indices)
To ensure numerical stability, z-scores were clipped to [-8, 8], affecting only about 0.1% of all values, while no genes were dropped for low variance (all 2,000 HVGs passed the threshold). These preprocessed tensors now serve as the standardized input for both the Gaussian and Potts energy-based models.
The preprocessing pipeline can be fully reproduced via:
python -m src.data.preprocess_arc --input path/to/arc.h5ad --mode {gaussian|potts}III. Training
We are training 3 models on an NVIDIA GPU with 12GB VRAM and 32GB of memory:
- Potts model with
jax - Potts model with
thrml - MLP baseline with
torch
Training the MLP is boring, so I’m going to skip explaining it. It just serves as our baseline for general comparison.
Training the Potts model is also pretty standard. It’s spiced up with jax because I like using high-performance tools and thrml is built on it. Before we get into the training of the Potts on jax vs thrml, below is some metadata about the data, train/test/split datasets, and metrics.
Everything went smooth for training + sampling with the MLP model and the jax-based Potts model, but we ran into an error when training the thrml-based Potts model.
The way thrml works under the hood is by building a fully connected energy graph where, in our case, each gene is a node and every pair of genes has a connection weight in a matrix . If there are N genes, thrml creates roughly unique edges to represent all pairwise interactions. That’s fine for small systems, but means 2M connections, which is too much for my system to sample from. Each training step in thrml computes the total system energy:
where means per-gene biases and captures how genes influence each other. Because this dense structure has to be stored and differentiated through, it very quickly exceeded my GPU memory.
To keep training practical and simple, we reduce the number of active genes from 2000 to 100, letting thrml operate on a smaller, fully connected subset while keeping the same mechanics for testing speed and convergence.
This affects our accuracy, which does affect sampling, but our goal is to test efficiency, not necessarily accuracy. So, to cut training time (and because I need to get this done with), we:
- reduce the number of genes being sampled from 2000 to 100
- we only have 10 epochs for each model trained
- the training is memory-heavy, compute-light, which is apparently typical for Gibbs sampling on dense Potts models
> console output when training the thrml model
When initializing training the thrml-based EBM, our GPU’s memory is heavily utilized, but with very little compute. Under the hood, jax is tracing the computational graph that our model will use. It pre-allocates buffers for our 100 genes, including pre-compiled kernels, gradients, and intermediates. It basically compiles all of this on the first step, which takes longer, but then makes every subsequent step of the training process super fast. So you get better returns on time the more epochs you run.
> GPU utilization from the command watch -n 1 nvidia-smi
That being said, I had to create a lot of optimizations, mostly due to the underlying graph data structure. This included only using JIT compilation to loss_fn and removing nested @jax.jit calls, removing a few Gibbs steps per epoch, using a batch size of 64 instead of 256, and eliminating any CPU bound work via JAX vmap and lax.scan (vectorization) for parallel processing. It significantly sped up training from 500 seconds/epoch to 70 seconds/epoch.
> JAX: higher energy baseline, less stable configurations
> thrml: multiple local minima reveal discrete gene expression patterns
IV. Conclusion
My problem choice was not the best, as I should have done a direct energy-minimization problem like protein folding or drug-ligand modeling. However, after retraining with the original 2000 genes and 100 epochs, the answers are pretty obvious. The manifolds show a lot more blue for the thrml based model process.
The use case for the chip makes total sense for sampling from vast probability distributions. For large scale inference this makes sense, and it makes even more sense in edge devices that are using a Software 2.0 native stack. As consumer AI morphs into hardware, it makes sense that this could become a chip that serves those devices.
Resources
thrml on GitHub (here)
Extropic on thermodynamic computing/sampling (here)
Algorithms to Live By, chapter 9 on sampling (here)
GitHub repo (here)