gt4sd.training_pipelines.pytorch_lightning.granular.core module¶
Granular training utilities.
Summary¶
Classes:
Arguments related to data. |
|
Arguments related to model. |
|
Arguments related to pytorch lightning trainer. |
|
Saving arguments related to Granular trainer. |
|
Granular training pipelines. |
Reference¶
- class GranularTrainingPipeline[source]¶
Bases:
PyTorchLightningTrainingPipeline
Granular training pipelines.
- get_data_and_model_modules(model_args, dataset_args, **kwargs)[source]¶
Get data and model modules for training.
- Parameters
model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
Tuple
[LightningDataModule
,LightningModule
]- Returns
the data and model modules.
- __annotations__ = {}¶
- __doc__ = 'Granular training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'¶
- class GranularPytorchLightningTrainingArguments(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple')[source]¶
Bases:
PytorchLightningTrainingArguments
Arguments related to pytorch lightning trainer.
- __name__ = 'GranularPytorchLightningTrainingArguments'¶
- check_val_every_n_epoch: Optional[int] = 5¶
- auto_lr_find: bool = True¶
- profiler: Optional[str] = 'simple'¶
- __annotations__ = {'accumulate_grad_batches': 'int', 'auto_lr_find': <class 'bool'>, 'basename': 'Optional[str]', 'check_val_every_n_epoch': typing.Optional[int], 'every_n_epochs': 'Optional[int]', 'every_n_train_steps': 'Optional[int]', 'gpus': 'Optional[int]', 'gradient_clip_val': 'float', 'limit_val_batches': 'int', 'log_every_n_steps': 'int', 'max_epochs': 'int', 'mode': 'str', 'monitor': 'Optional[str]', 'profiler': typing.Optional[str], 'resume_from_checkpoint': 'Optional[str]', 'save_dir': 'Optional[str]', 'save_last': 'Optional[bool]', 'save_top_k': 'int', 'strategy': 'Optional[str]', 'val_check_interval': 'int'}¶
- __dataclass_fields__ = {'accumulate_grad_batches': Field(name='accumulate_grad_batches',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Accumulates grads every k batches or as set up in the dict.'}),kw_only=False,_field_type=_FIELD), 'auto_lr_find': Field(name='auto_lr_find',type=<class 'bool'>,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Select whether to run a learning rate finder to try to optimize initial learning for faster convergence.'}),kw_only=False,_field_type=_FIELD), 'basename': Field(name='basename',type=typing.Optional[str],default='lightning_logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Experiment name.'}),kw_only=False,_field_type=_FIELD), 'check_val_every_n_epoch': Field(name='check_val_every_n_epoch',type=typing.Optional[int],default=5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'every_n_epochs': Field(name='every_n_epochs',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of epochs between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'every_n_train_steps': Field(name='every_n_train_steps',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of training steps between checkpoints.'}),kw_only=False,_field_type=_FIELD), 'gpus': Field(name='gpus',type=typing.Optional[int],default=-1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of gpus to train on.'}),kw_only=False,_field_type=_FIELD), 'gradient_clip_val': Field(name='gradient_clip_val',type=<class 'float'>,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Gradient clipping value.'}),kw_only=False,_field_type=_FIELD), 'limit_val_batches': Field(name='limit_val_batches',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How much of validation dataset to check.'}),kw_only=False,_field_type=_FIELD), 'log_every_n_steps': Field(name='log_every_n_steps',type=<class 'int'>,default=500,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'How often to log within steps.'}),kw_only=False,_field_type=_FIELD), 'max_epochs': Field(name='max_epochs',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stop training once this number of epochs is reached.'}),kw_only=False,_field_type=_FIELD), 'mode': Field(name='mode',type=<class 'str'>,default='min',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'monitor': Field(name='monitor',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Quantity to monitor in order to store a checkpoint.'}),kw_only=False,_field_type=_FIELD), 'profiler': Field(name='profiler',type=typing.Optional[str],default='simple',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'To profile individual steps during training and assist in identifying bottlenecks.'}),kw_only=False,_field_type=_FIELD), 'resume_from_checkpoint': Field(name='resume_from_checkpoint',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path/URL of the checkpoint from which training is resumed.'}),kw_only=False,_field_type=_FIELD), 'save_dir': Field(name='save_dir',type=typing.Optional[str],default='logs',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Save directory for logs and output.'}),kw_only=False,_field_type=_FIELD), 'save_last': Field(name='save_last',type=typing.Optional[bool],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'When True, always saves the model at the end of the epoch to a file last.ckpt'}),kw_only=False,_field_type=_FIELD), 'save_top_k': Field(name='save_top_k',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The best k models according to the quantity monitored will be saved.'}),kw_only=False,_field_type=_FIELD), 'strategy': Field(name='strategy',type=typing.Optional[str],default='ddp',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training strategy.'}),kw_only=False,_field_type=_FIELD), 'val_check_interval': Field(name='val_check_interval',type=<class 'int'>,default=5000,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': ' How often to check the validation set.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = '\n Arguments related to pytorch lightning trainer.\n '¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(strategy='ddp', accumulate_grad_batches=1, val_check_interval=5000, save_dir='logs', basename='lightning_logs', gradient_clip_val=0.0, limit_val_batches=500, log_every_n_steps=500, max_epochs=3, resume_from_checkpoint=None, gpus=-1, monitor=None, save_last=None, save_top_k=1, mode='min', every_n_train_steps=None, every_n_epochs=None, check_val_every_n_epoch=5, auto_lr_find=True, profiler='simple')¶
- __match_args__ = ('strategy', 'accumulate_grad_batches', 'val_check_interval', 'save_dir', 'basename', 'gradient_clip_val', 'limit_val_batches', 'log_every_n_steps', 'max_epochs', 'resume_from_checkpoint', 'gpus', 'monitor', 'save_last', 'save_top_k', 'mode', 'every_n_train_steps', 'every_n_epochs', 'check_val_every_n_epoch', 'auto_lr_find', 'profiler')¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'¶
- __repr__()¶
Return repr(self).
- class GranularModelArguments(model_list_path=None, lr=0.0001, test_output_path='./test')[source]¶
Bases:
TrainingPipelineArguments
Arguments related to model.
- __name__ = 'GranularModelArguments'¶
- model_list_path: Optional[str] = None¶
- lr: float = 0.0001¶
- __annotations__ = {'lr': <class 'float'>, 'model_list_path': typing.Optional[str], 'test_output_path': typing.Optional[str]}¶
- __dataclass_fields__ = {'lr': Field(name='lr',type=<class 'float'>,default=0.0001,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The learning rate.'}),kw_only=False,_field_type=_FIELD), 'model_list_path': Field(name='model_list_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to a json file that contains a dictionary with models and their parameters.If it is not provided, then the dictionary is searched in the given config file.'}),kw_only=False,_field_type=_FIELD), 'test_output_path': Field(name='test_output_path',type=typing.Optional[str],default='./test',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Path where to save latent encodings and predictions for the test set when an epoch ends. Defaults to a a folder called 'test' in the current working directory."}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = '\n Arguments related to model.\n '¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_list_path=None, lr=0.0001, test_output_path='./test')¶
- __match_args__ = ('model_list_path', 'lr', 'test_output_path')¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'¶
- __repr__()¶
Return repr(self).
- class GranularDataArguments(batch_size=64, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_workers=1)[source]¶
Bases:
TrainingPipelineArguments
Arguments related to data.
- __name__ = 'GranularDataArguments'¶
- batch_size: int = 64¶
- validation_split: Optional[float] = None¶
- validation_indices_file: Optional[str] = None¶
- stratified_batch_file: Optional[str] = None¶
- stratified_value_name: Optional[str] = None¶
- num_workers: int = 1¶
- __annotations__ = {'batch_size': <class 'int'>, 'num_workers': <class 'int'>, 'stratified_batch_file': typing.Optional[str], 'stratified_value_name': typing.Optional[str], 'validation_indices_file': typing.Optional[str], 'validation_split': typing.Optional[float]}¶
- __dataclass_fields__ = {'batch_size': Field(name='batch_size',type=<class 'int'>,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Batch size of the training. Defaults to 64.'}),kw_only=False,_field_type=_FIELD), 'num_workers': Field(name='num_workers',type=<class 'int'>,default=1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'number of workers. Defaults to 1.'}),kw_only=False,_field_type=_FIELD), 'stratified_batch_file': Field(name='stratified_batch_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified batch file for sampling. Defaults to None, a.k.a., no stratified sampling.'}),kw_only=False,_field_type=_FIELD), 'stratified_value_name': Field(name='stratified_value_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Stratified value name. Defaults to None, a.k.a., no stratified sampling. Needed in case a stratified batch file is provided.'}),kw_only=False,_field_type=_FIELD), 'validation_indices_file': Field(name='validation_indices_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Indices to use for validation. Defaults to None, a.k.a., use validation split proportion, if not provided uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD), 'validation_split': Field(name='validation_split',type=typing.Optional[float],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Proportion used for validation. Defaults to None, a.k.a., use indices file if provided otherwise uses half of the data for validation.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = '\n Arguments related to data.\n '¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(batch_size=64, validation_split=None, validation_indices_file=None, stratified_batch_file=None, stratified_value_name=None, num_workers=1)¶
- __match_args__ = ('batch_size', 'validation_split', 'validation_indices_file', 'stratified_batch_file', 'stratified_value_name', 'num_workers')¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'¶
- __repr__()¶
Return repr(self).
- class GranularSavingArguments(model_path)[source]¶
Bases:
TrainingPipelineArguments
Saving arguments related to Granular trainer.
- __name__ = 'GranularSavingArguments'¶
- model_path: str¶
- __annotations__ = {'model_path': <class 'str'>}¶
- __dataclass_fields__ = {'model_path': Field(name='model_path',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the checkpoint file to be used.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Saving arguments related to Granular trainer.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path)¶
- __match_args__ = ('model_path',)¶
- __module__ = 'gt4sd.training_pipelines.pytorch_lightning.granular.core'¶
- __repr__()¶
Return repr(self).