Most deep learning tools for sequence modeling target Natural Language Processing. I mean, I love NLP, but I’m doing genomics! Why doesn’t Google or Facebook make a utility library for my highly specific use-case?

Anyways, I found myself in need of ingesting protein sequences into PyTorch. There are many ways to do this: finding the “best” computational representation (e.g., GloVe, k-mer frequency, etc) of a protein polymer was the goal of my research project. The goal for this post is to demonstrate a simple, albeit complete system to ingest biological sequences from FASTA files.

The absolute first step is to load the sequence data. For convenience, I load it into memory. (It is possible to load these sequences from disk as needed using iterable-style datasets, but this precludes batching!)

from Bio import SeqIO
with open("foo.fasta") as f:
    records = list(SeqIO.parse(f, "fasta"))

The first step is to run through the genomic data so we know which residues are present. This allows us to assign each a unique number.

vocab = set()
for record in records:
    vocab.update(str(record.seq))
vocab.add("<pad>")
to_ix = {char: i for i, char in enumerate(vocab)}

For example, a random protein from the Grapevine Fleck Virus is encoded thusly:

[to_ix[residue] for residue in "MNRGPPLRSRPPSSPPPASAFPGPSPFPSPSPANSLPSASPPPPTCTPSSPVSRPFASARLRTSHPPRCPHRSAPPSAPSPPFTPPHPLPTPTPSSSPRSPWLSLAPLPTSSASLASFPPPPSSFSSPSSPSTSPLSPSSSSFPSSSSFSFLVPSNS"]
>>> [12, 7, 13, 3, 1, 1, 9, 13, 11, 13, 1, 1, 11, 11, 1, 1, 1, 2, 11, 2, 8, 1, 3, 1, 11, 1, 8, 1, 11, 1, 11, 1, 2, 7, 11, 9, 1, 11, 2, 11, 1, 1, 1, 1, 6, 10, 6, 1, 11, 11, 1, 0, 11, 13, 1, 8, 2, 11, 2, 13, 9, 13, 6, 11, 14, 1, 1, 13, 10, 1, 14, 13, 11, 2, 1, 1, 11, 2, 1, 11, 1, 1, 8, 6, 1, 1, 14, 1, 9, 1, 6, 1, 6, 1, 11, 11, 11, 1, 13, 11, 1, 4, 9, 11, 9, 2, 1, 9, 1, 6, 11, 11, 2, 11, 9, 2, 11, 8, 1, 1, 1, 1, 11, 11, 8, 11, 11, 1, 11, 11, 1, 11, 6, 11, 1, 9, 11, 1, 11, 11, 11, 11, 8, 1, 11, 11, 11, 11, 8, 11, 8, 9, 0, 1, 11, 7, 11]

This is the simplest possible embedding. How do we get this into PyTorch? By defining a dataset – which is just a class with definitions for __getitem__ and __len__ and passing this to a torch.utils.data.DataLoader instance. For the unaware, these are “dunder” or “magic” methods, which you can read about here. We can write a simplistic (read: flawed) implementation like so:

import torch

class BiologicalSequenceDataset:
    def __init__(self, records):
        self.records = records

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        seq = self.records[i].seq
        return torch.tensor([to_ix[residue] for residue in seq])

training_data = torch.utils.data.DataLoader(
    BiologicalSequenceDataset(records),
    batch_size=1,
)

This would work, actually, but only because I have specified the batch size as 1. Any more and we would get an error about tensors of different sizes. To address this issue, we implement a collate_fn as described here. This will take a set of arbitrary sequences and zero-pad the shorter ones, such that they’re all the same length.

def collate_fn(batch):
    return torch.nn.utils.rnn.pad_sequence(
        batch,
        batch_first=True,
        padding_value=to_ix["<pad>"]
    )

Then, the dataloader would be modified like so:

training_data = torch.utils.data.DataLoader(
    BiologicalSequenceDataset(records),
    collate_fn=collate_fn,
)

The complete code

If we make one adjustment to partition testing and training data, we have:

import torch
from Bio import SeqIO

with open("foo.fasta") as f:
    records = list(SeqIO.parse(f, "fasta"))

vocab = set()
for record in records:
    vocab.update(str(record.seq))
vocab.add("<pad>")
to_ix = {char: i for i, char in enumerate(vocab)}


class BiologicalSequenceDataset:
    def __init__(self, records):
        self.records = records

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        seq = self.records[i].seq
        return torch.tensor([to_ix[residue] for residue in seq])


def collate_fn(batch):
    return torch.nn.utils.rnn.pad_sequence(
        batch,
        batch_first=True,
        padding_value=to_ix["<pad>"]
    )

n_examples = len(records)
ds_train, ds_test = torch.utils.data.random_split(
    BiologicalSequenceDataset(records),
    lengths=[n_examples-(n_examples//4), n_examples//4]
)
dl = {
    "train": torch.utils.data.DataLoader(ds_train, collate_fn=collate_fn),
    "test": torch.utils.data.DataLoader(ds_test, collate_fn=collate_fn)
}