gt4sd.training_pipelines.cgcnn.core module

Cgcnn training utilities.

Summary

Classes:

AverageMeter

Computes and stores the average and current value.

CGCNNDataArguments

Data arguments related to CGCNN trainer.

CGCNNModelArguments

Model arguments related to CGCNN trainer.

CGCNNSavingArguments

Saving arguments related to CGCNN trainer.

CGCNNTrainingArguments

Training arguments related to CGCNN trainer.

CGCNNTrainingPipeline

CGCNN training pipelines for crystals.

Functions:

class_eval

Class evaluation.

mae

Computes the mean absolute error between prediction and target.

save_checkpoint

Save CGCNN checkpoint.

train

Train step for cgcnn models.

validate

Validation step for cgcnn models.

Reference

class CGCNNTrainingPipeline[source]

Bases: TrainingPipeline

CGCNN training pipelines for crystals.

train(training_args, model_args, dataset_args)[source]

Generic training function for CGCNN models.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Raises

NotImplementedError – the generic trainer does not implement the pipeline.

Return type

None

__doc__ = 'CGCNN training pipelines for crystals.'
__module__ = 'gt4sd.training_pipelines.cgcnn.core'
train(train_loader, model, criterion, optimizer, epoch, normalizer, disable_cuda, task, print_freq)[source]

Train step for cgcnn models.

Parameters
  • train_loader (Union[DataLoader[Any], Any]) – Dataloader for the training set.

  • model (CrystalGraphConvNet) – CGCNN model.

  • criterion (Union[NLLLoss, MSELoss]) – loss function.

  • optimizer (Union[SGD, Adam]) – Optimizer to be used.

  • epoch (int) – Epoch number.

  • normalizer (Normalizer) – Normalize.

  • disable_cuda (bool) – Disable CUDA.

  • task (str) – Training task.

  • print_freq (int) – Print frequency.

Return type

None

validate(val_loader, model, criterion, normalizer, disable_cuda, task, print_freq, test=False)[source]

Validation step for cgcnn models.

Parameters
  • val_loader (Union[DataLoader[Any], Any]) – Dataloader for the validation set.

  • model (CrystalGraphConvNet) – CGCNN model.

  • criterion (Union[MSELoss, NLLLoss]) – loss function.

  • normalizer (Normalizer) – Normalize.

  • disable_cuda (bool) – Disable CUDA.

  • task (str) – Training task.

  • print_freq (int) – Print frequency.

  • test (bool) – test or only validate using the given dataset.

Return type

float

Returns

average auc or mae depending on the training task.

mae(prediction, target)[source]

Computes the mean absolute error between prediction and target.

Parameters
  • prediction (Tensor) – torch.Tensor (N, 1)

  • target (Tensor) – torch.Tensor (N, 1)

Return type

Tensor

Returns

the computed mean absolute error.

class_eval(prediction, target)[source]

Class evaluation.

Parameters
  • prediction (Tensor) – Predictions.

  • target (Tensor) – Groundtruth.

Return type

Tuple[float, float, float, float, float]

Returns

Computed accuracy, precision, recall, fscore, and auc_score.

class AverageMeter[source]

Bases: object

Computes and stores the average and current value.

__init__()[source]

Initialize an AverageMeter object.

reset()[source]

Reset values to 0.

Return type

None

update(val, n=1)[source]

Update values of the AverageMeter.

Parameters
  • val (float) – value to be added.

  • n (int) – count.

Return type

None

__dict__ = mappingproxy({'__module__': 'gt4sd.training_pipelines.cgcnn.core', '__doc__': 'Computes and stores the average and current value.', '__init__': <function AverageMeter.__init__>, 'reset': <function AverageMeter.reset>, 'update': <function AverageMeter.update>, '__dict__': <attribute '__dict__' of 'AverageMeter' objects>, '__weakref__': <attribute '__weakref__' of 'AverageMeter' objects>, '__annotations__': {}})
__doc__ = 'Computes and stores the average and current value.'
__module__ = 'gt4sd.training_pipelines.cgcnn.core'
__weakref__

list of weak references to the object (if defined)

save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar')[source]

Save CGCNN checkpoint.

Parameters
  • state (object) – checkpoint’s object.

  • is_best (bool) – whether the given checkpoint has the best performance or not.

  • path (str) – path to save the checkpoint.

  • filename (str) – checkpoint’s filename.

Return type

None

class CGCNNDataArguments(datapath, train_size=None, val_size=None, test_size=None)[source]

Bases: TrainingPipelineArguments

Data arguments related to CGCNN trainer.

