GitHub repo:
Embed GitHub
In this project, I fine-tuned ESM2 on the CAFA3 labeled proteins and annotate them with the Gene Ontology labels. The goal is to predict labeled protein functions from an input amino acid sequence. There is a similar paper that executes a similar methodology and comes to similar conclusions. The hardware used in this includes an NVIDIA 4070ti GPU with 12GB VRAM, Intel Core i7-13700KF CPU with 24 threads. The operating system is Ubuntu 24.04. The pipeline is as follows:
input any arbitrary protein sequence → ESM-2 tokenizer → classifier model → classification profile of protein with GO annotation
General outline
- prepare the data
- load protein sequences from CAFA3 (CAFA3 provides labeled examples: the sequence + experimentally verified GO functions)
- load GO term annotations
- create synthetic data for improved training
- split data into train/test/validation sets
- extract ESM2 features
- load pretrained ESM2 into RAM
- tokenize protein sequences from CAFA3 using the ESM2 tokenizer
- extract mean embeddings from ESM2’s last hidden layer
- store embeddings as features for each protein
- create classifier model architecture
- MLP on top of ESM2 embeddings
- input: ESM2 1280 dimension embeddings
- hidden layers: 512 → 256 → num_targets
- output: binary predictions for each GO term (multi-label classification)
- train
- loss function: sigmoid binary cross-entropy
- optimizer: Adam with learning rate 0.001
- training loop: 2000 steps with an eval every 100 steps
- batch size: 32
- run evaluations
- metrics: accuracy, precision, auPRC, auROC
- per-function evaluation: calculate metrics fro each GO term
- final test: evaluate on held-out set
- iterate, if needed
- extrapolate to new datasets
- tbd, but arguably the most important part to see if it is actually valuable
Data preprocessing
We start with the protein sequence dataset (in this case, CAFA3), then pass those inputs to the ESM2 tokenizer.
After running the protein sequence datasets through the ESM2 tokenizer, we then need to convert those token embeddings into sequence embeddings to pass into our hand-built classifier.
After embedding our protein sequences into a data format that our classifier can work with (a concise, fixed-size embedding for each protein sequence), we then need to use these embeddings to predict the protein's functions.
*notes on data amount
number of CAFA3 protein sequences: ~1,800-2,000
- train split: ~1,440-1,600 proteins (80%)
- validation split: ~180-200 proteins (10%)
- test split: ~360-400 proteins (20%)
total number of GO terms: 303
but what really matters is the meaning we can get out of the dataset rather than raw bytes size
Training steps
- prepare data
- take CAFA3 protein sequences, map each to its GO function terms
- convert terms into multi-hot vectors
- split data into train/validation/test
- tokenize
- convert CAFA3 sequences into ESM2 token IDs (padding/truncation handled by Triton kernel)
- embed
- feed tokenized sequences into ESM2 and get per-sequence embeddings
- train head:
- pass embeddings into a linear head
- compare outputs against GO vectors using BCEWithLogits
- update head weights
- validate and save
- evaluate on CAFA3 validation split (F-max metric, auRPC, auROC)
- store trained head, tokenizer, label map
Inference steps
- tokenization
- the raw user-provided string (“MKT…”) is converted into tokens the ESM2 model understands
- embedding
- ESM2 turns those tokens into a vector representation (hidden states learned during that model’s training)
- head projection
- the trained linear head maps that embedding to scores (logits) over GO terms
- prediction
- apply sigmoid to get probabilities for each GO term
- interpretation
- the top terms are the protein’s predicted functions
Results
In summary, we fine tuned ESM2-150M on CAFA3 and GO term datasets, built an MLP classifier on those CAFA3 protein sequences for function prediction. We implemented a Triton kernel for optimized sequence embedding pooling and transfer learning from the frozen ESM2-150M embeddings, then a trained MLP head for the 303 GO function terms.
We used 2,000 training steps and achieved decent performance: validation AUPRC 0.583, test AUPRC 0.591, and AUROC ~0.94 across both splits. The model excels at specific functions like deubiquitinase activity (AUPRC 0.976) and G protein-coupled receptor activity (AUPRC 0.915). We had 98.2% overall accuracy with consistent generalization between validation and test sets.
This obviously needs way more data to validate, specifically in a business/research context.
Cloud deployment
The model size ended up being relatively small at 2.05 MB, which is efficient for deployment. This means it can be deployed in AWS Lambda, which is extremely cost effective.
If we want to deploy both the ESM2-150M model (~600MB) for embedding generation and the trained MLP classifier (2.05MB) for function prediction, we could use ECS Fargate, specifically with 2 vCPUs and 8GB RAM, which would cost ~$50-80/month. The customer would send a request to POST /predict with {"sequences": ["MKLLILTCLVAV..."]}.
It would be better to host both, since we could do raw protein sequences → function predictions in one API call (on unseen proteins). No preprocessing would be required since the users just send protein sequences, we could optimize caching to get much faster for every subsequent call, and we can retrain the MLP without changing the API. But as an MVP, we could just host the Lambda function.
Figures
Term | Description | AUPRC | AUROC |
GO:0101005 | deubiquitinase activity | 0.97619 | 0.999859 |
GO:0004930 | G protein-coupled receptor activity | 0.914855 | 0.990439 |
GO:0003824 | catalytic activity | 0.901898 | 0.942317 |
GO:0004888 | transmembrane signaling receptor activity | 0.888111 | 0.971792 |
GO:1990837 | sequence-specific double-stranded DNA binding | 0.884775 | 0.983483 |
GO:0060089 | molecular transducer activity | 0.884156 | 0.966763 |
GO:0038023 | signaling receptor activity | 0.882269 | 0.968464 |
GO:0019783 | ubiquitin-like protein peptidase activity | 0.871003 | 0.991653 |
GO:0043565 | sequence-specific DNA binding | 0.867001 | 0.975598 |
GO:0004672 | protein kinase activity | 0.863945 | 0.937293 |