Source code for gt4sd.properties.molecules.core

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import os
from enum import Enum
from typing import List, Optional, Union

import importlib_resources
import numpy as np
import pandas as pd
import torch
import yaml
from gt4sd_molformer.finetune.finetune_pubchem_light import (
    LightningModule as RegressionLightningModule,
)
from gt4sd_molformer.finetune.finetune_pubchem_light import (
    PropertyPredictionDataModule as RegressionDataModule,
)
from gt4sd_molformer.finetune.finetune_pubchem_light import (
    PropertyPredictionDataset as RegressionDataset,
)
from gt4sd_molformer.finetune.finetune_pubchem_light_classification import (
    LightningModule as ClassificationLightningModule,
)
from gt4sd_molformer.finetune.finetune_pubchem_light_classification import (
    PropertyPredictionDataModule as ClassificationDataModule,
)
from gt4sd_molformer.finetune.finetune_pubchem_light_classification import (
    PropertyPredictionDataset as ClassificationDataset,
)
from gt4sd_molformer.finetune.finetune_pubchem_light_classification_multitask import (
    MultitaskEmbeddingDataset,
    MultitaskModel,
    PropertyPredictionDataModule,
)
from gt4sd_molformer.finetune.ft_tokenizer.ft_tokenizer import MolTranBertTokenizer
from paccmann_generator.drug_evaluators import SIDER
from paccmann_generator.drug_evaluators import ClinTox as _ClinTox
from paccmann_generator.drug_evaluators import OrganDB as _OrganTox
from paccmann_generator.drug_evaluators import SCScore
from paccmann_generator.drug_evaluators import Tox21 as _Tox21
from pydantic import ConfigDict, Field
from tdc import Oracle
from tdc.metadata import download_receptor_oracle_name

from ...algorithms.core import (
    ConfigurablePropertyAlgorithmConfiguration,
    Predictor,
    PredictorAlgorithm,
)
from ...domains.materials import SmallMolecule
from ...frameworks.torch import device_claim
from ..core import (
    ApiTokenParameters,
    CallablePropertyPredictor,
    ConfigurableCallablePropertyPredictor,
    DomainSubmodule,
    IpAdressParameters,
    PropertyPredictorParameters,
    PropertyValue,
    S3Parameters,
)
from ..utils import (
    docking_import_check,
    get_activity_fn,
    get_similarity_fn,
    to_smiles,
    validate_api_token,
    validate_ip,
)
from .functions import (
    bertz,
    esol,
    is_scaffold,
    lipinski,
    logp,
    molecular_weight,
    number_of_aromatic_rings,
    number_of_atoms,
    number_of_h_acceptors,
    number_of_h_donors,
    number_of_heterocycles,
    number_of_large_rings,
    number_of_rings,
    number_of_rotatable_bonds,
    number_of_stereocenters,
    plogp,
    qed,
    sas,
    tpsa,
)


