gt4sd.training_pipelines.regression_transformer.implementation module¶
Regression Transformer training implementation.
Summary¶
Classes:
Arguments pertaining to model instantiation. |
|
RegressionTransformer training pipeline. |
Reference¶
- class RegressionTransformerTrainingPipeline[source]¶
Bases:
TrainingPipeline
RegressionTransformer training pipeline.
- train(training_args, model_args, dataset_args)[source]¶
- Generic training function for training a Regression Transformer (RT) model.
- For details see:
Born, J., & Manica, M. (2023). Regression Transformer enables concurrent sequence regression and generation for molecular language modelling. Nature Machine Intelligence, 5(4), 432-444.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
None
- setup_model(params)[source]¶
Error handling and training setup routine.
- Parameters
params (
Dict
[str
,Any
]) – A dictionary with all parameters to launch training.- Raises
ValueError – If flawed values are passed.
- setup_dataset(train_data_path, test_data_path, augment=0, save_datasets=False, *args, **kwargs)[source]¶
Constructs the dataset objects.
- Parameters
train_data_path (
str
) – Path to .csv file. Has to have a text column and at least one column of numerical properties.train_data_path – Path to .csv file. Has to have a text column and at least one column of numerical properties.
augment (
int
) – How many times each training sample is augmented.save_datasets (
bool
) – Whether to save the datasets to disk (will be stored in same location as train_data_path and test_data_path).
- Returns
A tuple of train and test dataset.
- create_dataset_from_list(data, save_path=None)[source]¶
Creates a LineByLineTextDataset from a List of strings.
- Parameters
data (
List
[str
]) – List of strings with the samples.save_path (
Optional
[str
,None
]) – Path to save the dataset to. Defaults to None, meaning the dataset will not be saved.
- Return type
LineByLineTextDataset
- Returns
The dataset.
- __annotations__ = {}¶
- __doc__ = 'RegressionTransformer training pipeline.'¶
- __module__ = 'gt4sd.training_pipelines.regression_transformer.implementation'¶
- class RegressionTransformerModelArguments(model_path=None, tokenizer_name=None, config_name=None, model_type='xlnet', cache_dir=None)[source]¶
Bases:
TrainingPipelineArguments
Arguments pertaining to model instantiation.
- __name__ = 'RegressionTransformerModelArguments'¶
- model_path: Optional[str] = None¶
- tokenizer_name: Optional[str] = None¶
- config_name: Optional[str] = None¶
- model_type: Optional[str] = 'xlnet'¶
- cache_dir: Optional[str] = None¶
- __annotations__ = {'cache_dir': typing.Optional[str], 'config_name': typing.Optional[str], 'model_path': typing.Optional[str], 'model_type': typing.Optional[str], 'tokenizer_name': typing.Optional[str]}¶
- __dataclass_fields__ = {'cache_dir': Field(name='cache_dir',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Where do you want to store the pretrained models downloaded from s3'}),kw_only=False,_field_type=_FIELD), 'config_name': Field(name='config_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Pretrained config name or path. But `model_path` takes preference.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the pretrained model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'model_type': Field(name='model_type',type=typing.Optional[str],default='xlnet',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If training from scratch, pass a model type from the list: albert, bart, bert, big_bird, bigbird_pegasus, blenderbot-small, bloom, camembert, codegen, convbert, ctrl, data2vec-text, deberta, deberta-v2, distilbert, electra, encoder-decoder, ernie, esm, flaubert, fnet, fsmt, funnel, gpt2, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, ibert, layoutlm, led, longformer, longt5, luke, m2m_100, marian, megatron-bert, mobilebert, mpnet, mvp, nezha, nystromformer, openai-gpt, pegasus_x, plbart, qdqbert, reformer, rembert, roberta, roformer, speech_to_text, squeezebert, t5, tapas, transfo-xl, wav2vec2, whisper, xlm, xlm-roberta, xlm-roberta-xl, xlnet, yoso. If `model_path` is also provided, `model_path` takes preference.'}),kw_only=False,_field_type=_FIELD), 'tokenizer_name': Field(name='tokenizer_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Pretrained tokenizer name or path. If not provided, will be inferred from `model_path`. If `model_path` is not provided either you have to pass a tokenizer.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments pertaining to model instantiation.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(model_path=None, tokenizer_name=None, config_name=None, model_type='xlnet', cache_dir=None)¶
- __match_args__ = ('model_path', 'tokenizer_name', 'config_name', 'model_type', 'cache_dir')¶
- __module__ = 'gt4sd.training_pipelines.regression_transformer.implementation'¶
- __repr__()¶
Return repr(self).