#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Implementation of PaccMann^RL conditional generators."""
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import torch
from paccmann_chemistry.models import StackGRUDecoder, StackGRUEncoder, TeacherVAE
from paccmann_chemistry.utils.search import SamplingSearch
from paccmann_omics.encoders import ENCODER_FACTORY
from pytoda.smiles.smiles_language import SMILESLanguage
from rdkit import Chem
from ....domains.materials import MoleculeFormat, validate_molecules
from ....domains.materials.protein_encoding import PrimarySequenceEncoder
from ....frameworks.torch import device_claim
from ....frameworks.torch.vae import reparameterize
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]class ConditionalGenerator(ABC):
"""Abstract interface for a conditional generator."""
#: device where the inference is running.
device: torch.device
#: temperature for the sampling.
temperature: float
#: maximum length of the generated molecules.
generated_length: int
#: parameters for the SELFIES generator.
selfies_conditional_generator_params: dict
#: SELFIES generator.
selfies_conditional_generator: TeacherVAE
#: SMILES language instance.
smiles_language: SMILESLanguage
generator_latent_size: int
encoder_latent_size: int
[docs] def get_smiles_from_latent(self, latent: torch.Tensor) -> List[str]:
"""Take samples from the latent space.
Args:
latent: latent vector tensor.
Returns:
SMILES list and indexes for the valid ones.
"""
if self.generator_latent_size == 2 * self.encoder_latent_size:
latent = latent.repeat(1, 1, 2)
# generate molecules as tokens list
generated_molecules = self.selfies_conditional_generator.generate(
latent,
prime_input=torch.tensor(
[self.smiles_language.start_index], device=self.device
).long(),
end_token=torch.tensor(
[self.smiles_language.stop_index], device=self.device
).long(),
generate_len=self.generated_length,
search=SamplingSearch(temperature=self.temperature),
)
molecules = [
self.smiles_language.token_indexes_to_smiles(generated_molecule.tolist())
for generated_molecule in iter(generated_molecules)
]
# convert SELFIES to SMILES
if "selfies" in self.smiles_language.name:
molecules = [
self.smiles_language.selfies_to_smiles(a_selfies)
for a_selfies in molecules
]
return molecules
[docs] @staticmethod
def validate_molecules(smiles) -> Tuple[List[Chem.rdchem.Mol], List[int]]:
return validate_molecules(pattern_list=smiles, input_type=MoleculeFormat.smiles)
[docs] @abstractmethod
def get_latent(self, condition: Any) -> torch.Tensor:
pass
[docs] def generate_batch(self, condition: Any) -> List[str]:
logger.info("embedding condition and getting reparametrized latent samples")
latent = self.get_latent(condition)
logger.info("starting generation of molecules")
# generate the molecules
return self.get_smiles_from_latent(latent)
[docs]class ProteinSequenceConditionalGenerator(ConditionalGenerator):
"""
Protein conditional generator as implemented in https://doi.org/10.1088/2632-2153/abe808
(originally https://arxiv.org/abs/2005.13285).
It generates highly binding and low toxic ligands.
Attributes:
samples_per_protein: number of points sampled per protein.
It has to be greater than 1.
protein_embedding_encoder_params: parameter for the protein embedding encoder.
protein_embedding_encoder: protein embedding encoder.
"""
[docs] def __init__(
self,
resources_path: str,
temperature: float = 1.4,
generated_length: int = 100,
samples_per_protein: int = 100,
device: Optional[Union[torch.device, str]] = None,
) -> None:
"""
Initialize the generator.
Args:
resources_path: directory where to find models and parameters.
temperature: temperature for the sampling. Defaults to 1.4.
generated_length: maximum length of the generated molecules.
Defaults to 100.
samples_per_protein: number of points sampled per protein.
It has to be greater than 1. Defaults to 10.
device: device where the inference
is running either as a dedicated class or a string. If not provided is inferred.
"""
# device
self.device = device_claim(device)
# setting sampling parameters
self.temperature = temperature
self.generated_length = generated_length
self.samples_per_protein = samples_per_protein
# instantiate protein embedding encoder
with open(os.path.join(resources_path, "protein_embedding_params.json")) as fp:
self.protein_embedding_encoder_params = json.load(fp)
self.protein_embedding_encoder = ENCODER_FACTORY["dense"](
self.protein_embedding_encoder_params
).to(self.device)
self.protein_embedding_encoder.load(
os.path.join(resources_path, "protein_embedding_encoder.pt"),
map_location=self.device,
)
self.protein_embedding_encoder.eval()
self.encoder_latent_size = self.protein_embedding_encoder.latent_size
# instantiate selfies conditional generator
with open(
os.path.join(resources_path, "selfies_conditional_generator.json")
) as fp:
self.selfies_conditional_generator_params = json.load(fp)
self.selfies_conditional_generator = TeacherVAE(
StackGRUEncoder(self.selfies_conditional_generator_params),
StackGRUDecoder(self.selfies_conditional_generator_params),
).to(self.device)
self.selfies_conditional_generator.load(
os.path.join(resources_path, "selfies_conditional_generator.pt"),
map_location=self.device,
)
self.selfies_conditional_generator.eval()
self.generator_latent_size = (
self.selfies_conditional_generator.decoder.latent_dim
)
# loading SMILES language for decoding and conversion of SELFIES to SMILES
self.smiles_language = SMILESLanguage.load(
os.path.join(resources_path, "selfies_language.pkl")
)
# protein embedding from primary sequence (via tape)
self.primary_sequence_embedder = PrimarySequenceEncoder(
model_type="transformer",
from_pretrained="bert-base",
model_config_file=None,
tokenizer="iupac",
).to(self.device)
[docs] def get_latent(self, protein: str) -> torch.Tensor:
"""
Given a protein generate the latent representation.
Args:
protein: the protein used as context/condition.
Returns:
the latent representation for the given context. It contains
self.samples_per_protein repeats.
"""
# encode embedded sequence once, ignore the returned dummy ids
embeddings, _ = self.primary_sequence_embedder.forward([[protein]])
protein_mu, protein_logvar = self.protein_embedding_encoder(
embeddings.to(self.device)
)
# now stack as batch to generate different samples
proteins_mu = torch.cat([protein_mu] * self.samples_per_protein, dim=0)
proteins_logvar = torch.cat([protein_logvar] * self.samples_per_protein, dim=0)
# get latent representation
return torch.unsqueeze(reparameterize(proteins_mu, proteins_logvar), 0)
[docs] def generate_batch(self, protein: str) -> List[str]:
return super().generate_batch(condition=protein)
[docs]class TranscriptomicConditionalGenerator(ConditionalGenerator):
"""
Transcriptomic conditional generator as implemented in https://doi.org/10.1016/j.isci.2021.102269
(originally https://doi.org/10.1007/978-3-030-45257-5_18, https://arxiv.org/abs/1909.05114).
It generates highly effective small molecules against transcriptomic progiles.
Attributes:
samples_per_profile: number of points sampled per profile.
It has to be greater than 1.
transcriptomic_encoder_params: parameter for the protein embedding encoder.
transcriptomic_encoder: protein embedding encoder.
"""
[docs] def __init__(
self,
resources_path: str,
temperature: float = 1.4,
generated_length: int = 100,
samples_per_profile: int = 100,
device: Optional[Union[torch.device, str]] = None,
) -> None:
"""
Initialize the generator.
Args:
resources_path: directory where to find models and parameters.
temperature: temperature for the sampling. Defaults to 1.4.
generated_length: maximum length of the generated molecules.
Defaults to 100.
samples_per_profile: number of points sampled per protein.
It has to be greater than 1. Defaults to 10.
device: device where the inference
is running either as a dedicated class or a string. If not provided is inferred.
"""
# device
self.device = device_claim(device)
# setting sampling parameters
self.temperature = temperature
self.generated_length = generated_length
self.samples_per_profile = samples_per_profile
with open(os.path.join(resources_path, "genes.txt")) as fp:
self.genes = [gene.strip() for gene in fp if gene]
# instantiate protein embedding encoder
with open(os.path.join(resources_path, "transcriptomic_params.json")) as fp:
self.transcriptomic_encoder_params = json.load(fp)
self.transcriptomic_encoder = ENCODER_FACTORY["dense"](
self.transcriptomic_encoder_params
).to(self.device)
self.transcriptomic_encoder.load(
os.path.join(resources_path, "transcriptomic_encoder.pt"),
map_location=self.device,
)
self.transcriptomic_encoder.eval()
self.encoder_latent_size = self.transcriptomic_encoder.latent_size
# instantiate selfies conditional generator
with open(
os.path.join(resources_path, "selfies_conditional_generator.json")
) as fp:
self.selfies_conditional_generator_params = json.load(fp)
self.selfies_conditional_generator = TeacherVAE(
StackGRUEncoder(self.selfies_conditional_generator_params),
StackGRUDecoder(self.selfies_conditional_generator_params),
).to(self.device)
self.selfies_conditional_generator.load(
os.path.join(resources_path, "selfies_conditional_generator.pt"),
map_location=self.device,
)
self.selfies_conditional_generator.eval()
self.generator_latent_size = (
self.selfies_conditional_generator.decoder.latent_dim
)
# loading SMILES language for decoding and conversion of SELFIES to SMILES
self.smiles_language = SMILESLanguage.load(
os.path.join(resources_path, "selfies_language.pkl")
)
[docs] def get_latent(self, profile: Union[np.ndarray, pd.Series, str]) -> torch.Tensor:
"""
Given a profile generate the latent representation.
Args:
profile: the profile used as context/condition.
Raises:
ValueError: in case the profile has a size mismatch with the genes panel.
Returns:
the latent representation for the given context. It contains
self.samples_per_profile repeats.
"""
if isinstance(profile, pd.Series):
# make sure genes are sorted
profile = profile[self.genes].values
elif isinstance(profile, str):
logger.warning("profile passed as string, serializing it to a list")
profile = np.array(json.loads(profile))
if profile.size != len(self.genes):
raise ValueError(
f"provided profile size ({profile.size}) does not match required size {len(self.genes)}"
)
# encode embedded progiles
transcriptomic_mu, transcriptomic_logvar = self.transcriptomic_encoder(
torch.from_numpy(
np.vstack([profile] * self.samples_per_profile),
)
.float()
.to(self.device)
)
# get latent representation
return torch.unsqueeze(
reparameterize(transcriptomic_mu, transcriptomic_logvar), 0
)
[docs] def generate_batch(self, profile: Union[np.ndarray, pd.Series]) -> List[str]:
return super().generate_batch(condition=profile)