gt4sd.training_pipelines.cgcnn.core module¶
Cgcnn training utilities.
Summary¶
Classes:
Computes and stores the average and current value. |
|
Data arguments related to CGCNN trainer. |
|
Model arguments related to CGCNN trainer. |
|
Saving arguments related to CGCNN trainer. |
|
Training arguments related to CGCNN trainer. |
|
CGCNN training pipelines for crystals. |
Functions:
Class evaluation. |
|
Computes the mean absolute error between prediction and target. |
|
Save CGCNN checkpoint. |
|
Train step for cgcnn models. |
|
Validation step for cgcnn models. |
Reference¶
- class CGCNNTrainingPipeline[source]¶
Bases:
TrainingPipeline
CGCNN training pipelines for crystals.
- train(training_args, model_args, dataset_args)[source]¶
Generic training function for CGCNN models.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Raises
NotImplementedError – the generic trainer does not implement the pipeline.
- Return type
None
- __doc__ = 'CGCNN training pipelines for crystals.'¶
- __module__ = 'gt4sd.training_pipelines.cgcnn.core'¶
- train(train_loader, model, criterion, optimizer, epoch, normalizer, disable_cuda, task, print_freq)[source]¶
Train step for cgcnn models.
- Parameters
train_loader (
Union
[DataLoader
[Any
],Any
]) – Dataloader for the training set.model (
CrystalGraphConvNet
) – CGCNN model.criterion (
Union
[NLLLoss
,MSELoss
]) – loss function.optimizer (
Union
[SGD
,Adam
]) – Optimizer to be used.epoch (
int
) – Epoch number.normalizer (
Normalizer
) – Normalize.disable_cuda (
bool
) – Disable CUDA.task (
str
) – Training task.print_freq (
int
) – Print frequency.
- Return type
None
- validate(val_loader, model, criterion, normalizer, disable_cuda, task, print_freq, test=False)[source]¶
Validation step for cgcnn models.
- Parameters
val_loader (
Union
[DataLoader
[Any
],Any
]) – Dataloader for the validation set.model (
CrystalGraphConvNet
) – CGCNN model.criterion (
Union
[MSELoss
,NLLLoss
]) – loss function.normalizer (
Normalizer
) – Normalize.disable_cuda (
bool
) – Disable CUDA.task (
str
) – Training task.print_freq (
int
) – Print frequency.test (
bool
) – test or only validate using the given dataset.
- Return type
float
- Returns
average auc or mae depending on the training task.
- mae(prediction, target)[source]¶
Computes the mean absolute error between prediction and target.
- Parameters
prediction (
Tensor
) – torch.Tensor (N, 1)target (
Tensor
) – torch.Tensor (N, 1)
- Return type
Tensor
- Returns
the computed mean absolute error.
- class_eval(prediction, target)[source]¶
Class evaluation.
- Parameters
prediction (
Tensor
) – Predictions.target (
Tensor
) – Groundtruth.
- Return type
Tuple
[float
,float
,float
,float
,float
]- Returns
Computed accuracy, precision, recall, fscore, and auc_score.
- class AverageMeter[source]¶
Bases:
object
Computes and stores the average and current value.
- update(val, n=1)[source]¶
Update values of the AverageMeter.
- Parameters
val (
float
) – value to be added.n (
int
) – count.
- Return type
None
- __dict__ = mappingproxy({'__module__': 'gt4sd.training_pipelines.cgcnn.core', '__doc__': 'Computes and stores the average and current value.', '__init__': <function AverageMeter.__init__>, 'reset': <function AverageMeter.reset>, 'update': <function AverageMeter.update>, '__dict__': <attribute '__dict__' of 'AverageMeter' objects>, '__weakref__': <attribute '__weakref__' of 'AverageMeter' objects>, '__annotations__': {}})¶
- __doc__ = 'Computes and stores the average and current value.'¶
- __module__ = 'gt4sd.training_pipelines.cgcnn.core'¶
- __weakref__¶
list of weak references to the object (if defined)
- save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar')[source]¶
Save CGCNN checkpoint.
- Parameters
state (
object
) – checkpoint’s object.is_best (
bool
) – whether the given checkpoint has the best performance or not.path (
str
) – path to save the checkpoint.filename (
str
) – checkpoint’s filename.
- Return type
None
- class CGCNNDataArguments(datapath, train_size=None, val_size=None, test_size=None)[source]¶
Bases:
TrainingPipelineArguments
Data arguments related to CGCNN trainer.
- __name__ = 'CGCNNDataArguments'¶
- datapath: str¶
- train_size: Optional[int] = None¶
- val_size: Optional[int] = None¶
- __annotations__ = {'datapath': <class 'str'>, 'test_size': typing.Optional[int], 'train_size': typing.Optional[int], 'val_size': typing.Optional[int]}¶
- __dataclass_fields__ = {'datapath': Field(name='datapath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the dataset.The dataset should follow the directory structure as described in https://github.com/txie-93/cgcnn'}),kw_only=False,_field_type=_FIELD), 'test_size': Field(name='test_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of testing data to be loaded.'}),kw_only=False,_field_type=_FIELD), 'train_size': Field(name='train_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training data to be loaded.'}),kw_only=False,_field_type=_FIELD), 'val_size': Field(name='val_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of validation data to be loaded.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Data arguments related to CGCNN trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(datapath, train_size=None, val_size=None, test_size=None)¶
- __match_args__ = ('datapath', 'train_size', 'val_size', 'test_size')¶
- __module__ = 'gt4sd.training_pipelines.cgcnn.core'¶
- __repr__()¶
Return repr(self).
- class CGCNNModelArguments(atom_fea_len=64, h_fea_len=128, n_conv=3, n_h=1)[source]¶
Bases:
TrainingPipelineArguments
Model arguments related to CGCNN trainer.
- __name__ = 'CGCNNModelArguments'¶
- atom_fea_len: int = 64¶
- h_fea_len: int = 128¶
- n_conv: int = 3¶
- n_h: int = 1¶
- __annotations__ = {'atom_fea_len': <class 'int'>, 'h_fea_len': <class 'int'>, 'n_conv': <class 'int'>, 'n_h': <class 'int'>}¶
- __dataclass_fields__ = {'atom_fea_len': Field(name='atom_fea_len',type=<class 'int'>,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of hidden atom features in conv layers.'}),kw_only=False,_field_type=_FIELD), 'h_fea_len': Field(name='h_fea_len',type=<class 'int'>,default=128,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of hidden features after pooling.'}),kw_only=False,_field_type=_FIELD), 'n_conv': Field(name='n_conv',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of conv layers.'}),kw_only=False,_field_type=_FIELD), 'n_h': Field(name='n_h',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of hidden layers after pooling.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Model arguments related to CGCNN trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(atom_fea_len=64, h_fea_len=128, n_conv=3, n_h=1)¶
- __match_args__ = ('atom_fea_len', 'h_fea_len', 'n_conv', 'n_h')¶
- __module__ = 'gt4sd.training_pipelines.cgcnn.core'¶
- __repr__()¶
Return repr(self).
- class CGCNNTrainingArguments(task='regression', output_path='.', disable_cuda=False, workers=0, epochs=30, start_epoch=0, batch_size=256, lr=0.01, lr_milestone=100, momentum=0.9, weight_decay=0.0, print_freq=10, resume='', optim='SGD')[source]¶
Bases:
TrainingPipelineArguments
Training arguments related to CGCNN trainer.
- __name__ = 'CGCNNTrainingArguments'¶
- task: str = 'regression'¶
- output_path: str = '.'¶
- disable_cuda: bool = False¶
- workers: int = 0¶
- epochs: int = 30¶
- start_epoch: int = 0¶
- batch_size: int = 256¶
- lr: float = 0.01¶
- lr_milestone: float = 100¶
- momentum: float = 0.9¶
- weight_decay: float = 0.0¶
- print_freq: int = 10¶
- resume: str = ''¶
- optim: str = 'SGD'¶
- __annotations__ = {'batch_size': <class 'int'>, 'disable_cuda': <class 'bool'>, 'epochs': <class 'int'>, 'lr': <class 'float'>, 'lr_milestone': <class 'float'>, 'momentum': <class 'float'>, 'optim': <class 'str'>, 'output_path': <class 'str'>, 'print_freq': <class 'int'>, 'resume': <class 'str'>, 'start_epoch': <class 'int'>, 'task': <class 'str'>, 'weight_decay': <class 'float'>, 'workers': <class 'int'>}¶
- __dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Mini-batch size.'}),kw_only=False,_field_type=_FIELD), 'disable_cuda': Field(name='disable_cuda',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Disable CUDA.'}),kw_only=False,_field_type=_FIELD), 'epochs': Field(name='epochs',type=<class 'int'>,default=30,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of total epochs to run.'}),kw_only=False,_field_type=_FIELD), 'lr': Field(name='lr',type=<class 'float'>,default=0.01,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Initial learning rate.'}),kw_only=False,_field_type=_FIELD), 'lr_milestone': Field(name='lr_milestone',type=<class 'float'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Milestone for scheduler.'}),kw_only=False,_field_type=_FIELD), 'momentum': Field(name='momentum',type=<class 'float'>,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Momentum.'}),kw_only=False,_field_type=_FIELD), 'optim': Field(name='optim',type=<class 'str'>,default='SGD',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer.'}),kw_only=False,_field_type=_FIELD), 'output_path': Field(name='output_path',type=<class 'str'>,default='.',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the store the checkpoints.'}),kw_only=False,_field_type=_FIELD), 'print_freq': Field(name='print_freq',type=<class 'int'>,default=10,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Print frequency.'}),kw_only=False,_field_type=_FIELD), 'resume': Field(name='resume',type=<class 'str'>,default='',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to latest checkpoint.'}),kw_only=False,_field_type=_FIELD), 'start_epoch': Field(name='start_epoch',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Manual epoch number (useful on restarts).'}),kw_only=False,_field_type=_FIELD), 'task': Field(name='task',type=<class 'str'>,default='regression',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Select the type of the task.'}),kw_only=False,_field_type=_FIELD), 'weight_decay': Field(name='weight_decay',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Weight decay.'}),kw_only=False,_field_type=_FIELD), 'workers': Field(name='workers',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of data loading workers.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Training arguments related to CGCNN trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(task='regression', output_path='.', disable_cuda=False, workers=0, epochs=30, start_epoch=0, batch_size=256, lr=0.01, lr_milestone=100, momentum=0.9, weight_decay=0.0, print_freq=10, resume='', optim='SGD')¶
- __match_args__ = ('task', 'output_path', 'disable_cuda', 'workers', 'epochs', 'start_epoch', 'batch_size', 'lr', 'lr_milestone', 'momentum', 'weight_decay', 'print_freq', 'resume', 'optim')¶
- __module__ = 'gt4sd.training_pipelines.cgcnn.core'¶
- __repr__()¶
Return repr(self).
- class CGCNNSavingArguments[source]¶
Bases:
TrainingPipelineArguments
Saving arguments related to CGCNN trainer.
- __name__ = 'CGCNNSavingArguments'¶
- __annotations__ = {}¶
- __dataclass_fields__ = {}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Saving arguments related to CGCNN trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__()¶
- __match_args__ = ()¶
- __module__ = 'gt4sd.training_pipelines.cgcnn.core'¶
- __repr__()¶
Return repr(self).