Ranking Proteins with Reinforcement Learning

Ranking Proteins with Reinforcement Learning

This was a great project, but took way longer than I anticipated. 4 days execution vs 1 day expectation. I read part of Sutton’s book last year and thought the concept of computers autonomously learning was pretty powerful, since this is exactly how lifeforms do it (albeit, a simplistic model). If you could have another model create the variables for the reward function, meaning they are ever increasing in description, that would make a very, very scalable business. I tried that with the reward shaper in this project.

I tried a lot of stuff in this- probably too much. I created scripts to get data from Uniprot and ChEMBL, then push those AA sequences to ESM to get embeddings, to then combine those and use them as a data structure. Then I tuned the environment and the reward function, ultimately deciding to test different RL architectures. I tested PPO, DQN, and A2C via the Stable-Basline3 package, and improved the metrics to be way more complete. I then turned it into an API that can be used by startups like ProFound Therapeutics.

This API can ingest real experimental decisions (like which proteins ProFound triaged, tested, or validated), learns a policy that mimics or improves those decisions, then ranks unseen proteins to recommend next-best targets. The model is validated against proteins that we currently have working knowledge of.

Intro

I drive a Tesla with Autopilot. It’s incredibly useful and getting better, though it struggles in complex situations- like being cut off, boxed-in, and needing to exit in 0.1 miles, or parking in tight lots. Still, it’s >99% flawless.

But one thing consistently irritates me: it favors short-term rewards over long-term ones.

Every day, I leave the gym, merge onto a 5-lane freeway, then exit in under a mile. When I drive, I just stay in the exit lane. Autopilot merges onto that first lane, then merges across 4 lanes to the far left, only to merge back to the first lane immediately, since it’s less than 1 mile. It does this every time, even if there is heavy traffic. It’s dumb.

Why? I think its neural net is myopic. The reward function seems tuned for short-term exploitation- not long-term planning. This may be for safety, but is more likely due to an incomplete RL design. The car behaves as if it’s optimizing <1 mile horizons.

Fixes could include increasing gamma (longer planning horizon), use a more global reward function with context (next exits, traffic, comfort), and penalize unnecessary lane changes.

So what does this have to do with protein folding RL agents?

Project

Alright so I didn’t realize that I was going to be building a model when I started this. It turns out RL is not just applying self-directed fine-tuning to a big model (like a foundation model). You are actually creating a whole new abstraction layer on top of the foundation model, using the other model’s output embeddings as input for your RL model.

When AlphaFold, Boltz2, ESM-2, GO, etc consume an input amino acid sequence, they do a bunch of transformations on that sequence, then output vectors/matrices of float16’s/float32’s. These then become the input embeddings to our model.

Source/model
Output for ML/RL agent
Example Models
Embedding model
1D vector of floats
UniRep, ESM, ProtBERT, SeqVec, TAPE
Structure model
2D matrix (coords) → 1D/2D
AlphaFold (3D coords), Rosetta
Feature vector
1D vector of floats
Hand-crafted features (AA composition, pI, hydrophobicity), GO term embeddings, Pfam/InterPro domain vectors, EC class embeddings
Contact map
2D matrix (can be flattened)
AlphaFold (contact map)

I trained a reinforcement learning agent to rank proteins using ESM-2 embeddings and experimental data. Training ran for 2 million timesteps on an RTX 4070 Ti using the SBX library (JAX-based).

ESM Embeddings

I added embeddings from ESM, since it seemed like the easiest to do on a Mac. AlphaFold was way too heavy. Each embedding represents a high-dimensional, biologically meaningful feature to represent each protein sequence as input to the RL model. This means we don’t need structural data, just the sequence embeddings. By decoupling the feature extraction from training, we avoid the computation cost of fine-tuning the agent. The agent learns to prioritize protein targets based on these embeddings

ChEMBL Data

Our goal with this was to replace random “hits” with real labels by using experimental data (binding assays, functional screens). The download process took too long for this project timeline when working with ChEMBL (was supposed to be done in 24 hours, ended up working on this for 4 days). There is definitely an API opportunity for one of the big experimental biology companies (maybe even a CRO like Ginkgo) to make a validation dataset like this for assay data. That is going to be a very important pillar for AI being applied in molecular biology.

Adding Uniprot + BindingDB to replace Chembl

Endpoint
For
/target
Map UniProt → ChEMBL target, get metadata
/activity
Get binding/functional data for target
/molecule
Get ligand/compound info
/assay
Get experiment/assay context

