Source code for gt4sd.training_pipelines.paccmann.vae.core

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""PaccMann VAE training utilities."""

import json
import logging
import os
from dataclasses import dataclass, field
from time import time
from typing import Any, Dict, Optional, cast

import torch
from paccmann_chemistry.models.training import train_vae
from paccmann_chemistry.models.vae import StackGRUDecoder, StackGRUEncoder, TeacherVAE
from paccmann_chemistry.utils import collate_fn, disable_rdkit_logging
from paccmann_chemistry.utils.hyperparams import SEARCH_FACTORY
from pytoda.datasets import SMILESDataset
from pytoda.smiles.smiles_language import SMILESLanguage
from torch.utils.tensorboard import SummaryWriter

from ....frameworks.torch import get_device
from ...core import TrainingPipelineArguments
from ..core import PaccMannTrainingPipeline

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]class PaccMannVAETrainingPipeline(PaccMannTrainingPipeline): """Language modeling training pipelines."""
[docs] def train( # type: ignore self, training_args: Dict[str, Any], model_args: Dict[str, Any], dataset_args: Dict[str, Any], ) -> None: """Generic training function for PaccMann training. Args: training_args: training arguments passed to the configuration. model_args: model arguments passed to the configuration. dataset_args: dataset arguments passed to the configuration. """ try: device = get_device() disable_rdkit_logging() params = {**training_args, **dataset_args, **model_args} train_smiles_filepath = params["train_smiles_filepath"] test_smiles_filepath = params["test_smiles_filepath"] smiles_language_filepath = params.get("smiles_language_filepath", None) params["batch_mode"] = "Padded" model_path = params["model_path"] training_name = params["training_name"] checkpoint = params["checkpoint_path"] writer = SummaryWriter(f"logs/{training_name}") logger.info(f"Model with name {training_name} starts.") model_dir = os.path.join(model_path, training_name) log_path = os.path.join(model_dir, "logs") val_dir = os.path.join(log_path, "val_logs") os.makedirs(os.path.join(model_dir, "weights"), exist_ok=True) os.makedirs(os.path.join(model_dir, "results"), exist_ok=True) os.makedirs(log_path, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # Load SMILES language smiles_language: Optional[SMILESLanguage] = None if smiles_language_filepath is not None: smiles_language = SMILESLanguage.load(smiles_language_filepath) logger.info(f"Smiles filepath: {train_smiles_filepath}") # create SMILES eager dataset smiles_train_data = SMILESDataset( train_smiles_filepath, smiles_language=smiles_language, padding=False, selfies=params.get("selfies", False), add_start_and_stop=params.get("add_start_stop_token", True), augment=params.get("augment_smiles", False), canonical=params.get("canonical", False), kekulize=params.get("kekulize", False), all_bonds_explicit=params.get("all_bonds_explicit", False), all_hs_explicit=params.get("all_hs_explicit", False), remove_bonddir=params.get("remove_bonddir", False), remove_chirality=params.get("remove_chirality", False), backend="lazy", device=device, ) smiles_test_data = SMILESDataset( test_smiles_filepath, smiles_language=smiles_language, padding=False, selfies=params.get("selfies", False), add_start_and_stop=params.get("add_start_stop_token", True), augment=params.get("augment_smiles", False), canonical=params.get("canonical", False), kekulize=params.get("kekulize", False), all_bonds_explicit=params.get("all_bonds_explicit", False), all_hs_explicit=params.get("all_hs_explicit", False), remove_bonddir=params.get("remove_bonddir", False), remove_chirality=params.get("remove_chirality", False), backend="lazy", device=device, ) if smiles_language_filepath is None: smiles_language = smiles_train_data.smiles_language smiles_language.save(os.path.join(model_path, f"{training_name}.lang")) else: smiles_language_filename = os.path.basename(smiles_language_filepath) cast(SMILESLanguage, smiles_language).save( os.path.join(model_dir, smiles_language_filename) ) params.update( { "vocab_size": cast( SMILESLanguage, smiles_language ).number_of_tokens, "pad_index": cast(SMILESLanguage, smiles_language).padding_index, } ) vocab_dict = cast(SMILESLanguage, smiles_language).index_to_token params.update( { "start_index": list(vocab_dict.keys())[ list(vocab_dict.values()).index("<START>") ], "end_index": list(vocab_dict.keys())[ list(vocab_dict.values()).index("<STOP>") ], } ) if params.get("embedding", "learned") == "one_hot": params.update({"embedding_size": params["vocab_size"]}) with open(os.path.join(model_dir, "model_params.json"), "w") as fp: json.dump(params, fp) # create DataLoaders train_data_loader = torch.utils.data.DataLoader( smiles_train_data, batch_size=params.get("batch_size", 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get("pin_memory", True), num_workers=params.get("num_workers", 8), ) test_data_loader = torch.utils.data.DataLoader( smiles_test_data, batch_size=params.get("batch_size", 64), collate_fn=collate_fn, drop_last=True, shuffle=True, pin_memory=params.get("pin_memory", True), num_workers=params.get("num_workers", 8), ) # initialize encoder and decoder gru_encoder = StackGRUEncoder(params).to(device) gru_decoder = StackGRUDecoder(params).to(device) gru_vae = TeacherVAE(gru_encoder, gru_decoder).to(device) if checkpoint: gru_vae.load(checkpoint) logger.info(f"Restored existing model from {checkpoint}") logger.info("Model summary:") for name, parameter in gru_vae.named_parameters(): logger.info(f"Param {name}, shape:\t{parameter.shape}") total_params = sum(p.numel() for p in gru_vae.parameters()) logger.info(f"Total # params: {total_params}") loss_tracker = { "test_loss_a": 10e4, "test_rec_a": 10e4, "test_kld_a": 10e4, "ep_loss": 0, "ep_rec": 0, "ep_kld": 0, } # train for n_epoch epochs logger.info("Model creation and data processing done, Training starts.") decoder_search = SEARCH_FACTORY[params.get("decoder_search", "sampling")]( temperature=params.get("temperature", 1.0), beam_width=params.get("beam_width", 3), top_tokens=params.get("top_tokens", 5), ) if writer: pparams = params.copy() pparams["training_file"] = train_smiles_filepath pparams["test_file"] = test_smiles_filepath pparams["language_file"] = smiles_language_filepath pparams["model_path"] = model_path pparams = {k: v if v is not None else "N.A." for k, v in params.items()} pparams["training_name"] = training_name from pprint import pprint pprint(pparams) writer.add_hparams(hparam_dict=pparams, metric_dict={}) for epoch in range(params["epochs"] + 1): t = time() loss_tracker = train_vae( epoch, gru_vae, train_data_loader, test_data_loader, smiles_language, model_dir, search=decoder_search, optimizer=params.get("optimizer", "adadelta"), lr=params["learning_rate"], kl_growth=params["kl_growth"], input_keep=params["input_keep"], test_input_keep=params["test_input_keep"], generate_len=params["generate_len"], log_interval=params["log_interval"], save_interval=params["save_interval"], eval_interval=params["eval_interval"], loss_tracker=loss_tracker, logger=logger, batch_mode=params["batch_mode"], ) logger.info(f"Epoch {epoch}, took {time() - t:.1f}.") logger.info( "Overall:\tBest loss = {0:.4f} in Ep {1}, " "best Rec = {2:.4f} in Ep {3}, " "best KLD = {4:.4f} in Ep {5}".format( loss_tracker["test_loss_a"], loss_tracker["ep_loss"], loss_tracker["test_rec_a"], loss_tracker["ep_rec"], loss_tracker["test_kld_a"], loss_tracker["ep_kld"], ) ) logger.info("Training done, shutting down.") except Exception: logger.exception( "Exception occurred while running PaccMannVAETrainingPipeline" )
[docs]@dataclass class PaccMannVAEModelArguments(TrainingPipelineArguments): """Arguments pertaining to model instantiation.""" __name__ = "model_args" n_layers: int = field( default=2, metadata={"help": "Number of layers for the RNNs."} ) bidirectional: bool = field( default=False, metadata={"help": "Whether the RNN cells are bidirectional."} ) rnn_cell_size: int = field(default=512, metadata={"help": "Size of the RNN cells."}) latent_dim: int = field(default=256, metadata={"help": "Size of the RNN cells."}) stack_width: int = field( default=50, metadata={"help": "Width of the memory stack for the RNN cell."} ) stack_depth: int = field( default=50, metadata={"help": "Depth of the memory stack for the RNN cell."} ) decode_search: str = field( default="sampling", metadata={"help": "Decoder search strategy."} ) dropout: float = field(default=0.2, metadata={"help": "Dropout rate to apply."}) generate_len: int = field( default=100, metadata={"help": "Length in tokens of the generated molecules."} ) kl_growth: float = field( default=0.003, metadata={"help": "Growth of the KL term weight in the loss."} ) input_keep: float = field( default=0.85, metadata={"help": "Probability to keep input tokens in train."} ) test_input_keep: float = field( default=1.0, metadata={"help": "Probability to keep input tokens in test."} ) temperature: float = field( default=0.8, metadata={"help": "Temperature for the sampling."} ) embedding: str = field( default="one_hot", metadata={ "help": "Embedding technique for the tokens. 'one_hot' or 'learned'." }, ) vocab_size: int = field( default=380, metadata={"help": "Size of the vocabulary of chemical tokens."} ) pad_index: int = field(default=0, metadata={"help": "Index for the padding token."}) embedding_size: int = field( default=380, metadata={"help": "Size of the embedding vectors."} ) beam_width: int = field(default=3, metadata={"help": "Width of the beam search."}) top_tokens: int = field( default=5, metadata={"help": "Number of tokens to consider in the beam search."} )