GitHub repo: Embed GitHub
Goal: predict labeled protein function from input amino acid sequence
Pipeline: any arbitrary protein sequence → ESM-2 tokenizer → [classifier model] → classification profile of protein with GO annotation
Datasets: CAFA3, Gene Ontology, ESM2
Hardware: NVIDIA 4070ti GPU with 12GB VRAM, Intel Core i7-13700KF CPU with 24 threads
Operating system: Ubuntu 24.04
General outline
- prepare the data
- load protein sequences from CAFA3 (CAFA3 provides labeled examples: the sequence + experimentally verified GO functions)
- load GO term annotations
- create synthetic data for improved training
- split data into train/test/validation sets
- extract ESM2 features
- load pretrained ESM2 into RAM
- tokenize protein sequences from CAFA3 using the ESM2 tokenizer
- extract mean embeddings from ESM2’s last hidden layer
- store embeddings as features for each protein
- create classifier model architecture
- MLP on top of ESM2 embeddings
- input: ESM2 1280 dimension embeddings
- hidden layers: 512 → 256 → num_targets
- output: binary predictions for each GO term (multi-label classification)
- train
- loss function: sigmoid binary cross-entropy
- optimizer: Adam with learning rate 0.001
- training loop: 2000 steps with an eval every 100 steps
- batch size: 32
- run evaluations
- metrics: accuracy, precision, auPRC, auROC
- per-function evaluation: calculate metrics fro each GO term
- final test: evaluate on held-out set
- iterate, if needed
- extrapolate to new datasets
- tbd, but arguably the most important part to see if it is actually valuable
Data preprocessing
We start with the protein sequence dataset (in this case, CAFA3), then pass those inputs to the ESM2 tokenizer.
# dataset that has protein sequences that are annotated with GO terms
cafa3_sequences = [
"MKLLILTCLVAVALARPKHPIKHQGLPQEVLNENLLRFFVAPFPEVFGKEKVNEL", # 60 amino acids (length 60)
"MKKLVLLSLVLAFSLLASQVAAPQNQAMDDTEADYQEMTGGKQTITVEELTTRK" # 60 amino acids (length 60)
]
# tokenize/vectorize the protein sequences
ESM2_tokenizer(cafa3_sequences)
# output from ESM2_tokenizer
input_ids = [
[0, 5, 12, 12, 15, 12, 15, 20, 5, 12, 20, 5, 12, 1, 0, 0, ...], # [1, 64] - padded to max length
[0, 5, 12, 12, 15, 12, 15, 20, 5, 12, 20, 5, 12, 1, 0, 0, ...] # [1, 64] - padded to max length
]
attention_mask = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...], # [1, 64] - 1=real token, 0=padding
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...] # [1, 64] - 1=real token, 0=padding
]
hidden = outputs.last_hidden_state # [2, 64, 640]
# shape: [batch_size=2, sequence_length=64, embedding_dim=640] -> each token gets a 640D embeddi
After running the protein sequence datasets through the ESM2 tokenizer, we then need to convert those token embeddings into sequence embeddings to pass into our hand-built classifier.
# Triton kernel- the bridge between individual token embeddings and a meaningful protein representation
hidden = torch.tensor([
# sequence 1: [64, 640] # 64 tokens, each with 640D embedding
[[0.1, 0.2, 0.3, ..., 0.8], # <cls> token embedding
[0.2, 0.1, 0.4, ..., 0.7], # M token embedding
[0.3, 0.2, 0.1, ..., 0.6], # K token embedding
...
[0.0, 0.0, 0.0, ..., 0.0], # padding token
[0.0, 0.0, 0.0, ..., 0.0]], # padding token
# sequence 2: [64, 640]
[[0.4, 0.3, 0.2, ..., 0.9], # <cls> token embedding
[0.5, 0.4, 0.3, ..., 0.8], # M token embedding
[0.6, 0.5, 0.4, ..., 0.7], # K token embedding
...
[0.0, 0.0, 0.0, ..., 0.0], # padding token
[0.0, 0.0, 0.0, ..., 0.0]] # padding token
])
mask = torch.tensor([
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...], # [64]: 1=real, 0=padding
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...] # [64]: 1=real, 0=padding
])
# Triton Kernel output
pooled = masked_mean_pool(hidden, mask) # [2, 640]
# shape: [batch_size=2, embedding_dim=640]
# one mean embedding per protein sequence
# example output
pooled = torch.tensor([
[0.25, 0.15, 0.25, ..., 0.75], # mean embedding for sequence 1
[0.45, 0.35, 0.25, ..., 0.85], # mean embedding for sequence 2
...
[0.65, 0.25, 0.35, ..., 0.69], # mean embedding for sequence N
])
After embedding our protein sequences into a data format that our classifier can work with (a concise, fixed-size embedding for each protein sequence), we then need to use these embeddings to predict the protein's functions.
# our sequence-level embeddings from Step 2
embeddings = pooled.detach().cpu().numpy() # [32, 640]
# shape: [batch_size=32, embedding_dim=640]
# generate raw logits by feeding the embeddings through our MLP classifier
logits = state.apply_fn({"params": state.params}, x=embeddings) # [2, 303]
# shape: [batch_size=2, num_functions=303]
# raw logits for each GO term
probabilities = jax.nn.sigmoid(logits) # [2, 303]
# shape: [batch_size=2, num_functions=303]
# probabilities between 0 and 1 for each GO term
# example output
probabilities = [
[0.1, 0.9, 0.3, 0.7, 0.2, ...], # Protein 1: 10% chance of GO:0003824, 90% chance of GO:0008150, ...etc
[0.8, 0.2, 0.6, 0.4, 0.1, ...], # Protein 2: 80% chance of GO:0003824, 20% chance of GO:0008150, ...etc
...
[0.3, 0.1, 0.5, 0.8, 0.2, ...], # Protein N: 30% chance of GO:0003824, 10% chance of GO:0008150, ...etc
]
*notes on data amount
number of CAFA3 protein sequences: ~1,800-2,000
- train split: ~1,440-1,600 proteins (80%)
- validation split: ~180-200 proteins (10%)
- test split: ~360-400 proteins (20%)
total number of GO terms: 303
but what really matters is the meaning we can get out of the dataset rather than raw bytes size
Training steps
- prepare data
- take CAFA3 protein sequences, map each to its GO function terms
- convert terms into multi-hot vectors
- split data into train/validation/test
- tokenize
- convert CAFA3 sequences into ESM2 token IDs (padding/truncation handled by Triton kernel)
- embed
- feed tokenized sequences into ESM2 and get per-sequence embeddings
- train head:
- pass embeddings into a linear head
- compare outputs against GO vectors using BCEWithLogits
- update head weights
- validate and save
- evaluate on CAFA3 validation split (F-max metric, auRPC, auROC)
- store trained head, tokenizer, label map
Inference steps
- tokenization
- the raw user-provided string (“MKT…”) is converted into tokens the ESM2 model understands
- embedding
- ESM2 turns those tokens into a vector representation (hidden states learned during that model’s training)
- head projection
- the trained linear head maps that embedding to scores (logits) over GO terms
- prediction
- apply sigmoid to get probabilities for each GO term
- interpretation
- the top terms are the protein’s predicted functions
Results
In summary, we fine tuned ESM2-150M on CAFA3 and GO term datasets, built an MLP classifier on those CAFA3 protein sequences for function prediction. We implemented a Triton kernel for optimized sequence embedding pooling and transfer learning from the frozen ESM2-150M embeddings, then a trained MLP head for the 303 GO function terms.
We used 2,000 training steps and achieved decent performance: validation AUPRC 0.583, test AUPRC 0.591, and AUROC ~0.94 across both splits. The model excels at specific functions like deubiquitinase activity (AUPRC 0.976) and G protein-coupled receptor activity (AUPRC 0.915). We had 98.2% overall accuracy with consistent generalization between validation and test sets.
This obviously needs way more data to validate, specifically in a business/research context.
Cloud deployment
The model size ended up being relatively small at 2.05 MB, which is efficient for deployment. This means it can be deployed in AWS Lambda, which is extremely cost effective.
If we want to deploy both the ESM2-150M model (~600MB) for embedding generation and the trained MLP classifier (2.05MB) for function prediction, we could use ECS Fargate, specifically with 2 vCPUs and 8GB RAM, which would cost ~$50-80/month. The customer would send a request to POST /predict
with {"sequences": ["MKLLILTCLVAV..."]}
.
It would be better to host both, since we could do raw protein sequences → function predictions in one API call (on unseen proteins). No preprocessing would be required since the users just send protein sequences, we could optimize caching to get much faster for every subsequent call, and we can retrain the MLP without changing the API. But as an MVP, we could just host the Lambda function.
Figures
Term | Description | AUPRC | AUROC |
GO:0101005 | deubiquitinase activity | 0.97619 | 0.999859 |
GO:0004930 | G protein-coupled receptor activity | 0.914855 | 0.990439 |
GO:0003824 | catalytic activity | 0.901898 | 0.942317 |
GO:0004888 | transmembrane signaling receptor activity | 0.888111 | 0.971792 |
GO:1990837 | sequence-specific double-stranded DNA binding | 0.884775 | 0.983483 |
GO:0060089 | molecular transducer activity | 0.884156 | 0.966763 |
GO:0038023 | signaling receptor activity | 0.882269 | 0.968464 |
GO:0019783 | ubiquitin-like protein peptidase activity | 0.871003 | 0.991653 |
GO:0043565 | sequence-specific DNA binding | 0.867001 | 0.975598 |
GO:0004672 | protein kinase activity | 0.863945 | 0.937293 |