gt4sd.cli.trainer module

Run training pipelines for the GT4SD.

Summary

Classes:

TrainerArgumentParser

Argument parser using a custom help logic.

TrainerArguments

Trainer arguments.

Functions:

main

Run a training pipeline.

Reference

class TrainerArguments(training_pipeline_name, configuration_file=None)[source]

Bases: object

Trainer arguments.

__name__ = 'TrainerArguments'
training_pipeline_name: str
configuration_file: Optional[str] = None
__annotations__ = {'configuration_file': typing.Optional[str], 'training_pipeline_name': <class 'str'>}
__dataclass_fields__ = {'configuration_file': Field(name='configuration_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Configuration file for the training in JSON format. It can be used to completely by-pass pipeline specific arguments.'}),kw_only=False,_field_type=_FIELD), 'training_pipeline_name': Field(name='training_pipeline_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training pipeline name, supported pipelines: cgcnn, crystals-rfc, diffusion-trainer, gflownet-trainer, granular-trainer, guacamol-lstm-trainer, language-modeling-trainer, molformer, moses-organ-trainer, moses-vae-trainer, paccmann-vae-trainer, regression-transformer-trainer, torchdrug-gcpn-trainer, torchdrug-graphaf-trainer.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__dict__ = mappingproxy({'__module__': 'gt4sd.cli.trainer', '__annotations__': {'training_pipeline_name': <class 'str'>, 'configuration_file': typing.Optional[str]}, '__doc__': 'Trainer arguments.', '__name__': 'trainer_base_args', 'configuration_file': None, '__dict__': <attribute '__dict__' of 'TrainerArguments' objects>, '__weakref__': <attribute '__weakref__' of 'TrainerArguments' objects>, '__dataclass_params__': _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False), '__dataclass_fields__': {'training_pipeline_name': Field(name='training_pipeline_name',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Training pipeline name, supported pipelines: cgcnn, crystals-rfc, diffusion-trainer, gflownet-trainer, granular-trainer, guacamol-lstm-trainer, language-modeling-trainer, molformer, moses-organ-trainer, moses-vae-trainer, paccmann-vae-trainer, regression-transformer-trainer, torchdrug-gcpn-trainer, torchdrug-graphaf-trainer.'}),kw_only=False,_field_type=_FIELD), 'configuration_file': Field(name='configuration_file',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Configuration file for the training in JSON format. It can be used to completely by-pass pipeline specific arguments.'}),kw_only=False,_field_type=_FIELD)}, '__init__': <function TrainerArguments.__init__>, '__repr__': <function TrainerArguments.__repr__>, '__eq__': <function TrainerArguments.__eq__>, '__hash__': None, '__match_args__': ('training_pipeline_name', 'configuration_file')})
__doc__ = 'Trainer arguments.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(training_pipeline_name, configuration_file=None)
__match_args__ = ('training_pipeline_name', 'configuration_file')
__module__ = 'gt4sd.cli.trainer'
__repr__()

Return repr(self).

__weakref__

list of weak references to the object (if defined)

class TrainerArgumentParser(dataclass_types, **kwargs)[source]

Bases: ArgumentParser

Argument parser using a custom help logic.

print_help(file=None)[source]

Print help checking dynamically whether a specific pipeline is passed.

Parameters

file (Optional[IO[str], None]) – an optional I/O stream. Defaults to None, a.k.a., stdout and stderr.

Return type

None

parse_json_file(json_file)[source]

Overriding default .json parser.

It by-passes all command line arguments and simply add the training pipeline.

Parameters

json_file (str) – JSON file containing pipeline configuration parameters.

Return type

Tuple[DataClass, …]

Returns

parsed arguments in a tuple of dataclasses.

__annotations__ = {}
__doc__ = 'Argument parser using a custom help logic.'
__module__ = 'gt4sd.cli.trainer'
main()[source]

Run a training pipeline.

Raises

ValueError – in case the provided training pipeline provided is not supported.

Return type

None