Stephen Malina
This is my blog. There are many others like it but this one is mine.
Protein Language Models (Part 2): Models
Contents
_I’m experimenting with Lilian Weng style review blog posts on topics I wish I’d had someone to explain to me when I got started in protein machine learning.
This is the second post in a series introducing protein language models_. The other posts are:
- Protein Language Models (Part 1): Introduction & Datasets
- Protein Language Models (Part 2): Models
- Protein Language Models (Part 3): Benchmarks & Evaluation
- Protein Language Models (Part 4): Scaling
- Protein Language Models (Part 5): FAQ & Conclusion
This post focuses on the different types of PLMs with a specific focus on deviations from the standard LLM playbook.
Single Sequence Models
The most common type of PLM takes single sequences as input and learns to predict their probability. These models are the most common in the PLM world and so are a good place to start.
Architectures
Early work on protein language models such as UniRep employed various RNN architectures. More recently, the CARP paper argued that convolutional models could perform competitively with Transformers on both pretraining metrics and downstream tasks. Other papers, such as the ESM-InverseFolding (ESM-IF) paper, have used Graph Neural Networks (GNNs). However, as of July 2023, Transformers have the most momentum The Hazy Research group at Stanford has recently put out a flurry of papers that purportedly improve on various aspects of Transformers. Benefiting from careful investigation of how attention works, the Hyena architecture overcomes attention’s quadratic bottleneck while potentially maintaining its performance. (A slew of sub-quadratic attention variants have claimed to do this, but people I trust seemingly put more stock in Hyena actually achieving it at scale than other candidates.) Combined with the more recent Monarch Mixer, which tackles the MLP efficiency bottleneck, I assign >40% probability that SoTA models no longer use attention in 5 years. behind them in the PLM world. I believe this will continue to hold unless another architecture shows better holistic performance and replaces transformers more broadly. Models from the ESM lineage, [RITA, ProGen 1 and 2, ProtTrans’ various models, and Ankh all use transformer variants.
ProtTrans
Prominent PLMs use a mix of model structures and losses – encoder-only MLM, encoder-decoder seq2seq MLM, decoder-only autoregressive – first explored by the NLP community. ProtTrans investigated the impact of these choices on downstream task performance by training models with different setups using comparable parameter counts.
Differences between proteins and language might lead us to different priors about which losses make the most sense. Relative to language with its temporal structure, proteins have less inherent left-to-right structure. Ribosomes translate from left (N-terminus) to right (C-terminus), but protein function depends much more on spatial location and contiguity than order in the sequence. For that reason, we could imagine that masked or other bidirectional losses make more sense for protein language models. Admittedly though, people made similar arguments for why autoregressive losses would be insufficient for powerful language models and were wrong, so rather than speculate, we’d like to these questions empirically.
ProtTrans Figure 1: Using pretrained LMs as feature extractors for per-residue and per-protein tasks.
They looked at both a per-residue task (secondary structure prediction) and two per-protein classification tasks (membrane vs. soluble and subcellular localization). Of the models trained on the same dataset (UniRef100) with roughly equivalent parameter counts, ProtBert performs best, suggesting that simple encoder-only models combined with MLM losses work well in proteins (as they do in NLP). However, this conclusion isn’t as strong as it could be because the paper lacks a similarly sized seq2seq (T5-style) model to compare to, even though the larger T5 models consistently perform best.
Ankh
Taking insights from ProtTrans, the Ankh authors (many of whom also worked on ProtTrans) wrote a follow-up paper that used careful experiments with pretraining to find a setup that enabled them to train smaller models that performed better on average than much larger ProtTrans and ESM models on downstream tasks. For their experiments, they included the downstream tasks used by ProtTrans but also additional structure and fitness tasks.
On average, Ankh improved the PLM SOTA performance by 4.8% while Ankh base improved it by 3.4% with <10% & 3% of the training parameters and 30% & 15% of the embedding dimension, for Ankh and Ankh base respectively.
To map out the effect of architectural and loss decisions, they performed 22 numbered experiments controlled for number of parameters and evaluated them using performance across downstream tasks. I’d recommend at least skimming them as no other PLM paper shows the same level of rigor and care, but here I’ll highlight a few key takeaways. Following ProtTrans’ finding that the T5 model performed best (although as mentioned that conclusion suffers from not controlling for scale), all their experiments used a seq2seq encoder-decoder model:
- 1-gram span masking works better than full autoregressive decoding: In experiments 1-4, they find that only reconstructing mask tokens, split by mask placeholders rather than the full context performs better than either full autoregressive reconstruction or masking spans rather than single tokens. This is a double win because it saves resources relative to the full reconstruction option.
- An encoder 2x bigger than the decoder works just as well and saves compute: Experiments 10-12 cover their investigation of varying relative encoder vs. decoder size. They found that using a larger encoder usually produces the best performance on downstream tasks, intuitively because doing so results in better embeddings.
- In general, differences from architectural choices proved small: More of a meta point, but the tables for each of these individual experiments typically show small performance differences on downstream tasks, usually within a percent or two on any individual task. Masking strategy had the highest effect sizes, but even those generally fall in the range of 2-3% relative to the next best strategy for any given task.
Other Models
- ESM-1b, ESM-1v, and ESM-2: Three sets of masked language models ranging in size from 30M parameters to 15B parameters. The ESM-1b and ESM-1v models deserve credit for demonstrating the potential of scaled up masked language models for both zero-shot fitness and contact prediction. The largest ESM-2 model was used to train [ESMFold], one of the early demonstrations that a PLM could replace the MSA input and layers in a structure prediction model and perform competitively to the state-of-the-art (Alphafold 2 at the time).
- Progen 2: Set of auto-regressive models ranging size from 151M to 6.4B parameters. Relative to other PLM papers, focused more heavily on testing the model’s ability to generate proteins, although they also tested zero-shot fitness prediction and performed better than other existing single sequence models at the time. Progen 2 will come up again in the generation section of the evaluation post.
- RITA: Set of four auto-regressive models trained at various sizes (the largest at 1.2B parameters). This is the first PLM paper I’m aware of that actually mapped out scaling laws and used them to train compute-optimal models. RITA will feature heavily in the scaling post.
- ProtGPT2: Early example of an auto-regressive PLM trained using the GPT2 architecture. Has largely been superseded by bigger, better models but deserves credit for demonstrating just applying the GPT recipe to proteins could work.
Structure/sequence models
Besides sequence, structure presents a natural modality for PLMs to incorporate. Structure – if you count dynamics as part of structure – directly mediates nearly all of protein function and so injecting structural knowledge into PLMs may help them perform and generalize better. PLMs have started to incorporate structure via the “inverse folding” task. Inverse folding conditions generation on structure, such that the model learns $ p(\textrm{sequence} | \textrm{structure}) $ rather than just $ p(\textrm{sequence}) $.
ESM-IF
ESM-IF was one of the early inverse folding models. ESM-IF was trained on a mix of sequence/structure pairs from AlphaFold DB (12M) and CATH (~16,000) and uses an encoder-decoder architecture that takes structures as inputs to the encoder and autoregressively decodes sequences conditioned on structural encodings. As seen below, it also uses an auxiliary span masking loss to train the encoder on structures. In order to handle structures as input, it uses rotation-equivariant layers called Geometric Vector Perceptron layers. At inference time, ESM-IF can generate sequences given a fixed backbone structure.
ESM-IF improvement on prior inverse folding models primarily came from using AlphaFold DB rather than just CATH. As discussed in Dataset, AlphaFold DB contains ~1000x more sequence/structure pairs than CATH, providing a much richer dataset for learning inverse folding even with the caveat that some of its predicted structures have errors. By using a a larger model than previously (both for its GNN and Transformer versions), ESM-IF benefited from this scale, as shown in the following figure.
ProstT5
ProstT5 also does inverse folding. Similar to ESM-IF, it was trained on sequence/structure pairs from AlphaFold DB (but not CATH), but it differs from ESM-IF in two ways. First, it’s a structure and sequence to sequence model rather than only a structure to sequence model. Second, the authors adopted a much simpler, clever strategy for encoding structure, taking advantage of the 1D structural encoding used by Foldseek in order to be able to encode structures as sequences. Because of this, ProstT5 was trained by just fine-tuning an existing PLM (ProtT5) on structures encoded as tokens. This allows it to be used for inverse folding while maintaining its ability to embed sequences and extending it to structures. On an inverse folding benchmark, ProtT5 approaches but doesn’t yet SoTA models’ performance. Its training process also seems to cause performance degradation on downstream tasks sequence-based tasks, so there’s still room to improve. On the flip side, the fact that it works at all combined with our prior based on the naive approach of just feeding everything into a language model working well in other domains suggests the general is still promising.
Family-based retrieval models
Multiple sequence alignments (MSAs) were arguably the earliest form of retrieval augmentation, coming around decades before retrieval-augmented language models entered the scene. Pre-language model era, protein researchers would often train Potts models From an ML perspective, a Potts model is an unsupervised generative model trained with a maximum likelihood estimation objective using first and second-order interaction terms to learn covariance statistics. on sequences from a protein family’s MSA in order to learn residue covariation statistics. Similar to PLMs, Potts models were then used to predict protein contacts and sometimes fitness.
With the advent of deep learning for proteins, PLMs have started to replace Potts Models due to their ability to handle unaligned sequences and learn representations from proteins across the entire tree of life. There’s been interesting work showing how attention-based models relate to Potts models and proposing an attention variant, Factored Attention, that interpolates between the standard attention mechanism and Potts models. While standard PLMs aim to learn homology entirely from single sequence, researchers have experimented with setups that explicitly provide homology information as part of their input. Similar to how retrieval augmentation may allow for models to perform as well while using orders-of-magnitude fewer parameters, the hope here is that providing homology information can allow PLMs to perform as well or better with fewer parameters.
MSA Transformer
The MSA Transformer does masked language modeling but augments single sequence inputs with sequences from their MSAs.
The MSA Transformer bucks the trend of PLMs not innovating on attention architectures with their invention of interleaved column and row attention (subsequently used by AlphaFold 2), which enables them to attend to an entire MSA without running into pains due to naive attention’s quadratic bottleneck.
Rather than attend across the entire MSA, column attention constrains itself to a single position and then looks across sequences in the MSA and row attention attends across positions for a single sequence. Naive attention would have $ \mathcal{O}(M^2 L^2) $ time complexity for sequence length $ L $, MSA size $ M $, whereas column and row attention together have $ \mathcal{O}(LM^2 + L^2M) $ time complexity.
As evidence for MSA retrieval improving performance, they compared a 100M MSA Transformer’s performance on downstream tasks to the 650M parameter ESM-1b model’s performance and found that MSA Transformer outperformed despite having ~6x fewer parameters.
On top of that, they found that even a smaller MSA Transformer model outperforms the 650M ESM-1b model at contact prediction. Subsequent results on ProteinGym zero-shot fitness prediction bolstered this result, with MSA Transformer outperforming all the single sequence models already mentioned. Combined with a bunch of careful experiments they did to show that the MSA information was actually being used, this suggests that retrieval currently provides a big boost for PLM performance, although with the cost of requiring infrastructure for doing MSA retrieval at inference time.
Transception
Transception took a different approach from MSA Transformer’s. Rather than explicitly attend to sequences in the MSA, Transception trained an autoregressive model that computes average weights on amino acids from the MSA at each position and then uses the weights at a given position combined with the left context to predict its amino acid. Transception also used a novel attention variant and positional encoding scheme but I’m not going to cover those here. The tradeoff this makes is that it means the model can only uses the relative frequency of amino acids at each position to inform its predictions rather than the full MSA.
For evaluation, they compared Transception’s models’ performance to sequence-only models of various sizes’ performance and to MSA Transformer’s on zero-shot prediction. They use a benchmark called ProteinGym, which we’ll return to again in the next post when discussing benchmarks and evaluation. At a high level, ProteinGym collects a large number of different deep mutational scan datasets and makes it easy to evaluate models’ performance on them via zero-shot prediction.
Transception Table 2: Average Spearman’s $ \rho $ on ProteinGym benchmarks bucketed by MSA depth and AUC across all datasets.
They found that of the PLMs, Transception performed best on all three MSA depth subsets and performed best of all models on the overall average. However, they also observed that for datasets with medium or high MSA depths, an alignment-based family specific model (VAE), EVE, still outperformed the PLMs. This suggests that PLMs still have a way to go in terms of competing with unsupervised models in regimes where family data is abundant.
Focusing on the PLMs, Transception (with retrieval) outperformed MSA Transformer and all other single sequence models across the board. This may come as a bit of a surprise because MSA Transformer’s retrieval component is strictly more expressive than Transception’s. The authors speculate and run some experiments to test the hypothesis that Transception’s retrieval mechanism and weighting scheme is more robust to shallow alignments and mutations in the sequence, explaining why it especially outperforms on the low MSA depth subset.
Transception Figure 4: Average performance as a function of minimum % similarity, computed via filtering out increasingly large subsets of the MSA for each benchmark (left). The right figure shows performance on the p53 tumor protein specifically, and shows Transception’s relative robustness to increasingly distance, shallow MSAs.
They also argue that Transception performs better on proteins with larger sets of mutations, comparing Transception’s performance to other models’ broken down by “mutation depth” (edit distance).
Transception Table 3: Average Spearman’s $ \rho $ as a function of mutation depth.
Finally, Transception works on proteins with indels, whereas MSA Transformer does not.
Conditional models
The final class of PLMs is conditional models. Conditional models take some sort of auxiliary token(s) as input to guide their predictions (at training time) and generation (at inference time). Conditioning allows directly guiding models towards a specific sub-region of the overall learned distribution vs. having to do so purely by providing contextual sequence information.
Progen
The original Progen paper described autoregressive models that could do conditional generation using taxonomic labels as control tokens. To allow this, they conditioned each sequence on its taxonomic information during training as originally done for NLP in CTRL. Their paper focused much more on the application side of things, showing that they could generate functional proteins (which they actually measured in the lab). As part of this, they do demonstrate that control labels can effectively guide a model towards generating proteins that look similar to those of a specific family.
Conclusion
Protein language models started by experimenting with architectures and setups very similar to NLP. In doing so, they found that many of the insights from NLP largely replicated. Larger models perform better on language modeling and on downstream tasks (discussed more in the Scaling post) and seq2seq models (of the T5 flavor) work especially well. As the field’s matured, researchers have also explored more protein-specific setups such as Family-based retrieval, which for now seems to work better than single-sequence models. While at a more nascent stage, going forward it’ll be interesting to see whether Sequence/structure models also start showing better parameter-for-parameter performance and whether even more multimodal models show interesting capabilities or performance.
Appendix
Other modeling decisions
Context windows
While context length limitations are not as acute for proteins as for language or genomes, extending models’ context windows may also provide a benefit for proteins. While a 2k context is large enough for most proteins, it precludes modeling the sequences of the longest proteins such as Titin (33k residues), complexes, and certain viral assemblies (such as Adenovirus) together. Some papers such as Ankh use relative positional embeddings, which allow models to handle longer proteins than appeared in their training dataset, but I haven’t seen rigorous evaluations of how well these models can actually deal with longer lengths. Longer context windows via improvements such as FlashAttention or entirely new architectures like Hyena‘s could also enable naive implementations of models like the MSA Transformer that use sets of sequences as inputs.
Tokenization & embeddings
Given the ease of using existing frameworks like Huggingface’s for training, some papers (Progen2) analogize amino acids to letters and tokenize them using Byte Pair Encodings (BPEs). Others (both the RITA and ESM models for example) instead embed amino acids directly, giving them an embedding alphabet size of 20. In favor of this, the RITA authors argue that the analogy between amino acids and language fails here (and I agree). The analogy fails because proteins lack a concept of words, which provide natural boundaries for common tokenizations. Tokenizing by amino acid also makes assessing biochemical understanding easier, because we can directly evaluate how well embeddings encode known biochemical properties of each amino acid. Given this, I currently expect that tokenizing each amino acid individually will become standard, but we’ll see.
Existing PLMs use positional embedding setups drawn from NLP such as relative positional embeddings. However, this seems like an area where PLMs might benefit from more protein-specific encodings that, for example, map to structural position rather than just sequence position.