Source code for gt4sd.properties.crystals.core

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import argparse
import os
from typing import Dict, List

import pandas as pd
import torch
from pydantic import Field
from torch.autograd import Variable
from torch.utils.data import DataLoader

from ...algorithms.core import (
    ConfigurablePropertyAlgorithmConfiguration,
    Predictor,
    PredictorAlgorithm,
)
from ...frameworks.cgcnn.data import AtomCustomJSONInitializer, CIFData, collate_pool
from ...frameworks.cgcnn.model import CrystalGraphConvNet, Normalizer
from ...frameworks.crystals_rfc.feature_engine import Features
from ...frameworks.crystals_rfc.rf_classifier import RFC
from ..core import DomainSubmodule, S3Parameters


[docs]class S3ParametersCrystals(S3Parameters): domain: DomainSubmodule = DomainSubmodule("crystals")
[docs]class CGCNNParameters(S3ParametersCrystals): algorithm_name: str = "cgcnn" batch_size: int = Field(description="Prediction batch size", default=256) workers: int = Field(description="Number of data loading workers", default=0)
[docs]class MetalNonMetalClassifierParameters(S3ParametersCrystals): algorithm_name: str = "RFC" algorithm_application: str = "MetalNonMetalClassifier"
[docs]class FormationEnergyParameters(CGCNNParameters): algorithm_application: str = "FormationEnergy"
[docs]class AbsoluteEnergyParameters(CGCNNParameters): algorithm_application: str = "AbsoluteEnergy"
[docs]class BandGapParameters(CGCNNParameters): algorithm_application: str = "BandGap"
[docs]class FermiEnergyParameters(CGCNNParameters): algorithm_application: str = "FermiEnergy"
[docs]class BulkModuliParameters(CGCNNParameters): algorithm_application: str = "BulkModuli"
[docs]class ShearModuliParameters(CGCNNParameters): algorithm_application: str = "ShearModuli"
[docs]class PoissonRatioParameters(CGCNNParameters): algorithm_application: str = "PoissonRatio"
[docs]class MetalSemiconductorClassifierParameters(CGCNNParameters): algorithm_application: str = "MetalSemiconductorClassifier"
class _CGCNN(PredictorAlgorithm): """Base class for all cgcnn-based predictive algorithms.""" def __init__(self, parameters: CGCNNParameters): # Set up the configuration from the parameters configuration = ConfigurablePropertyAlgorithmConfiguration( algorithm_type=parameters.algorithm_type, domain=parameters.domain, algorithm_name=parameters.algorithm_name, algorithm_application=parameters.algorithm_application, algorithm_version=parameters.algorithm_version, ) self.batch_size = parameters.batch_size self.workers = parameters.workers # The parent constructor calls `self.get_model`. super().__init__(configuration=configuration) def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ existing_models = os.listdir(resources_path) existing_models = [ file for file in existing_models if file.endswith(".pth.tar") ] if len(existing_models) > 1: raise ValueError( "Only one model should be located in the specified model path." ) elif len(existing_models) == 0: raise ValueError("Model does not exist in the specified model path.") model_path = os.path.join(resources_path, existing_models[0]) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model_args = argparse.Namespace(**checkpoint["args"]) normalizer = Normalizer(torch.zeros(3)) normalizer.load_state_dict(checkpoint["normalizer"]) atom_initialization = AtomCustomJSONInitializer( os.path.join(resources_path, "atom_init.json") ) # Wrapper to get the predictions def informative_model(cif_path: str) -> Dict[str, List[float]]: dataset = CIFData(cif_path, atom_initialization=atom_initialization) test_loader = DataLoader( dataset, batch_size=self.batch_size, num_workers=self.workers, collate_fn=collate_pool, ) # build model structures, _, _ = dataset[0] orig_atom_fea_len = structures[0].shape[-1] nbr_fea_len = structures[1].shape[-1] # type: ignore model = CrystalGraphConvNet( orig_atom_fea_len, nbr_fea_len, atom_fea_len=model_args.atom_fea_len, n_conv=model_args.n_conv, h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == "classification" else False, ) model.load_state_dict(checkpoint["state_dict"]) model.eval() test_preds = [] test_cif_ids = [] for input, _, batch_cif_ids in test_loader: with torch.no_grad(): input_var = ( Variable(input[0]), Variable(input[1]), input[2], input[3], ) # compute output output = model(*input_var) # record loss if model_args.task == "classification": test_pred = torch.exp(output.data.cpu()) test_preds += test_pred[:, 1].tolist() else: test_pred = normalizer.denorm(output.data.cpu()) test_preds += test_pred.view(-1).tolist() test_cif_ids += batch_cif_ids return {"cif_ids": test_cif_ids, "predictions": test_preds} # type: ignore return informative_model
[docs]class MetalNonMetalClassifier(PredictorAlgorithm): """Metal/non-metal classifier class."""
[docs] def __init__(self, parameters: MetalNonMetalClassifierParameters): # Set up the configuration from the parameters configuration = ConfigurablePropertyAlgorithmConfiguration( algorithm_type=parameters.algorithm_type, domain=parameters.domain, algorithm_name=parameters.algorithm_name, algorithm_application=parameters.algorithm_application, algorithm_version=parameters.algorithm_version, ) # The parent constructor calls `self.get_model`. super().__init__(configuration=configuration)
[docs] def get_model(self, resources_path: str) -> Predictor: """Instantiate the actual model. Args: resources_path: local path to model files. Returns: Predictor: the model. """ rfc = RFC() rfc.load_model(resources_path) # Wrapper to get the predictions def informative_model(formula_file: str) -> Dict[str, List[str]]: # getting the feature feature_eng = Features(formula_file=formula_file) features = feature_eng.get_features() # getting the targets and symmetries df = pd.DataFrame(features) pred_x = df.iloc[:, 1:].values # getting the chemical formulas df_mat = pd.read_csv(formula_file, header=None) formulas = df_mat.iloc[:, 0].to_list() predictions = rfc.predict(pred_x=pred_x) return {"formulas": formulas, "predictions": predictions} return informative_model
[docs]class FormationEnergy(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the formation energy per atom using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class AbsoluteEnergy(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the absolute energy of crystals using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class BandGap(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the band gap of crystals using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class FermiEnergy(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the Fermi energy of crystals using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class BulkModuli(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the bulk moduli of crystals using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class ShearModuli(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the shear moduli of crystals using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class PoissonRatio(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts the Poisson ratio of crystals using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text
[docs]class MetalSemiconductorClassifier(_CGCNN):
[docs] @classmethod def get_description(cls) -> str: text = """ This model predicts whether a given crystal is metal or semiconductor using the CGCNN framework. For more details see: https://doi.org/10.1103/PhysRevLett.120.145301. """ return text