Source code for gt4sd.training_pipelines.cgcnn.core

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Cgcnn training utilities."""

import logging
import os
import shutil
import time
from dataclasses import dataclass, field
from random import sample
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import metrics
from torch import Tensor
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader

from ...frameworks.cgcnn.data import CIFData, collate_pool, get_train_val_test_loader
from ...frameworks.cgcnn.model import CrystalGraphConvNet, Normalizer
from ..core import TrainingPipeline, TrainingPipelineArguments

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]class CGCNNTrainingPipeline(TrainingPipeline): """CGCNN training pipelines for crystals."""
[docs] def train( # type: ignore self, training_args: Dict[str, Any], model_args: Dict[str, Any], dataset_args: Dict[str, Any], ) -> None: """Generic training function for CGCNN models. Args: training_args: training arguments passed to the configuration. model_args: model arguments passed to the configuration. dataset_args: dataset arguments passed to the configuration. Raises: NotImplementedError: the generic trainer does not implement the pipeline. """ training_args["disable_cuda"] = ( training_args["disable_cuda"] or not torch.cuda.is_available() ) if training_args["task"] == "regression": best_mae_error = 1e10 else: best_mae_error = 0.0 # load data dataset = CIFData(dataset_args["datapath"]) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( # type: ignore dataset=dataset, collate_fn=collate_fn, batch_size=training_args["batch_size"], num_workers=training_args["workers"], pin_memory=training_args["disable_cuda"], train_size=dataset_args["train_size"], val_size=dataset_args["val_size"], test_size=dataset_args["test_size"], return_test=True, ) # obtain target value normalizer if training_args["task"] == "classification": normalizer = Normalizer(torch.zeros(2)) normalizer.load_state_dict({"mean": 0.0, "std": 1.0}) else: if len(dataset) < 500: logger.warning( "Dataset has less than 500 data points. " "Lower accuracy is expected. " ) sample_data_list = [dataset[i] for i in range(len(dataset))] else: sample_data_list = [ dataset[i] for i in sample(range(len(dataset)), 500) ] _, sample_target, _ = collate_pool(sample_data_list) normalizer = Normalizer(sample_target) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] # type: ignore model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args["atom_fea_len"], n_conv=model_args["n_conv"], h_fea_len=model_args["h_fea_len"], n_h=model_args["n_h"], classification=True if training_args["task"] == "classification" else False, ) if not training_args["disable_cuda"]: model.cuda() # define loss func and optimizer if training_args["task"] == "classification": criterion = nn.NLLLoss() else: criterion = nn.MSELoss() # type: ignore if training_args["optim"] == "SGD": optimizer = optim.SGD( model.parameters(), training_args["lr"], momentum=training_args["momentum"], weight_decay=training_args["weight_decay"], ) elif training_args["optim"] == "Adam": optimizer = optim.Adam( # type: ignore model.parameters(), training_args["lr"], weight_decay=training_args["weight_decay"], ) else: raise NameError("Only SGD or Adam is allowed as optimizer") # optionally resume from a checkpoint if training_args["resume"]: if os.path.isfile(training_args["resume"]): logger.info("loading checkpoint '{}'".format(training_args["resume"])) checkpoint = torch.load(training_args["resume"]) training_args["start_epoch"] = checkpoint["epoch"] best_mae_error = checkpoint["best_mae_error"] model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) normalizer.load_state_dict(checkpoint["normalizer"]) logger.info( "loaded checkpoint '{}' (epoch {})".format( training_args["resume"], checkpoint["epoch"] ) ) else: logger.info( "no checkpoint found at '{}'".format(training_args["resume"]) ) scheduler = MultiStepLR( optimizer, milestones=[training_args["lr_milestone"]], gamma=0.1 ) for epoch in range(training_args["start_epoch"], training_args["epochs"]): # train for one epoch train( train_loader, model, criterion, optimizer, epoch, normalizer, training_args["disable_cuda"], training_args["task"], training_args["print_freq"], ) # evaluate on validation set mae_error = validate( val_loader, model, criterion, normalizer, training_args["disable_cuda"], training_args["task"], training_args["print_freq"], test=True, ) if mae_error != mae_error: raise ValueError("mae_error is NaN") scheduler.step() # remember the best mae_eror and save checkpoint if training_args["task"] == "regression": is_best = mae_error < best_mae_error best_mae_error = min(mae_error, best_mae_error) else: is_best = mae_error > best_mae_error best_mae_error = max(mae_error, best_mae_error) save_checkpoint( { "epoch": epoch + 1, "state_dict": model.state_dict(), "best_mae_error": best_mae_error, "optimizer": optimizer.state_dict(), "normalizer": normalizer.state_dict(), "training_args": training_args, "model_args": model_args, "dataset_args": dataset_args, }, is_best, training_args["output_path"], ) # test best model logger.info("Evaluate Model on Test Set") best_path = os.path.join(training_args["output_path"], "model_best.pth.tar") if os.path.exists(best_path): best_checkpoint = torch.load(best_path) model.load_state_dict(best_checkpoint["state_dict"]) validate( test_loader, model, criterion, normalizer, training_args["disable_cuda"], training_args["task"], training_args["print_freq"], test=True, )
[docs]def train( train_loader: Union[DataLoader[Any], Any], model: CrystalGraphConvNet, criterion: Union[nn.NLLLoss, nn.MSELoss], optimizer: Union[optim.SGD, optim.Adam], epoch: int, normalizer: Normalizer, disable_cuda: bool, task: str, print_freq: int, ) -> None: """Train step for cgcnn models. Args: train_loader: Dataloader for the training set. model: CGCNN model. criterion: loss function. optimizer: Optimizer to be used. epoch: Epoch number. normalizer: Normalize. disable_cuda: Disable CUDA. task: Training task. print_freq: Print frequency. """ batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() mae_errors = AverageMeter() accuracies = AverageMeter() precisions = AverageMeter() recalls = AverageMeter() fscores = AverageMeter() auc_scores = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target, _) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) if not disable_cuda: input_var = ( Variable(input[0].cuda(non_blocking=True)), Variable(input[1].cuda(non_blocking=True)), input[2].cuda(non_blocking=True), [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]], ) else: input_var = (Variable(input[0]), Variable(input[1]), input[2], input[3]) # normalize target if task == "regression": target_normed = normalizer.norm(target) else: target_normed = target.view(-1).long() if not disable_cuda: target_var = Variable(target_normed.cuda(non_blocking=True)) else: target_var = Variable(target_normed) # compute output output = model(*input_var) loss = criterion(output, target_var) # measure accuracy and record loss if task == "regression": mae_error = mae(normalizer.denorm(output.data.cpu()), target) losses.update(loss.data.cpu(), target.size(0)) mae_errors.update(mae_error, target.size(0)) # type: ignore else: accuracy, precision, recall, fscore, auc_score = class_eval( output.data.cpu(), target ) losses.update(loss.data.cpu().item(), target.size(0)) accuracies.update(accuracy, target.size(0)) precisions.update(precision, target.size(0)) recalls.update(recall, target.size(0)) fscores.update(fscore, target.size(0)) auc_scores.update(auc_score, target.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: if task == "regression": logger.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})".format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, mae_errors=mae_errors, ) ) else: logger.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Accu {accu.val:.3f} ({accu.avg:.3f})\t" "Precision {prec.val:.3f} ({prec.avg:.3f})\t" "Recall {recall.val:.3f} ({recall.avg:.3f})\t" "F1 {f1.val:.3f} ({f1.avg:.3f})\t" "AUC {auc.val:.3f} ({auc.avg:.3f})".format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, accu=accuracies, prec=precisions, recall=recalls, f1=fscores, auc=auc_scores, ) )
[docs]def validate( val_loader: Union[DataLoader[Any], Any], model: CrystalGraphConvNet, criterion: Union[nn.MSELoss, nn.NLLLoss], normalizer: Normalizer, disable_cuda: bool, task: str, print_freq: int, test: bool = False, ) -> float: """Validation step for cgcnn models. Args: val_loader: Dataloader for the validation set. model: CGCNN model. criterion: loss function. normalizer: Normalize. disable_cuda: Disable CUDA. task: Training task. print_freq: Print frequency. test: test or only validate using the given dataset. Returns: average auc or mae depending on the training task. """ batch_time = AverageMeter() losses = AverageMeter() mae_errors = AverageMeter() accuracies = AverageMeter() precisions = AverageMeter() recalls = AverageMeter() fscores = AverageMeter() auc_scores = AverageMeter() test_targets = [] test_preds = [] test_cif_ids = [] # switch to evaluate mode model.eval() end = time.time() for i, (input, target, batch_cif_ids) in enumerate(val_loader): if not disable_cuda: with torch.no_grad(): input_var = ( Variable(input[0].cuda(non_blocking=True)), Variable(input[1].cuda(non_blocking=True)), input[2].cuda(non_blocking=True), [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]], ) else: with torch.no_grad(): input_var = (Variable(input[0]), Variable(input[1]), input[2], input[3]) if task == "regression": target_normed = normalizer.norm(target) else: target_normed = target.view(-1).long() if not disable_cuda: with torch.no_grad(): target_var = Variable(target_normed.cuda(non_blocking=True)) else: with torch.no_grad(): target_var = Variable(target_normed) # compute output output = model(*input_var) loss = criterion(output, target_var) # measure accuracy and record loss if task == "regression": mae_error = mae(normalizer.denorm(output.data.cpu()), target) losses.update(loss.data.cpu().item(), target.size(0)) mae_errors.update(mae_error, target.size(0)) # type: ignore if test: test_pred = normalizer.denorm(output.data.cpu()) test_target = target test_preds += test_pred.view(-1).tolist() test_targets += test_target.view(-1).tolist() test_cif_ids += batch_cif_ids else: accuracy, precision, recall, fscore, auc_score = class_eval( output.data.cpu(), target ) losses.update(loss.data.cpu().item(), target.size(0)) accuracies.update(accuracy, target.size(0)) precisions.update(precision, target.size(0)) recalls.update(recall, target.size(0)) fscores.update(fscore, target.size(0)) auc_scores.update(auc_score, target.size(0)) if test: test_pred = torch.exp(output.data.cpu()) test_target = target assert test_pred.shape[1] == 2 test_preds += test_pred[:, 1].tolist() test_targets += test_target.view(-1).tolist() test_cif_ids += batch_cif_ids # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: if task == "regression": logger.info( "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})".format( i, len(val_loader), batch_time=batch_time, loss=losses, mae_errors=mae_errors, ) ) else: logger.info( "Test: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Accu {accu.val:.3f} ({accu.avg:.3f})\t" "Precision {prec.val:.3f} ({prec.avg:.3f})\t" "Recall {recall.val:.3f} ({recall.avg:.3f})\t" "F1 {f1.val:.3f} ({f1.avg:.3f})\t" "AUC {auc.val:.3f} ({auc.avg:.3f})".format( i, len(val_loader), batch_time=batch_time, loss=losses, accu=accuracies, prec=precisions, recall=recalls, f1=fscores, auc=auc_scores, ) ) if task == "regression": logger.info("MAE {mae_errors.avg:.3f}".format(mae_errors=mae_errors)) return mae_errors.avg else: logger.info("AUC {auc.avg:.3f}".format(auc=auc_scores)) return auc_scores.avg
[docs]def mae(prediction: Tensor, target: Tensor) -> Tensor: """Computes the mean absolute error between prediction and target. Args: prediction: torch.Tensor (N, 1) target: torch.Tensor (N, 1) Returns: the computed mean absolute error. """ return torch.mean(torch.abs(target - prediction))
[docs]def class_eval( prediction: Tensor, target: Tensor ) -> Tuple[float, float, float, float, float]: """Class evaluation. Args: prediction: Predictions. target: Groundtruth. Returns: Computed accuracy, precision, recall, fscore, and auc_score. """ prediction = np.exp(prediction.numpy()) target = target.numpy() pred_label = np.argmax(prediction, axis=1) target_label = np.squeeze(target) if not target_label.shape: target_label = np.asarray([target_label]) if prediction.shape[1] == 2: precision, recall, fscore, _ = metrics.precision_recall_fscore_support( target_label, pred_label, average="binary" ) try: auc_score = metrics.roc_auc_score(target_label, prediction[:, 1]) except ValueError: auc_score = 0.0 accuracy = metrics.accuracy_score(target_label, pred_label) else: raise NotImplementedError return accuracy, precision, recall, fscore, auc_score
[docs]class AverageMeter: """Computes and stores the average and current value."""
[docs] def __init__(self): """Initialize an AverageMeter object.""" self.reset()
[docs] def reset(self) -> None: """Reset values to 0.""" self.val = 0.0 self.avg = 0.0 self.sum = 0.0 self.count = 0
[docs] def update(self, val: float, n: int = 1) -> None: """Update values of the AverageMeter. Args: val: value to be added. n: count. """ self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count
[docs]def save_checkpoint( state: object, is_best: bool, path: str = ".", filename: str = "checkpoint.pth.tar" ) -> None: """Save CGCNN checkpoint. Args: state: checkpoint's object. is_best: whether the given checkpoint has the best performance or not. path: path to save the checkpoint. filename: checkpoint's filename. """ torch.save(state, os.path.join(path, filename)) if is_best: shutil.copyfile( os.path.join(path, filename), os.path.join(path, "model_best.pth.tar") )
[docs]@dataclass class CGCNNDataArguments(TrainingPipelineArguments): """Data arguments related to CGCNN trainer.""" __name__ = "dataset_args" datapath: str = field( metadata={ "help": "Path to the dataset." "The dataset should follow the directory structure as described in https://github.com/txie-93/cgcnn" }, ) train_size: Optional[int] = field( default=None, metadata={"help": "Number of training data to be loaded."} ) val_size: Optional[int] = field( default=None, metadata={"help": "Number of validation data to be loaded."} ) test_size: Optional[int] = field( default=None, metadata={"help": "Number of testing data to be loaded."} )
[docs]@dataclass class CGCNNModelArguments(TrainingPipelineArguments): """Model arguments related to CGCNN trainer.""" __name__ = "model_args" atom_fea_len: int = field( default=64, metadata={"help": "Number of hidden atom features in conv layers."} ) h_fea_len: int = field( default=128, metadata={"help": "Number of hidden features after pooling."} ) n_conv: int = field(default=3, metadata={"help": "Number of conv layers."}) n_h: int = field( default=1, metadata={"help": "Number of hidden layers after pooling."} )
[docs]@dataclass class CGCNNTrainingArguments(TrainingPipelineArguments): """Training arguments related to CGCNN trainer.""" __name__ = "training_args" task: str = field( default="regression", metadata={"help": "Select the type of the task."}, ) output_path: str = field( default=".", metadata={"help": "Path to the store the checkpoints."}, ) disable_cuda: bool = field(default=False, metadata={"help": "Disable CUDA."}) workers: int = field( default=0, metadata={"help": "Number of data loading workers."} ) epochs: int = field(default=30, metadata={"help": "Number of total epochs to run."}) start_epoch: int = field( default=0, metadata={"help": "Manual epoch number (useful on restarts)."} ) batch_size: int = field(default=256, metadata={"help": "Mini-batch size."}) lr: float = field(default=0.01, metadata={"help": "Initial learning rate."}) lr_milestone: float = field( default=100, metadata={"help": "Milestone for scheduler."} ) momentum: float = field(default=0.9, metadata={"help": "Momentum."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay."}) print_freq: int = field(default=10, metadata={"help": "Print frequency."}) resume: str = field(default="", metadata={"help": "Path to latest checkpoint."}) optim: str = field(default="SGD", metadata={"help": "Optimizer."})
[docs]@dataclass class CGCNNSavingArguments(TrainingPipelineArguments): """Saving arguments related to CGCNN trainer.""" __name__ = "saving_args"