gt4sd.training_pipelines.pytorch_lightning.granular.core module

Granular training utilities.

Summary

Classes:

GranularDataArguments

Arguments related to data.

GranularModelArguments

Arguments related to model.

GranularPytorchLightningTrainingArguments

Arguments related to pytorch lightning trainer.

GranularSavingArguments

Saving arguments related to Granular trainer.

GranularTrainingPipeline

Granular training pipelines.

Reference

class GranularTrainingPipeline[source]

Bases: PyTorchLightningTrainingPipeline

Granular training pipelines.

get_data_and_model_modules(model_args, dataset_args, **kwargs)[source]

Get data and model modules for training.

Parameters
  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

Tuple[LightningDataModule, LightningModule]

Returns

the data and model modules.

__annotations__ = {}
__doc__ = 'Granular training pipelines.'
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'
class GranularPytorchLightningTrainingArguments(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple')[source]

Bases: PytorchLightningTrainingArguments

Arguments related to pytorch lightning trainer.

__name__ = 'GranularPytorchLightningTrainingArguments'
check_val_every_n_epoch: Optional[int] = 5
auto_lr_find: bool = True
profiler: Optional[str] = 'simple'
__annotations__ = {'accumulate_grad_batches': 'int', 'auto_lr_find': <class 'bool'>, 'basename': 'Optional[str]', 'check_val_every_n_epoch': typing.Optional[int], 'every_n_epochs': 'Optional[int]', 'every_n_train_steps': 'Optional[int]', 'gpus': 'Optional[int]', 'gradient_clip_val': 'float', 'limit_val_batches': 'int', 'log_every_n_steps': 'int', 'max_epochs': 'int', 'mode': 'str', 'monitor': 'Optional[str]', 'profiler': typing.Optional[str], 'resume_from_checkpoint': 'Optional[str]', 'save_dir': 'Optional[str]', 'save_last': 'Optional[bool]', 'save_top_k': 'int', 'strategy': 'Optional[str]', 'val_check_interval': 'int'}
__dataclass_fields__ = {'accumulate_grad_batches': Field(name='accumulate_grad_batches',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Accumulates grads every k batches or as set up in the dict.'}),kw_only=False,_field_type=_FIELD), 'auto_lr_find': Field(name='auto_lr_find',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Select whether to run a learning rate finder to try to optimize initial learning for faster convergence.'}),kw_only=False,_field_type=_FIELD), 'basename': Field(name='basename',type=typing.Optional[str],default='lightning_logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Experiment name.'}),kw_only=False,_field_type=_FIELD), 'check_val_every_n_epoch': Field(name='check_val_every_n_epoch',type=typing.Optional[int],default=5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'every_n_epochs': Field(name='every_n_epochs',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'every_n_train_steps': Field(name='every_n_train_steps',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training steps between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'gpus': Field(name='gpus',type=typing.Optional[int],default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of gpus to train on.'}),kw_only=False,_field_type=_FIELD), 'gradient_clip_val': Field(name='gradient_clip_val',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradient clipping value.'}),kw_only=False,_field_type=_FIELD), 'limit_val_batches': Field(name='limit_val_batches',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How much of validation dataset to check.'}),kw_only=False,_field_type=_FIELD), 'log_every_n_steps': Field(name='log_every_n_steps',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to log within steps.'}),kw_only=False,_field_type=_FIELD), 'max_epochs': Field(name='max_epochs',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stop training once this number of epochs is reached.'}),kw_only=False,_field_type=_FIELD), 'mode': Field(name='mode',type=<class 'str'>,default='min',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'monitor': Field(name='monitor',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'profiler': Field(name='profiler',type=typing.Optional[str],default='simple',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'To profile individual steps during training and assist in identifying bottlenecks.'}),kw_only=False,_field_type=_FIELD), 'resume_from_checkpoint': Field(name='resume_from_checkpoint',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path/URL of the checkpoint from which training is resumed.'}),kw_only=False,_field_type=_FIELD), 'save_dir': Field(name='save_dir',type=typing.Optional[str],default='logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Save directory for logs and output.'}),kw_only=False,_field_type=_FIELD), 'save_last': Field(name='save_last',type=typing.Optional[bool],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'When True, always saves the model at the end of the epoch to a file last.ckpt'}),kw_only=False,_field_type=_FIELD), 'save_top_k': Field(name='save_top_k',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The best k models according to the quantity monitored will be saved.'}),kw_only=False,_field_type=_FIELD), 'strategy': Field(name='strategy',type=typing.Optional[str],default='ddp',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training strategy.'}),kw_only=False,_field_type=_FIELD), 'val_check_interval': Field(name='val_check_interval',type=<class 'int'>,default=5000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': ' How often to check the validation set.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = '\n    Arguments related to pytorch lightning trainer.\n    '
__eq__(other)

Return self==value.

__hash__ = None
__init__(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple')
__match_args__ = ('strategy', 'accumulate_grad_batches', 'val_check_interval', 'save_dir', 'basename', 'gradient_clip_val', 'limit_val_batches', 'log_every_n_steps', 'max_epochs', 'resume_from_checkpoint', 'gpus', 'monitor', 'save_last', 'save_top_k', 'mode', 'every_n_train_steps', 'every_n_epochs', 'check_val_every_n_epoch', 'auto_lr_find', 'profiler')
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'
__repr__()

Return repr(self).

class GranularModelArguments(model_list_path=None, lr=0.0001, test_output_path='./test')[source]

Bases: TrainingPipelineArguments

Arguments related to model.

__name__ = 'GranularModelArguments'
model_list_path: Optional[str] = None
lr: float = 0.0001
__annotations__ = {'lr': <class 'float'>, 'model_list_path': typing.Optional[str], 'test_output_path': typing.Optional[str]}
__dataclass_fields__ = {'lr': Field(name='lr',type=<class 'float'>,default=0.0001,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate.'}),kw_only=False,_field_type=_FIELD), 'model_list_path': Field(name='model_list_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to a json file that contains a dictionary with models and their parameters.If it is not provided, then the dictionary is searched in the given config file.'}),kw_only=False,_field_type=_FIELD), 'test_output_path': Field(name='test_output_path',type=typing.Optional[str],default='./test',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Path where to save latent encodings and predictions for the test set when an epoch ends. Defaults to a a folder called 'test' in the current working directory."}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = '\n    Arguments related to model.\n    '
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_list_path=None, lr=0.0001, test_output_path='./test')
__match_args__ = ('model_list_path', 'lr', 'test_output_path')
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'
__repr__()

Return repr(self).

class GranularDataArguments(batch_size=64, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_workers=1)[source]

Bases: TrainingPipelineArguments

Arguments related to data.

__name__ = 'GranularDataArguments'
batch_size: int = 64
validation_split: Optional[float] = None
validation_indices_file: Optional[str] = None
stratified_batch_file: Optional[str] = None
stratified_value_name: Optional[str] = None
num_workers: int = 1
__annotations__ = {'batch_size': <class 'int'>, 'num_workers': <class 'int'>, 'stratified_batch_file': typing.Optional[str], 'stratified_value_name': typing.Optional[str], 'validation_indices_file': typing.Optional[str], 'validation_split': typing.Optional[float]}
__dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Batch size of the training. Defaults to 64.'}),kw_only=False,_field_type=_FIELD), 'num_workers': Field(name='num_workers',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'number of workers. Defaults to 1.'}),kw_only=False,_field_type=_FIELD), 'stratified_batch_file': Field(name='stratified_batch_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified batch file for sampling. Defaults to None, a.k.a., no stratified sampling.'}),kw_only=False,_field_type=_FIELD), 'stratified_value_name': Field(name='stratified_value_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified value name. Defaults to None, a.k.a., no stratified sampling. Needed in case a stratified batch file is provided.'}),kw_only=False,_field_type=_FIELD), 'validation_indices_file': Field(name='validation_indices_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Indices to use for validation. Defaults to None, a.k.a., use validation split proportion, if not provided uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD), 'validation_split': Field(name='validation_split',type=typing.Optional[float],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Proportion used for validation. Defaults to None, a.k.a., use indices file if provided otherwise uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = '\n    Arguments related to data.\n    '
__eq__(other)

Return self==value.

__hash__ = None
__init__(batch_size=64, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_workers=1)
__match_args__ = ('batch_size', 'validation_split', 'validation_indices_file', 'stratified_batch_file', 'stratified_value_name', 'num_workers')
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'
__repr__()

Return repr(self).

class GranularSavingArguments(model_path)[source]

Bases: TrainingPipelineArguments

Saving arguments related to Granular trainer.

__name__ = 'GranularSavingArguments'
model_path: str
__annotations__ = {'model_path': <class 'str'>}
__dataclass_fields__ = {'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the checkpoint file to be used.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to Granular trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path)
__match_args__ = ('model_path',)
__module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'
__repr__()

Return repr(self).