This was a great project, but took way longer than I anticipated. 4 days execution vs 1 day expectation. I read part of Sutton’s book last year and thought the concept of computers autonomously learning was pretty powerful, since this is exactly how lifeforms do it (albeit, a simplistic model). If you could have another model create the variables for the reward function, meaning they are ever increasing in description, that would make a very, very scalable business. I tried that with the reward shaper in this project.
I tried a lot of stuff in this- probably too much. I created scripts to get data from Uniprot and ChEMBL, then push those AA sequences to ESM to get embeddings, to then combine those and use them as a data structure. Then I tuned the environment and the reward function, ultimately deciding to test different RL architectures. I tested PPO, DQN, and A2C via the Stable-Basline3
package, and improved the metrics to be way more complete. I then turned it into an API that can be used by startups like ProFound Therapeutics.
This API can ingest real experimental decisions (like which proteins ProFound triaged, tested, or validated), learns a policy that mimics or improves those decisions, then ranks unseen proteins to recommend next-best targets. The model is validated against proteins that we currently have working knowledge of.
Intro
I drive a Tesla with Autopilot. It’s incredibly useful and getting better, though it struggles in complex situations- like being cut off, boxed-in, and needing to exit in 0.1 miles, or parking in tight lots. Still, it’s >99% flawless.
But one thing consistently irritates me: it favors short-term rewards over long-term ones.
Every day, I leave the gym, merge onto a 5-lane freeway, then exit in under a mile. When I drive, I just stay in the exit lane. Autopilot merges onto that first lane, then merges across 4 lanes to the far left, only to merge back to the first lane immediately, since it’s less than 1 mile. It does this every time, even if there is heavy traffic. It’s dumb.
Why? I think its neural net is myopic. The reward function seems tuned for short-term exploitation- not long-term planning. This may be for safety, but is more likely due to an incomplete RL design. The car behaves as if it’s optimizing <1 mile horizons.
Fixes could include increasing gamma (longer planning horizon), use a more global reward function with context (next exits, traffic, comfort), and penalize unnecessary lane changes.
So what does this have to do with protein folding RL agents?
Project
Alright so I didn’t realize that I was going to be building a model when I started this. It turns out RL is not just applying self-directed fine-tuning to a big model (like a foundation model). You are actually creating a whole new abstraction layer on top of the foundation model, using the other model’s output embeddings as input for your RL model.
When AlphaFold, Boltz2, ESM-2, GO, etc consume an input amino acid sequence, they do a bunch of transformations on that sequence, then output vectors/matrices of float16’s/float32’s. These then become the input embeddings to our model.
Source/model | Output for ML/RL agent | Example Models |
Embedding model | 1D vector of floats | UniRep, ESM, ProtBERT, SeqVec, TAPE |
Structure model | 2D matrix (coords) → 1D/2D | AlphaFold (3D coords), Rosetta |
Feature vector | 1D vector of floats | Hand-crafted features (AA composition, pI, hydrophobicity), GO term embeddings, Pfam/InterPro domain vectors, EC class embeddings |
Contact map | 2D matrix (can be flattened) | AlphaFold (contact map) |
I trained a reinforcement learning agent to rank proteins using ESM-2 embeddings and experimental data. Training ran for 2 million timesteps on an RTX 4070 Ti using the SBX library (JAX-based).
ESM Embeddings
I added embeddings from ESM, since it seemed like the easiest to do on a Mac. AlphaFold was way too heavy. Each embedding represents a high-dimensional, biologically meaningful feature to represent each protein sequence as input to the RL model. This means we don’t need structural data, just the sequence embeddings. By decoupling the feature extraction from training, we avoid the computation cost of fine-tuning the agent. The agent learns to prioritize protein targets based on these embeddings
ChEMBL Data
Our goal with this was to replace random “hits” with real labels by using experimental data (binding assays, functional screens). The download process took too long for this project timeline when working with ChEMBL (was supposed to be done in 24 hours, ended up working on this for 4 days). There is definitely an API opportunity for one of the big experimental biology companies (maybe even a CRO like Ginkgo) to make a validation dataset like this for assay data. That is going to be a very important pillar for AI being applied in molecular biology.
Adding Uniprot + BindingDB to replace Chembl
Endpoint | For |
/target | Map UniProt → ChEMBL target, get metadata |
/activity | Get binding/functional data for target |
/molecule | Get ligand/compound info |
/assay | Get experiment/assay context |
Unifying ChEMBL data + ESM embeddings
Before: RL agent learned from random noise (fake features + fake hits) - couldn't find real patterns
After: RL agent learns from real biological relationships (protein structure → experimental activity) - can discover meaningful protein prioritization strategies
Optimizing the Reward Function
The system balances multiple objectives to calculate the reward, including activity strength, which measures how well the protein binds to targets; diversity, which measures how different selected proteins are from each other; novelty which measures how understudied the protein is; and cost, which measures how expensive it would be to study experimentally.
Old Reward function
- : binary hit label (1 if hit, 0 if not)
- : binding affinity in nM and log-scaled (molar)
- : functional activity (normalized, 0–1, scaled by 5)
- : toxicity score (normalized, 0–1, penalized by -2)
- : expression level (normalized, 0–1, penalized by 2)
New Reward Function
- (tunable constant to center the affinity reward)
- (tunable weights)
Multi-objective reward function improvement strategies:
- Normalization: all variables should be normalized to comparable scales (0–1)
- Pareto weighting: allow for tunable weights for each objective
- Nonlinear transforms: use sigmoid/softplus for smoother gradients
- Penalize missing/invalid data: avoid bias from missing values
- Domain knowledge: use biologically meaningful thresholds
Training
We trained a bunch. The ppo_2M_best_
turned out to be the best.
Timesteps | Final Eval Mean Reward |
500k | 43.3 |
1M | 44.3 |
2M | 44.9 |
Training on 4070 Ti with JAX-based SBX library.
Metric | Result |
Mean Reward | 44.92 (vs. ~40.9 for random) |
Performance Gain | +9.9% over random policy |
Reward Std Dev | 0.00 — perfectly stable policy |
Convergence Time | ~1 hour (on GPU) |
Episode Length | 50 proteins (fixed; reached every time) |
Model Quality Summary
- High-Reward Policy: Agent consistently selected high-value proteins
- Fast & Stable Training: Learning curve rose quickly (~200k steps) and converged cleanly
- Strong Generalization (within set): No catastrophic forgetting or instability observed
Shortcomings + Improvements
Issue | Details |
Low Diversity | Only 3 unique proteins selected repeatedly (diversity score: 0.006) |
Overfitting Risk | Possible reliance on narrow strategy; may not generalize |
Reward Function Limit | Lacked explicit diversity/novelty terms |
Validation Gap | Model was only tested on training set |
Testing Different RL Architectures (CPU)
We stuck with the PPO agent for the main model, but I was curious to see how other RL architectures would perform.
First, I created the single abstraction layers for each RL agent architecture- PPO, DQN, Actor-Critic. Then, I realized we could just import SB3 and use their classes for our program.
So I did just that and created a test file (test_sb3_algorithms.py
) that also outputs some useful data. I also updated the metrics to be more complete. These metrics measure how well the agent selects proteins.
Hit Rate: Total proteins evaluated / number of hits found
Average Reward: Mean reward per episode
Cumulative Reward: Total reward across all episodes
Precision@k: Fraction of top-k selected proteins that are actual hits
Recall@k: Fraction of all hits found in top-k selections
F1-Score: Harmonic mean of precision and recall
AUC-ROC: Area under the receiver operating characteristic curve
Normalized Discounted Cumulative Gain (NDCG): Ranks proteins by predicted vs. actual quality
# after execution
Summary:
PPO: Mean Reward = 3.18, Rewards = [np.float32(3.1789813), np.float32(3.1789813), np.float32(3.1789813), np.float32(3.1789813), np.float32(3.1789813)]
Metrics: {'hit_rate': 1.0, 'avg_reward': np.float32(3.1789813), 'cum_reward': np.float32(15.894907), 'precision@k': 1.0, 'recall@k': 0.1, 'f1@k': 0.18181818181818182, 'auc_roc': nan, 'ndcg': 1.0}
DQN: Mean Reward = 10.89, Rewards = [np.float32(10.893245), np.float32(10.893245), np.float32(10.893245), np.float32(10.893245), np.float32(10.893245)]
Metrics: {'hit_rate': 0.7, 'avg_reward': np.float32(10.893245), 'cum_reward': np.float32(54.466225), 'precision@k': 0.6, 'recall@k': 0.08571428571428572, 'f1@k': 0.15, 'auc_roc': 0.5, 'ndcg': 1.0}
A2C: Mean Reward = 12.04, Rewards = [np.float32(12.038507), np.float32(12.038507), np.float32(12.038507), np.float32(12.038507), np.float32(12.038507)]
Metrics: {'hit_rate': 1.0, 'avg_reward': np.float32(12.038507), 'cum_reward': np.float32(60.192535), 'precision@k': 1.0, 'recall@k': 0.1, 'f1@k': 0.18181818181818182, 'auc_roc': nan, 'ndcg': 1.0}
/Users/bradleywoolf/Desktop/software/bioinformatics_practice/ProtRankRL/tests/test_sb3_algorithms.py:127: FutureWarning:
Conclusion
In benchmarking PPO, DQN, and A2C agents on the protein triage task, DQN and A2C consistently achieved higher mean and cumulative rewards than PPO, with A2C slightly outperforming DQN. In the chart below, higher values mean better.
Agent | Mean Reward | Hit Rate | Precision@5 | Recall@5 | F1@5 | NDCG |
PPO | 3.18 | 1.00 | 1.00 | 0.10 | 0.18 | 1.00 |
DQN | 10.89 | 0.70 | 0.60 | 0.09 | 0.15 | 1.00 |
A2C | 12.04 | 1.00 | 1.00 | 0.10 | 0.18 | 1.00 |
A2C: achieved the highest mean reward and perfect hit rate, meaning it consistently selects true hits while optimizing for diversity, novelty, and cost (via our custom reward function)
DQN: slightly lower hit rate but a high mean reward, meaning it balances finding hits with other objectives in the reward function
PPO: perfect hit rate but much lower mean reward, meaning it focuses only on finding hits and neglects diversity, novelty, and cost
Appendix
Definitions
PPO (Proximal Policy Optimization): essentially gradient descent but for RL instead of DL. In DL, gradients come from a known loss (MSE, cross-entropy). In RL, the “loss” is noisy, indirect—it’s the expected future reward.
Learned representations: the “eyes” of our RL agent- they determine what information is available for decision making
Component | Definition | In This Project | Math Variables |
Policy (π) | Chooses action given a state | Maps protein features to triage rank | |
State (s) | Environment snapshot | UniProt/AlphaFold/KG feature vector | |
Action (a) | Decision made by policy | Select triage rank or protein index | |
Reward (r) | Feedback for action quality | Composite: hit + GO + diversity − latency | |
Transition (T) | Environment’s next state rule | Step through batch of proteins | |
Value Function | Future return from a state | Guides learning, estimates total triage gain | |
Update Rule | Policy improvement mechanism | PPO with clipping and advantage estimate | |
Advantage (Âₜ) | Relative value of an action | Stabilizes PPO policy updates | |
Exploration | Encourages diverse behavior | Entropy bonus to prevent premature convergence | |
Discount (γ) | Weight on future rewards | Tune between 0.9–0.99 for triage horizon | |
Q-Function | Value from state-action pair | Optional: Q-learning variant | |
Replay Buffer | Memory of transitions | Optional: For offline RL or TD-error prioritization |