Unifying ChEMBL data + ESM embeddings

Before: RL agent learned from random noise (fake features + fake hits) - couldn't find real patterns

After: RL agent learns from real biological relationships (protein structure → experimental activity) - can discover meaningful protein prioritization strategies

Optimizing the Reward Function

The system balances multiple objectives to calculate the reward, including activity strength, which measures how well the protein binds to targets; diversity, which measures how different selected proteins are from each other; novelty which measures how understudied the protein is; and cost, which measures how expensive it would be to study experimentally.

Old Reward function

reward=y+max(0,5log10(Kd×109))+5A2T+2Ereward=y+max(0,5−log10(K_d×10^9))+5A−2T+2E

  • yy: binary hit label (1 if hit, 0 if not)
  • KdK_d: binding affinity in nM and log-scaled (molar)
  • AA: functional activity (normalized, 0–1, scaled by 5)
  • TT: toxicity score (normalized, 0–1, penalized by -2)
  • EE: expression level (normalized, 0–1, penalized by 2)

New Reward Function

Rimproved=w1y+w2softplus(c1log10(Kd×109))R_{improved}=w_1⋅y + w_2⋅softplus(c_1−log_{10}(K_d×10^9))

  • softplus(x)=ln(1+ex)softplus(x)=ln(1+e^x)
  • c1=7c_1=7 (tunable constant to center the affinity reward)
  • w1=1,w2=2,w3=2,w4=2,w5=1w1=1,w2=2,w3=2,w4=2,w5=1 (tunable weights)

Multi-objective reward function improvement strategies:

  • Normalization: all variables should be normalized to comparable scales (0–1)
  • Pareto weighting: allow for tunable weights for each objective
  • Nonlinear transforms: use sigmoid/softplus for smoother gradients
  • Penalize missing/invalid data: avoid bias from missing values
  • Domain knowledge: use biologically meaningful thresholds

Training

We trained a bunch. The ppo_2M_best_ turned out to be the best.

Timesteps
Final Eval Mean Reward
500k
43.3
1M
44.3
2M
44.9

Training on 4070 Ti with JAX-based SBX library.

Metric
Result
Mean Reward
44.92 (vs. ~40.9 for random)
Performance Gain
+9.9% over random policy
Reward Std Dev
0.00 — perfectly stable policy
Convergence Time
~1 hour (on GPU)
Episode Length
50 proteins (fixed; reached every time)

Model Quality Summary

  • High-Reward Policy: Agent consistently selected high-value proteins
  • Fast & Stable Training: Learning curve rose quickly (~200k steps) and converged cleanly
  • Strong Generalization (within set): No catastrophic forgetting or instability observed

Shortcomings + Improvements

Issue
Details
Low Diversity
Only 3 unique proteins selected repeatedly (diversity score: 0.006)
Overfitting Risk
Possible reliance on narrow strategy; may not generalize
Reward Function Limit
Lacked explicit diversity/novelty terms
Validation Gap
Model was only tested on training set

Testing Different RL Architectures (CPU)

We stuck with the PPO agent for the main model, but I was curious to see how other RL architectures would perform.

First, I created the single abstraction layers for each RL agent architecture- PPO, DQN, Actor-Critic. Then, I realized we could just import SB3 and use their classes for our program.

So I did just that and created a test file (test_sb3_algorithms.py) that also outputs some useful data. I also updated the metrics to be more complete. These metrics measure how well the agent selects proteins.

Hit Rate: Total proteins evaluated / number of hits found

Average Reward: Mean reward per episode

Cumulative Reward: Total reward across all episodes

Precision@k: Fraction of top-k selected proteins that are actual hits

Recall@k: Fraction of all hits found in top-k selections

F1-Score: Harmonic mean of precision and recall

AUC-ROC: Area under the receiver operating characteristic curve

Normalized Discounted Cumulative Gain (NDCG): Ranks proteins by predicted vs. actual quality

# after execution
Summary:
PPO: Mean Reward = 3.18, Rewards = [np.float32(3.1789813), np.float32(3.1789813), np.float32(3.1789813), np.float32(3.1789813), np.float32(3.1789813)]
  Metrics: {'hit_rate': 1.0, 'avg_reward': np.float32(3.1789813), 'cum_reward': np.float32(15.894907), 'precision@k': 1.0, 'recall@k': 0.1, 'f1@k': 0.18181818181818182, 'auc_roc': nan, 'ndcg': 1.0}
