3 minutes
Loading {Gen,Prote}omic Sequences for Deep Learning: FASTA to PyTorch
Most deep learning tools for sequence modeling target Natural Language Processing. I mean, I love NLP, but I’m doing genomics! Why doesn’t Google or Facebook make a utility library for my highly specific use-case?
Anyways, I found myself in need of ingesting protein sequences into PyTorch. There are many ways to do this: finding the “best” computational representation (e.g., GloVe, k-mer frequency, etc) of a protein polymer was the goal of my research project. The goal for this post is to demonstrate a simple, albeit complete system to ingest biological sequences from FASTA files.
The absolute first step is to load the sequence data. For convenience, I load it into memory. (It is possible to load these sequences from disk as needed using iterable-style datasets, but this precludes batching!)
from Bio import SeqIO
with open("foo.fasta") as f:
records = list(SeqIO.parse(f, "fasta"))
The first step is to run through the genomic data so we know which residues are present. This allows us to assign each a unique number.
vocab = set()
for record in records:
vocab.update(str(record.seq))
vocab.add("<pad>")
to_ix = {char: i for i, char in enumerate(vocab)}
For example, a random protein from the Grapevine Fleck Virus is encoded thusly:
[to_ix[residue] for residue in "MNRGPPLRSRPPSSPPPASAFPGPSPFPSPSPANSLPSASPPPPTCTPSSPVSRPFASARLRTSHPPRCPHRSAPPSAPSPPFTPPHPLPTPTPSSSPRSPWLSLAPLPTSSASLASFPPPPSSFSSPSSPSTSPLSPSSSSFPSSSSFSFLVPSNS"]
>>> [12, 7, 13, 3, 1, 1, 9, 13, 11, 13, 1, 1, 11, 11, 1, 1, 1, 2, 11, 2, 8, 1, 3, 1, 11, 1, 8, 1, 11, 1, 11, 1, 2, 7, 11, 9, 1, 11, 2, 11, 1, 1, 1, 1, 6, 10, 6, 1, 11, 11, 1, 0, 11, 13, 1, 8, 2, 11, 2, 13, 9, 13, 6, 11, 14, 1, 1, 13, 10, 1, 14, 13, 11, 2, 1, 1, 11, 2, 1, 11, 1, 1, 8, 6, 1, 1, 14, 1, 9, 1, 6, 1, 6, 1, 11, 11, 11, 1, 13, 11, 1, 4, 9, 11, 9, 2, 1, 9, 1, 6, 11, 11, 2, 11, 9, 2, 11, 8, 1, 1, 1, 1, 11, 11, 8, 11, 11, 1, 11, 11, 1, 11, 6, 11, 1, 9, 11, 1, 11, 11, 11, 11, 8, 1, 11, 11, 11, 11, 8, 11, 8, 9, 0, 1, 11, 7, 11]
This is the simplest possible embedding. How do we get this into PyTorch? By defining a dataset – which is just a class with definitions for __getitem__
and __len__
and passing this to a torch.utils.data.DataLoader
instance. For the unaware, these are “dunder” or “magic” methods, which you can read about here. We can write a simplistic (read: flawed) implementation like so:
import torch
class BiologicalSequenceDataset:
def __init__(self, records):
self.records = records
def __len__(self):
return len(self.records)
def __getitem__(self, i):
seq = self.records[i].seq
return torch.tensor([to_ix[residue] for residue in seq])
training_data = torch.utils.data.DataLoader(
BiologicalSequenceDataset(records),
batch_size=1,
)
This would work, actually, but only because I have specified the batch size as 1. Any more and we would get an error about tensors of different sizes. To address this issue, we implement a collate_fn
as described here. This will take a set of arbitrary sequences and zero-pad the shorter ones, such that they’re all the same length.
def collate_fn(batch):
return torch.nn.utils.rnn.pad_sequence(
batch,
batch_first=True,
padding_value=to_ix["<pad>"]
)
Then, the dataloader would be modified like so:
training_data = torch.utils.data.DataLoader(
BiologicalSequenceDataset(records),
collate_fn=collate_fn,
)
The complete code
If we make one adjustment to partition testing and training data, we have:
import torch
from Bio import SeqIO
with open("foo.fasta") as f:
records = list(SeqIO.parse(f, "fasta"))
vocab = set()
for record in records:
vocab.update(str(record.seq))
vocab.add("<pad>")
to_ix = {char: i for i, char in enumerate(vocab)}
class BiologicalSequenceDataset:
def __init__(self, records):
self.records = records
def __len__(self):
return len(self.records)
def __getitem__(self, i):
seq = self.records[i].seq
return torch.tensor([to_ix[residue] for residue in seq])
def collate_fn(batch):
return torch.nn.utils.rnn.pad_sequence(
batch,
batch_first=True,
padding_value=to_ix["<pad>"]
)
n_examples = len(records)
ds_train, ds_test = torch.utils.data.random_split(
BiologicalSequenceDataset(records),
lengths=[n_examples-(n_examples//4), n_examples//4]
)
dl = {
"train": torch.utils.data.DataLoader(ds_train, collate_fn=collate_fn),
"test": torch.utils.data.DataLoader(ds_test, collate_fn=collate_fn)
}
636 Words
2020-05-28 19:00