Source code for gt4sd.domains.materials.protein_encoding

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Data processing utilities."""

import inspect
from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import torch
from tape.datasets import pad_sequences
from tape.registry import registry
from tape.tokenizers import TAPETokenizer
from torch import nn


[docs]class PrimarySequenceEncoder(nn.Module): """Model like class to create tape embeddings/encodings. This follows tapes implementation via `run_embed` closely, but removes any seed/device/cuda handling (of model and batch). This can be done in the training loop like for any other nn.Module. Example: An example use with protein sequence dataset from `pytoda` (requires mock/rdkit and pytoda>0.2) passing ids with the primary sequence:: import sys from mock import Mock sys.modules['rdkit'] = Mock() sys.modules['rdkit.Chem'] = Mock() from torch.utils.data import DataLoader from pytoda.datasets.protein_sequence_dataset import protein_sequence_dataset from pytoda.datasets.tests.test_protein_sequence_dataset import ( FASTA_CONTENT_GENERIC, TestFileContent ) from pytoda.datasets.utils import keyed with TestFileContent(FASTA_CONTENT_GENERIC) as a_test_file: sequence_dataset = keyed(protein_sequence_dataset( a_test_file.filename, filetype='.fasta', backend='lazy' )) batch_size = 5 dataloader = DataLoader(sequence_dataset, batch_size=batch_size) encoder = PrimarySequenceEncoder( model_type='transformer', from_pretrained='bert-base', tokenizer='iupac', log_level=logging.INFO, ) # sending encoder to cuda device should work, not tested loaded = next(iter(dataloader)) print(loaded) encoded, ids = encoder.forward(loaded) print(ids) print(encoded) However the forward call supports also not passing ids, but batch still has to be wrapped as list (of length 1):: encoded, dummy_ids = PrimarySequenceEncoder().forward( [ ['MQNP', 'LLLLL'], # type: Sequence[str] # sequence_ids may be missing here ] ) """
[docs] def __init__( self, model_type: str = "transformer", from_pretrained: Optional[str] = "bert-base", model_config_file: Optional[str] = None, # full_sequence_embed: bool = False, tokenizer: str = "iupac", ): """Initialize the PrimarySequenceEncoder. Args: model_type: Which type of model to create (e.g. transformer, unirep, ...). Defaults to 'transformer'. from_pretrained: either a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``, or a path to a `directory` containing model weights saved using :func:`tape.models.modeling_utils.ProteinConfig.save_pretrained`, e.g.: ``./my_model_directory/``. Defaults to 'bert-base'. model_config_file: A json config file that specifies hyperparameters. Defaults to None. tokenizer: vocabulary name. Defaults to 'iupac'. Note: tapes default seed would be 42 (see `tape.utils.set_random_seeds`) """ super().__init__() # padding during forward goes through cpu (numpy) self.device_indicator = nn.Parameter(torch.empty(0), requires_grad=False) # dummy sequence_ids, so they are optional self.next_dummy_id = 0 task_spec = registry.get_task_spec("embed") # task = 'embed' # from tape.datasets import EmbedDataset self.model = registry.get_task_model( model_type, task_spec.name, model_config_file, from_pretrained ) # to filter out batch items that aren't used in this model # see `from_collated_batch` and `tape.training.ForwardRunner` forward_arg_keys = inspect.getfullargspec(self.model.forward).args self._forward_arg_keys = forward_arg_keys[1:] # remove self argument assert "input_ids" in self._forward_arg_keys self.tokenizer = TAPETokenizer(vocab=tokenizer) self.full_sequence_embed = False self.eval()
[docs] def train(self, mode: bool): # type:ignore """Avoid any setting to train mode.""" return super().train(False)
[docs] def generate_tokenized( self, batch: List[Sequence[str]] ) -> Iterator[Tuple[str, np.ndarray, np.ndarray]]: # batch is list of len 2 (typically tuples[str] of length `batch_size`) for item, sequence_id in zip(*batch): token_ids = self.tokenizer.encode(item) input_mask: np.ndarray = np.ones_like(token_ids) yield sequence_id, token_ids, input_mask
[docs] @classmethod def collate_fn( cls, batch: List[Tuple[str, np.ndarray, np.ndarray]] ) -> Dict[str, Union[List[str], torch.Tensor]]: # from tape.datasets.EmbedDataset because there it's not a classmethod ids, tokens, input_mask = zip(*batch) ids_list: List[str] = list(ids) tokens_tensor: torch.Tensor = torch.from_numpy(pad_sequences(tokens)) input_mask_tensor: torch.Tensor = torch.from_numpy(pad_sequences(input_mask)) # on cpu now, is unavoidable as tokenizer and mask are in numpy. return { "ids": ids_list, "input_ids": tokens_tensor, "input_mask": input_mask_tensor, } # type: ignore
[docs] def from_collated_batch( self, batch: Dict[str, Union[List[str], torch.Tensor]] ) -> Dict[str, torch.Tensor]: # filter arguments batch_tensors: Dict[str, torch.Tensor] = { name: tensor # type:ignore for name, tensor in batch.items() if name in self._forward_arg_keys } device = self.device_indicator.device if device.type == "cuda": batch_tensors = { name: tensor.cuda(device=device, non_blocking=True) for name, tensor in batch_tensors.items() } return batch_tensors
[docs] def forward( # type:ignore self, batch: List[Sequence[str]] ) -> Tuple[torch.Tensor, List[str]]: # batch: List[(primary_sequences,), (sequence_ids,))] of length 2 # keys can be passed on by pytoda via keyed(ds: Keydataset[str]) if len(batch) == 1: # no sequence_ids passed dummy_ids = self.get_dummy_ids(length=len(batch[0])) batch.append(dummy_ids) elif len(batch) == 2: pass else: raise ValueError( "batch should be of length 1 or 2, containing `primary_sequences` " " and optionally `sequence_ids`." ) with torch.no_grad(): # Iterator[(sequence_id, token_ids, input_mask)] batch_loader_like = self.generate_tokenized(batch) batch_dict_with_ids: Dict[ str, Union[List[str], torch.Tensor] ] = self.collate_fn(list(batch_loader_like)) ids: List[str] = cast(List[str], batch_dict_with_ids["ids"]) batch_dict = self.from_collated_batch(batch_dict_with_ids) # outputs = self.model(**batch_dict) # pooled_embed = outputs[1] sequence_embed = self.model(**batch_dict)[0] sequence_lengths = batch_dict["input_mask"].sum(1) # can variable length slicing be done on the batch? if not self.full_sequence_embed: sequences_out: torch.Tensor = sequence_embed.new_empty( # dimension of sequence length will be averaged out size=sequence_embed.shape[::2] ) else: raise NotImplementedError for i, (seqembed, length) in enumerate( zip( sequence_embed, sequence_lengths, ) ): seqembed = seqembed[: int(length)] if not self.full_sequence_embed: seqembed = seqembed.mean(0) sequences_out[i, ...] = seqembed return sequences_out, ids
[docs] def get_dummy_ids(self, length: int) -> Tuple[str, ...]: first = self.next_dummy_id self.next_dummy_id += length # before last return tuple(map(str, range(first, self.next_dummy_id)))