DQN: Mean Reward = 10.89, Rewards = [np.float32(10.893245), np.float32(10.893245), np.float32(10.893245), np.float32(10.893245), np.float32(10.893245)]
  Metrics: {'hit_rate': 0.7, 'avg_reward': np.float32(10.893245), 'cum_reward': np.float32(54.466225), 'precision@k': 0.6, 'recall@k': 0.08571428571428572, 'f1@k': 0.15, 'auc_roc': 0.5, 'ndcg': 1.0}
A2C: Mean Reward = 12.04, Rewards = [np.float32(12.038507), np.float32(12.038507), np.float32(12.038507), np.float32(12.038507), np.float32(12.038507)]
  Metrics: {'hit_rate': 1.0, 'avg_reward': np.float32(12.038507), 'cum_reward': np.float32(60.192535), 'precision@k': 1.0, 'recall@k': 0.1, 'f1@k': 0.18181818181818182, 'auc_roc': nan, 'ndcg': 1.0}
/Users/bradleywoolf/Desktop/software/bioinformatics_practice/ProtRankRL/tests/test_sb3_algorithms.py:127: FutureWarning:
image
image

Conclusion

In benchmarking PPO, DQN, and A2C agents on the protein triage task, DQN and A2C consistently achieved higher mean and cumulative rewards than PPO, with A2C slightly outperforming DQN. In the chart below, higher values mean better.

Agent
Mean Reward
Hit Rate
Precision@5
Recall@5
F1@5
NDCG
PPO
3.18
1.00
1.00
0.10
0.18
1.00
DQN
10.89
0.70
0.60
0.09
0.15
1.00
A2C
12.04
1.00
1.00
0.10
0.18
1.00

A2C: achieved the highest mean reward and perfect hit rate, meaning it consistently selects true hits while optimizing for diversity, novelty, and cost (via our custom reward function)

DQN: slightly lower hit rate but a high mean reward, meaning it balances finding hits with other objectives in the reward function

PPO: perfect hit rate but much lower mean reward, meaning it focuses only on finding hits and neglects diversity, novelty, and cost

Appendix

Definitions

PPO (Proximal Policy Optimization): essentially gradient descent but for RL instead of DL. In DL, gradients come from a known loss (MSE, cross-entropy). In RL, the “loss” is noisy, indirect—it’s the expected future reward.

Learned representations: the “eyes” of our RL agent- they determine what information is available for decision making

Component
Definition
In This Project
Math Variables
Policy (π)
Chooses action given a state
Maps protein features to triage rank
π(as)π(a∣s)
State (s)
Environment snapshot
UniProt/AlphaFold/KG feature vector
sRDs\in \mathbb{R}^D
Action (a)
Decision made by policy
Select triage rank or protein index
aAa \in \mathcal{A}
Reward (r)
Feedback for action quality
Composite: hit + GO + diversity − latency
rt=r(st,at)r_t = r(s_t, a_t)
Transition (T)
Environment’s next state rule
Step through batch of proteins
st+1T(st,at)s_{t+1}∼T(s_t,a_t)
Value Function
Future return from a state
Guides learning, estimates total triage gain
V(s)=E[tγtrt]V(s)= \mathbb{E}\left[\sum_t \gamma^t r_t\right]
Update Rule
Policy improvement mechanism
PPO with clipping and advantage estimate
L=E[min(rt(θ)A^t,clip(rt(θ),))]L = \mathbb{E}\left[\min(r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), \ldots))\right]
Advantage (Âₜ)
Relative value of an action
Stabilizes PPO policy updates
A^t=Q(st,at)V(st)\hat{A}_t = Q(s_t,a_t) - V(s_t)
Exploration
Encourages diverse behavior
Entropy bonus to prevent premature convergence
H(π(s))H\left(\pi(\cdot \mid s)\right)
Discount (γ)
Weight on future rewards
Tune between 0.9–0.99 for triage horizon
γ(0,1)\gamma \in (0, 1)
Q-Function
Value from state-action pair
Optional: Q-learning variant
Q(s,a)=E[tγtrts0=s,a0=a]Q(s,a) = \mathbb{E}\left[\sum_t \gamma^t r_t \mid s_0=s, a_0=a \right]
Replay Buffer
Memory of transitions
Optional: For offline RL or TD-error prioritization
D={(st,at,rt,st+1)}\mathcal{D} = \{(s_t, a_t, r_t, s_{t+1})\}