gt4sd.training_pipelines.crystals_crf.core module

Crystals crf training utilities.

Summary

Classes:

CrystalsRFCDataArguments

Data arguments related to crystals RFC trainer.

CrystalsRFCModelArguments

Model arguments related to crystals RFC trainer.

CrystalsRFCSavingArguments

Saving arguments related to crystals RFC trainer.

CrystalsRFCTrainingArguments

Training arguments related to crystals RFC trainer.

CrystalsRFCTrainingPipeline

Crystals RFC training pipelines for crystals.

Reference

class CrystalsRFCTrainingPipeline[source]

Bases: TrainingPipeline

Crystals RFC training pipelines for crystals.

train(training_args, model_args, dataset_args)[source]

Generic training function for Crystals RFC models.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Raises

NotImplementedError – the generic trainer does not implement the pipeline.

Return type

None

__annotations__ = {}
__doc__ = 'Crystals RFC training pipelines for crystals.'
__module__ = 'gt4sd.training_pipelines.crystals_crf.core'
class CrystalsRFCDataArguments(datapath, test_size=None)[source]

Bases: TrainingPipelineArguments

Data arguments related to crystals RFC trainer.

__name__ = 'CrystalsRFCDataArguments'
datapath: str
__annotations__ = {'datapath': <class 'str'>, 'test_size': typing.Optional[int]}
__dataclass_fields__ = {'datapath': Field(name='datapath',type=<class 'str'>,default=<dataclasses._MISSING_TYPE object>,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the dataset.The dataset should follow the directory structure as described in https://github.com/dilangaem/SemiconAI.'}),kw_only=False,_field_type=_FIELD), 'test_size': Field(name='test_size',type=typing.Optional[int],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Testing set percentage.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Data arguments related to crystals RFC trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(datapath, test_size=None)
__match_args__ = ('datapath', 'test_size')
__module__ = 'gt4sd.training_pipelines.crystals_crf.core'
__repr__()

Return repr(self).

class CrystalsRFCModelArguments(sym='all')[source]

Bases: TrainingPipelineArguments

Model arguments related to crystals RFC trainer.

__name__ = 'CrystalsRFCModelArguments'
sym: str = 'all'
__annotations__ = {'sym': <class 'str'>}
__dataclass_fields__ = {'sym': Field(name='sym',type=<class 'str'>,default='all',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Crystal systems to be used. 'all' for all the crystal systems. Other seven options are: 'monoclinic', 'triclinic', 'orthorhombic', 'trigonal', 'hexagonal', 'cubic', 'tetragonal'"}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Model arguments related to crystals RFC trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(sym='all')
__match_args__ = ('sym',)
__module__ = 'gt4sd.training_pipelines.crystals_crf.core'
__repr__()

Return repr(self).

class CrystalsRFCTrainingArguments(output_path='.')[source]

Bases: TrainingPipelineArguments

Training arguments related to crystals RFC trainer.

__name__ = 'CrystalsRFCTrainingArguments'
output_path: str = '.'
__annotations__ = {'output_path': <class 'str'>}
__dataclass_fields__ = {'output_path': Field(name='output_path',type=<class 'str'>,default='.',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path to the store the checkpoints.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Training arguments related to crystals RFC trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(output_path='.')
__match_args__ = ('output_path',)
__module__ = 'gt4sd.training_pipelines.crystals_crf.core'
__repr__()

Return repr(self).

class CrystalsRFCSavingArguments[source]

Bases: TrainingPipelineArguments

Saving arguments related to crystals RFC trainer.

__name__ = 'CrystalsRFCSavingArguments'
__annotations__ = {}
__dataclass_fields__ = {}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Saving arguments related to crystals RFC trainer.'
__eq__(other)

Return self==value.

__hash__ = None
__init__()
__match_args__ = ()
__module__ = 'gt4sd.training_pipelines.crystals_crf.core'
__repr__()

Return repr(self).