# NOTE: property prediction parameters
[docs]class ScscoreConfiguration(PropertyPredictorParameters): score_scale: int = 5 fp_len: int = 1024 fp_rad: int = 2
[docs]class SimilaritySeedParameters(PropertyPredictorParameters): smiles: str = Field(..., examples=["c1ccccc1"]) fp_key: str = "ECFP4"
[docs]class ActivityAgainstTargetParameters(PropertyPredictorParameters): target: str = Field(..., examples=["drd2"], description="name of the target.")
[docs]class AskcosParameters(IpAdressParameters):
[docs] class Output(str, Enum): plausability: str = "plausibility" num_step: str = "num_step" synthesizability: str = "synthesizability" price: str = "price"
output: Output = Field( default=Output.plausability, examples=[Output.synthesizability], description="Main output return type from ASKCOS", ) save_json: bool = Field(default=False) file_name: str = Field(default="tree_builder_result.json") num_trials: int = Field(default=5) max_depth: int = Field(default=9) max_branching: int = Field(default=25) expansion_time: int = Field(default=60) max_ppg: int = Field(default=100) template_count: int = Field(default=1000) max_cum_prob: float = Field(default=0.999) chemical_property_logic: str = Field(default="none") max_chemprop_c: int = Field(default=0) max_chemprop_n: int = Field(default=0) max_chemprop_o: int = Field(default=0) max_chemprop_h: int = Field(default=0) chemical_popularity_logic: str = Field(default="none") min_chempop_reactants: int = Field(default=5) min_chempop_products: int = Field(default=5) filter_threshold: float = Field(default=0.1) return_first: str = Field(default="true") model_config = ConfigDict(use_enum_values=True)
[docs]class MoleculeOneParameters(ApiTokenParameters): oracle_name: str = "Molecule One Synthesis"
[docs]class DockingTdcParameters(PropertyPredictorParameters): # To dock against a receptor defined via TDC target: str = Field( ..., examples=download_receptor_oracle_name, description="Target for docking, provided via TDC", )
[docs]class DockingParameters(PropertyPredictorParameters): # To dock against a user-provided receptor name: str = Field(default="pyscreener") receptor_pdb_file: str = Field( examples=["/tmp/2hbs.pdb"], description="Path to receptor PDB file" ) box_center: List[int] = Field( examples=[[15.190, 53.903, 16.917]], description="Docking box center" ) box_size: List[float] = Field( examples=[[20, 20, 20]], description="Docking box size" )
[docs]class S3ParametersMolecules(S3Parameters): domain: DomainSubmodule = DomainSubmodule("molecules")
[docs]class MolformerParameters(S3ParametersMolecules): algorithm_name: str = "molformer" batch_size: int = Field(description="Prediction batch size", default=128) workers: int = Field(description="Number of data loading workers", default=8) device: Optional[str] = Field( description="Device to be used for inference", default=None )
[docs]class MolformerClassificationParameters(MolformerParameters): algorithm_application: str = "classification"
[docs]class MolformerMultitaskClassificationParameters(MolformerParameters): algorithm_application: str = "multitask_classification"
[docs]class MolformerRegressionParameters(MolformerParameters): algorithm_application: str = "regression"
[docs]class MCAParameters(S3ParametersMolecules): algorithm_name: str = "MCA"
[docs]class Tox21Parameters(MCAParameters): algorithm_application: str = "Tox21"
[docs]class ClinToxParameters(MCAParameters): algorithm_application: str = "ClinTox"
[docs]class SiderParameters(MCAParameters): algorithm_application: str = "SIDER"
[docs]class OrganToxParameters(MCAParameters):
[docs] class Organs(str, Enum): adrenal_gland: str = "Adrenal Gland" bone_marrow: str = "Bone Marrow" brain: str = "Brain" eye: str = "Eye" heart: str = "Heart" kidney: str = "Kidney" liver: str = "Liver" lung: str = "Lung" lymph_node: str = "Lymph Node" mammary_gland: str = "Mammary Gland" ovary: str = "Ovary" pancreas: str = "Pancreas" pituitary_gland: str = "Pituitary Gland" spleen: str = "Spleen" stomach: str = "Stomach" testes: str = "Testes" thymus: str = "Thymus" thyroid_gland: str = "Thyroid Gland" urinary_bladder: str = "Urinary Bladder" uterus: str = "Uterus"
[docs] class ToxType(str, Enum): chronic: str = "chronic" subchronic: str = "subchronic" multigenerational: str = "multigenerational" all: str = "all"
algorithm_application: str = "OrganTox" site: Organs = Field( ..., examples=[Organs.kidney], description="name of the target site of interest.", ) toxicity_type: ToxType = Field( default=ToxType.all, examples=[ToxType.chronic], description="type of toxicity for which predictions are made.", )
class _Molformer(PredictorAlgorithm): """Base class for all Molformer predictive algorithms.""" def __init__(self, parameters: MolformerParameters): # Set up the configuration from the parameters configuration = ConfigurablePropertyAlgorithmConfiguration( algorithm_type=parameters.algorithm_type, domain=parameters.domain, algorithm_name=parameters.algorithm_name, algorithm_application=parameters.algorithm_application, algorithm_version=parameters.algorithm_version, ) self.batch_size = parameters.batch_size self.workers = parameters.workers self.tokenizer_path = ( importlib_resources.files("gt4sd_molformer") / "finetune/bert_vocab.txt" ) self.device = device_claim(parameters.device) # The parent constructor calls `self.get_model`. super().__init__(configuration=configuration) def get_resources_path_and_config(self, resources_path: str): model_path = os.path.join(resources_path, "model.ckpt") config_path = os.path.join(resources_path, "hparams.yaml") with open(config_path, "r") as stream: config = yaml.safe_load(stream) return config, model_path
[docs]class MolformerClassification(_Molformer): """Class for all Molformer classification algorithms."""
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ config, model_path = self.get_resources_path_and_config(resources_path) config["num_workers"] = 0 tokenizer = MolTranBertTokenizer(self.tokenizer_path) model = ClassificationLightningModule(config, tokenizer).load_from_checkpoint( model_path, strict=False, config=config, tokenizer=tokenizer, vocab=len(tokenizer.vocab), ) model.to(self.device) model.eval() # Wrapper to get the predictions def informative_model(samples: Union[str, List[str]]) -> List[float]: if isinstance(samples, str): samples = [samples] df = pd.DataFrame.from_dict({"smiles": samples}) dataset = ClassificationDataset(df) datamodule = ClassificationDataModule(config, tokenizer) datamodule.test_ds = dataset preds = [] for batch in datamodule.test_dataloader(): with torch.no_grad(): batch = [x.to(self.device) for x in batch] batch_output = model.testing_step(batch, 0, 0) preds_cpu = batch_output["pred"][:, 1] y_pred = np.where(preds_cpu >= 0.5, 1, 0) preds += y_pred.tolist() return preds return informative_model
[docs]class MolformerMultitaskClassification(_Molformer): """Class for all Molformer multitask classification algorithms."""
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ config, model_path = self.get_resources_path_and_config(resources_path) config["num_workers"] = 0 tokenizer = MolTranBertTokenizer(self.tokenizer_path) model = MultitaskModel(config, tokenizer).load_from_checkpoint( model_path, strict=False, config=config, tokenizer=tokenizer, vocab=len(tokenizer.vocab), ) model.to(self.device) model.eval() # Wrapper to get the predictions def informative_model(samples: Union[str, List[str]]) -> List[str]: if isinstance(samples, str): samples = [samples] df = pd.DataFrame.from_dict({"smiles": samples}) dataset = MultitaskEmbeddingDataset(df) datamodule = PropertyPredictionDataModule(config, tokenizer) datamodule.test_ds = dataset preds = [] for batch in datamodule.test_dataloader(): with torch.no_grad(): batch = [x.to(self.device) for x in batch] batch_output = model.testing_step(batch, 0, 0) batch_preds_idx = torch.argmax(batch_output["pred"], dim=1) batch_preds = [config["measure_names"][i] for i in batch_preds_idx] preds += batch_preds return preds return informative_model
[docs]class MolformerRegression(_Molformer): """Class for all Molformer regression algorithms."""
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ config, model_path = self.get_resources_path_and_config(resources_path) config["num_workers"] = 0 tokenizer = MolTranBertTokenizer(self.tokenizer_path) model = RegressionLightningModule(config, tokenizer).load_from_checkpoint( model_path, strict=False, config=config, tokenizer=tokenizer, vocab=len(tokenizer.vocab), ) model.to(self.device) model.eval() # Wrapper to get the predictions def informative_model(samples: Union[str, List[str]]) -> List[float]: if isinstance(samples, str): samples = [samples] df = pd.DataFrame.from_dict({"smiles": samples}) dataset = RegressionDataset(df, False, config["aug"]) datamodule = RegressionDataModule(config, tokenizer) datamodule.test_ds = dataset preds = [] for batch in datamodule.test_dataloader(): with torch.no_grad(): batch = [x.to(self.device) for x in batch] batch_output = model.testing_step(batch, 0, 0) preds += batch_output["pred"].view(-1).tolist() return preds return informative_model
# NOTE: property prediction classes
[docs]class Plogp(CallablePropertyPredictor): """Calculate the penalized logP of a molecule. This is the logP minus the number of rings with > 6 atoms minus the SAS. """
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=plogp, parameters=parameters)
[docs]class Lipinski(CallablePropertyPredictor): """Calculate whether a molecule adheres to the Lipinski-rule-of-5. A crude approximation of druglikeness. """
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=lipinski, parameters=parameters)
[docs]class Esol(CallablePropertyPredictor): """Estimate the water solubility of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=esol, parameters=parameters)
[docs]class Scscore(CallablePropertyPredictor): """Calculate the synthetic complexity score (SCScore) of a molecule."""
[docs] def __init__( self, parameters: ScscoreConfiguration = ScscoreConfiguration() ) -> None: super().__init__( callable_fn=SCScore(**parameters.dict()), parameters=parameters )
[docs]class Sas(CallablePropertyPredictor): """Calculate the synthetic accessibility score (SAS) for a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=sas, parameters=parameters)
[docs]class Bertz(CallablePropertyPredictor): """Calculate Bertz index of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=bertz, parameters=parameters)
[docs]class Tpsa(CallablePropertyPredictor): """Calculate the total polar surface area of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=tpsa, parameters=parameters)
[docs]class Logp(CallablePropertyPredictor): """Calculates the partition coefficient of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=logp, parameters=parameters)
[docs]class Qed(CallablePropertyPredictor): """Calculate the quantitative estimate of drug-likeness (QED) of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=qed, parameters=parameters)
[docs]class NumberHAcceptors(CallablePropertyPredictor): """Calculate number of H acceptors of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_h_acceptors, parameters=parameters)
[docs]class NumberAtoms(CallablePropertyPredictor): """Calculate number of atoms of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_atoms, parameters=parameters)
[docs]class NumberHDonors(CallablePropertyPredictor): """Calculate number of H donors of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_h_donors, parameters=parameters)
[docs]class NumberAromaticRings(CallablePropertyPredictor): """Calculate number of aromatic rings of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_aromatic_rings, parameters=parameters)
[docs]class NumberRings(CallablePropertyPredictor): """Calculate number of rings of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_rings, parameters=parameters)
[docs]class NumberRotatableBonds(CallablePropertyPredictor): """Calculate number of rotatable bonds of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_rotatable_bonds, parameters=parameters)
[docs]class NumberLargeRings(CallablePropertyPredictor): """Calculate the amount of large rings (> 6 atoms) of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_large_rings, parameters=parameters)
[docs]class MolecularWeight(CallablePropertyPredictor): """Calculate molecular weight of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=molecular_weight, parameters=parameters)
[docs]class IsScaffold(CallablePropertyPredictor): """Whether a molecule is identical to its Murcko Scaffold."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=is_scaffold, parameters=parameters)
[docs]class NumberHeterocycles(CallablePropertyPredictor): """The amount of heterocycles of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_heterocycles, parameters=parameters)
[docs]class NumberStereocenters(CallablePropertyPredictor): """The amount of stereo centers of a molecule."""
[docs] def __init__( self, parameters: PropertyPredictorParameters = PropertyPredictorParameters() ) -> None: super().__init__(callable_fn=number_of_stereocenters, parameters=parameters)
[docs]class SimilaritySeed(CallablePropertyPredictor): """Calculate the similarity of a molecule to a seed molecule."""
[docs] def __init__(self, parameters: SimilaritySeedParameters) -> None: super().__init__( callable_fn=get_similarity_fn( target_mol=parameters.smiles, fp_key=parameters.fp_key ), parameters=parameters, )
[docs]class ActivityAgainstTarget(CallablePropertyPredictor): """Calculate the activity of a molecule against a target molecule."""
[docs] def __init__(self, parameters: ActivityAgainstTargetParameters) -> None: super().__init__( callable_fn=get_activity_fn(target=parameters.target), parameters=parameters )
[docs]class Askcos(ConfigurableCallablePropertyPredictor): """ A property predictor that uses the ASKCOs API to calculate the synthesizability of a molecule. """
[docs] def __init__(self, parameters: AskcosParameters): # Raises if IP is not valid msg = ( "You have to point to an IP address of a running ASKCOS instance. " "For details on setting this up, see: https://tdcommons.ai/functions/oracles/#askcos" ) if not isinstance(parameters.host_ip, str): raise TypeError(f"IP adress must be a string, not {parameters.host_ip}") if not hasattr(parameters, "host_ip"): raise AttributeError(f"IP adress missing in {parameters}") if "http" not in parameters.host_ip: raise ValueError( f"ASKCOS requires an IP prepended with a http, e.g., " f"'http://xx.xx.xxx.xxx' and not {parameters.host_ip}." ) ip = parameters.host_ip.split("//")[1] validate_ip(ip, message=msg) super().__init__(callable_fn=Oracle(name="ASKCOS"), parameters=parameters)
[docs]class MoleculeOne(CallablePropertyPredictor): """ A property predictor that uses the MoleculeOne API to calculate the synthesizability of a molecule. """
[docs] def __init__(self, parameters: MoleculeOneParameters): msg = ( "You have to provide a valid API key, for details on setting this up, see: " "https://tdcommons.ai/functions/oracles/#moleculeone" ) # Only performs type checking on API key validate_api_token(parameters, message=msg) super().__init__( callable_fn=Oracle( name=parameters.oracle_name, api_token=parameters.api_token ), parameters=parameters, )
[docs]class DockingTdc(ConfigurableCallablePropertyPredictor): """ A property predictor that computes the docking score against a target provided via the TDC package (see: https://tdcommons.ai/functions/oracles/#docking-scores) """
[docs] def __init__(self, parameters: DockingTdcParameters): docking_import_check() callable = Oracle(name=parameters.target) super().__init__(callable_fn=callable, parameters=parameters)
[docs]class Docking(ConfigurableCallablePropertyPredictor): """ A property predictor that computes the docking score against a user-defined target. Relies on TDC backend, see https://tdcommons.ai/functions/oracles/#docking-scores for setup. """
[docs] def __init__(self, parameters: DockingParameters): docking_import_check() callable = Oracle( name=parameters.name, receptor_pdb_file=parameters.receptor_pdb_file, box_center=parameters.box_center, box_size=parameters.box_size, ) super().__init__(callable_fn=callable, parameters=parameters)
class _MCA(PredictorAlgorithm): """Base class for all MCA-based predictive algorithms.""" def __init__(self, parameters: MCAParameters): # Set up the configuration from the parameters configuration = ConfigurablePropertyAlgorithmConfiguration( algorithm_type=parameters.algorithm_type, domain=parameters.domain, algorithm_name=parameters.algorithm_name, algorithm_application=parameters.algorithm_application, algorithm_version=parameters.algorithm_version, ) # The parent constructor calls `self.get_model`. super().__init__(configuration=configuration)
[docs]class Tox21(_MCA): """Model to predict environmental toxicity for the 12 endpoints in Tox21."""
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ # This model returns a singular reward and not a prediction for all 12 classes. model = _Tox21(model_path=resources_path) # Wrapper to get toxicity-endpoint-level predictions def informative_model(x: SmallMolecule) -> List[PropertyValue]: x = to_smiles(x) _ = model(x) return model.predictions.detach().tolist() return informative_model
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the 12 endpoints from the Tox21 challenge. The endpoints are: NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53 For details on the data see: https://tripod.nih.gov/tox21/challenge/. """ return text
[docs]class ClinTox(_MCA): """Model to predict environmental toxicity for the 12 endpoints in Tox21."""
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ # This model returns a singular reward and not a prediction for both classes. model = _ClinTox(model_path=resources_path) # Wrapper to get toxicity-endpoint-level predictions def informative_model(x: SmallMolecule) -> List[PropertyValue]: x = to_smiles(x) _ = model(x) return model.predictions.detach().tolist() return informative_model
[docs] @classmethod def get_description(cls) -> str: text = """ This model is a multitask classifier for two classes: 1. Predicted probability to receive FDA approval. 2. Predicted probability of failure in clinical trials. For details on the data see: https://pubs.rsc.org/en/content/articlehtml/2018/sc/c7sc02664a. """ return text
[docs]class Sider(_MCA):
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ # This model returns a singular reward and not a prediction for both classes. model = SIDER(model_path=resources_path) # Wrapper to get toxicity-endpoint-level predictions def informative_model(x: SmallMolecule) -> List[PropertyValue]: x = to_smiles(x) _ = model(x) return model.predictions.detach().tolist() return informative_model
[docs] @classmethod def get_description(cls) -> str: text = """ This model is a multitask classifier to predict side effects of drugs across 27 classes. For details on the data see: https://pubs.rsc.org/en/content/articlehtml/2018/sc/c7sc02664a. """ return text
[docs]class OrganTox(_MCA): """Model to predict toxicity for different organs."""
[docs] def __init__(self, parameters: OrganToxParameters) -> None: # Extract model-specific parameters self.site = parameters.site self.toxicity_type = parameters.toxicity_type super().__init__(parameters=parameters)
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ # This model returns a singular reward and not a prediction for both classes. model = _OrganTox( model_path=resources_path, site=self.site, toxicity_type=self.toxicity_type ) # Wrapper to get toxicity-endpoint-level predictions def informative_model(x: SmallMolecule) -> List[PropertyValue]: x = to_smiles(x) _ = model(x) all_preds = model.predictions.detach() return all_preds[model.class_indices].tolist() return informative_model
[docs] @classmethod def get_description(cls) -> str: text = """ This model is a multitask classifier to toxicity across different organs and toxicity types (`chronic`, `subchronic`, `multigenerational` or `all`). The organ has to specified in the constructor. The toxicity type defaults to `all` in which case three values are returned in the order `chronic`, `multigenerational` and `subchronic`. For details on the data see: https://pubs.rsc.org/en/content/articlehtml/2018/sc/c7sc02664a. """ return text