__name__ = 'CGCNNDataArguments'
datapath: str
train_size: Optional[int] = None
val_size: Optional[int] = None
__annotations__ = {'datapath': <class 'str'>, 'test_size': typing.Optional[int], 'train_size': typing.Optional[int], 'val_size': typing.Optional[int]}
__dataclass_fields__ = {'datapath': Field(name='datapath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the dataset.The dataset should follow the directory structure as described in https://github.com/txie-93/cgcnn'}),kw_only=False,_field_type=_FIELD), 'test_size': Field(name='test_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of testing data to be loaded.'}),kw_only=False,_field_type=_FIELD), 'train_size': Field(name='train_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training data to be loaded.'}),kw_only=False,_field_type=_FIELD), 'val_size': Field(name='val_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of validation data to be loaded.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Data arguments related to CGCNN trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(datapath, train_size=None, val_size=None, test_size=None)
__match_args__ = ('datapath', 'train_size', 'val_size', 'test_size')
__module__ = 'gt4sd.training_pipelines.cgcnn.core'
__repr__()

Return repr(self).

class CGCNNModelArguments(atom_fea_len=64, h_fea_len=128, n_conv=3, n_h=1)[source]

Bases: TrainingPipelineArguments

Model arguments related to CGCNN trainer.

__name__ = 'CGCNNModelArguments'
atom_fea_len: int = 64
h_fea_len: int = 128
n_conv: int = 3
n_h: int = 1
__annotations__ = {'atom_fea_len': <class 'int'>, 'h_fea_len': <class 'int'>, 'n_conv': <class 'int'>, 'n_h': <class 'int'>}
__dataclass_fields__ = {'atom_fea_len': Field(name='atom_fea_len',type=<class 'int'>,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of hidden atom features in conv layers.'}),kw_only=False,_field_type=_FIELD), 'h_fea_len': Field(name='h_fea_len',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of hidden features after pooling.'}),kw_only=False,_field_type=_FIELD), 'n_conv': Field(name='n_conv',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of conv layers.'}),kw_only=False,_field_type=_FIELD), 'n_h': Field(name='n_h',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of hidden layers after pooling.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Model arguments related to CGCNN trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(atom_fea_len=64, h_fea_len=128, n_conv=3, n_h=1)
__match_args__ = ('atom_fea_len', 'h_fea_len', 'n_conv', 'n_h')
__module__ = 'gt4sd.training_pipelines.cgcnn.core'
__repr__()

Return repr(self).

class CGCNNTrainingArguments(task='regression', output_path='.', disable_cuda=False, workers=0, epochs=30, start_epoch=0, batch_size=256, lr=0.01, lr_milestone=100, momentum=0.9, weight_decay=0.0, print_freq=10, resume='', optim='SGD')[source]

Bases: TrainingPipelineArguments

Training arguments related to CGCNN trainer.

__name__ = 'CGCNNTrainingArguments'
task: str = 'regression'
output_path: str = '.'
disable_cuda: bool = False
workers: int = 0
epochs: int = 30
start_epoch: int = 0
batch_size: int = 256
lr: float = 0.01
lr_milestone: float = 100
momentum: float = 0.9
weight_decay: float = 0.0
print_freq: int = 10
resume: str = ''
optim: str = 'SGD'
__annotations__ = {'batch_size': <class 'int'>, 'disable_cuda': <class 'bool'>, 'epochs': <class 'int'>, 'lr': <class 'float'>, 'lr_milestone': <class 'float'>, 'momentum': <class 'float'>, 'optim': <class 'str'>, 'output_path': <class 'str'>, 'print_freq': <class 'int'>, 'resume': <class 'str'>, 'start_epoch': <class 'int'>, 'task': <class 'str'>, 'weight_decay': <class 'float'>, 'workers': <class 'int'>}
__dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Mini-batch size.'}),kw_only=False,_field_type=_FIELD), 'disable_cuda': Field(name='disable_cuda',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Disable CUDA.'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=30,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of total epochs to run.'}),kw_only=False,_field_type=_FIELD), 'lr': Field(name='lr',type=<class 'float'>,default=0.01,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial learning rate.'}),kw_only=False,_field_type=_FIELD), 'lr_milestone': Field(name='lr_milestone',type=<class 'float'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Milestone for scheduler.'}),kw_only=False,_field_type=_FIELD), 'momentum': Field(name='momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Momentum.'}),kw_only=False,_field_type=_FIELD), 'optim': Field(name='optim',type=<class 'str'>,default='SGD',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer.'}),kw_only=False,_field_type=_FIELD), 'output_path': Field(name='output_path',type=<class 'str'>,default='.',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the store the checkpoints.'}),kw_only=False,_field_type=_FIELD), 'print_freq': Field(name='print_freq',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Print frequency.'}),kw_only=False,_field_type=_FIELD), 'resume': Field(name='resume',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to latest checkpoint.'}),kw_only=False,_field_type=_FIELD), 'start_epoch': Field(name='start_epoch',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Manual epoch number (useful on restarts).'}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=<class 'str'>,default='regression',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Select the type of the task.'}),kw_only=False,_field_type=_FIELD), 'weight_decay': Field(name='weight_decay',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Weight decay.'}),kw_only=False,_field_type=_FIELD), 'workers': Field(name='workers',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of data loading workers.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Training arguments related to CGCNN trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(task='regression', output_path='.', disable_cuda=False, workers=0, epochs=30, start_epoch=0, batch_size=256, lr=0.01, lr_milestone=100, momentum=0.9, weight_decay=0.0, print_freq=10, resume='', optim='SGD')
__match_args__ = ('task', 'output_path', 'disable_cuda', 'workers', 'epochs', 'start_epoch', 'batch_size', 'lr', 'lr_milestone', 'momentum', 'weight_decay', 'print_freq', 'resume', 'optim')
__module__ = 'gt4sd.training_pipelines.cgcnn.core'
__repr__()

Return repr(self).

class CGCNNSavingArguments[source]

Bases: TrainingPipelineArguments

Saving arguments related to CGCNN trainer.

__name__ = 'CGCNNSavingArguments'
__annotations__ = {}
__dataclass_fields__ = {}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to CGCNN trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__()
__match_args__ = ()
__module__ = 'gt4sd.training_pipelines.cgcnn.core'
__repr__()

Return repr(self).