gt4sd.training_pipelines.regression_transformer.implementation module¶
Regression Transformer training implementation.
Summary¶
Classes:
Arguments pertaining to model instantiation.  | 
|
RegressionTransformer training pipeline.  | 
Reference¶
- class RegressionTransformerTrainingPipeline[source]¶
 Bases:
TrainingPipelineRegressionTransformer training pipeline.
- train(training_args, model_args, dataset_args)[source]¶
 - Generic training function for training a Regression Transformer (RT) model.
 - For details see:
 Born, J., & Manica, M. (2023). Regression Transformer enables concurrent sequence regression and generation for molecular language modelling. Nature Machine Intelligence, 5(4), 432-444.
- Parameters
 training_args (
Dict[str,Any]) – training arguments passed to the configuration.model_args (
Dict[str,Any]) – model arguments passed to the configuration.dataset_args (
Dict[str,Any]) – dataset arguments passed to the configuration.
- Return type
 None
- setup_model(params)[source]¶
 Error handling and training setup routine.
- Parameters
 params (
Dict[str,Any]) – A dictionary with all parameters to launch training.- Raises
 ValueError – If flawed values are passed.
- setup_dataset(train_data_path, test_data_path, augment=0, save_datasets=False, *args, **kwargs)[source]¶
 Constructs the dataset objects.
- Parameters
 train_data_path (
str) – Path to .csv file. Has to have a text column and at least one column of numerical properties.train_data_path – Path to .csv file. Has to have a text column and at least one column of numerical properties.
augment (
int) – How many times each training sample is augmented.save_datasets (
bool) – Whether to save the datasets to disk (will be stored in same location as train_data_path and test_data_path).
- Returns
 A tuple of train and test dataset.
- create_dataset_from_list(data, save_path=None)[source]¶
 Creates a LineByLineTextDataset from a List of strings.
- Parameters
 data (
List[str]) – List of strings with the samples.save_path (
Optional[str,None]) – Path to save the dataset to. Defaults to None, meaning the dataset will not be saved.
- Return type
 LineByLineTextDataset- Returns
 The dataset.
- __annotations__ = {}¶
 
- __doc__ = 'RegressionTransformer training pipeline.'¶
 
- __module__ = 'gt4sd.training_pipelines.regression_transformer.implementation'¶
 
- class RegressionTransformerModelArguments(model_path=None, tokenizer_name=None, config_name=None, model_type='xlnet', cache_dir=None)[source]¶
 Bases:
TrainingPipelineArgumentsArguments pertaining to model instantiation.
- __name__ = 'RegressionTransformerModelArguments'¶
 
- model_path: Optional[str] = None¶
 
- tokenizer_name: Optional[str] = None¶
 
- config_name: Optional[str] = None¶
 
- model_type: Optional[str] = 'xlnet'¶
 
- cache_dir: Optional[str] = None¶
 
- __annotations__ = {'cache_dir': typing.Optional[str], 'config_name': typing.Optional[str], 'model_path': typing.Optional[str], 'model_type': typing.Optional[str], 'tokenizer_name': typing.Optional[str]}¶
 
- __dataclass_fields__ = {'cache_dir': Field(name='cache_dir',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Where do you want to store the pretrained models downloaded from s3'}),kw_only=False,_field_type=_FIELD), 'config_name': Field(name='config_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Pretrained config name or path. But `model_path` takes preference.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the pretrained model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'model_type': Field(name='model_type',type=typing.Optional[str],default='xlnet',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If training from scratch, pass a model type from the list: albert, bart, bert, big_bird, bigbird_pegasus, blenderbot-small, bloom, camembert, codegen, convbert, ctrl, data2vec-text, deberta, deberta-v2, distilbert, electra, encoder-decoder, ernie, esm, flaubert, fnet, fsmt, funnel, gpt2, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, ibert, layoutlm, led, longformer, longt5, luke, m2m_100, marian, megatron-bert, mobilebert, mpnet, mvp, nezha, nystromformer, openai-gpt, pegasus_x, plbart, qdqbert, reformer, rembert, roberta, roformer, speech_to_text, squeezebert, t5, tapas, transfo-xl, wav2vec2, whisper, xlm, xlm-roberta, xlm-roberta-xl, xlnet, yoso. If `model_path` is also provided, `model_path` takes preference.'}),kw_only=False,_field_type=_FIELD), 'tokenizer_name': Field(name='tokenizer_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Pretrained tokenizer name or path. If not provided, will be inferred from `model_path`. If `model_path` is not provided either you have to pass a tokenizer.'}),kw_only=False,_field_type=_FIELD)}¶
 
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
 
- __doc__ = 'Arguments pertaining to model instantiation.'¶
 
- __eq__(other)¶
 Return self==value.
- __hash__ = None¶
 
- __init__(model_path=None, tokenizer_name=None, config_name=None, model_type='xlnet', cache_dir=None)¶
 
- __match_args__ = ('model_path', 'tokenizer_name', 'config_name', 'model_type', 'cache_dir')¶
 
- __module__ = 'gt4sd.training_pipelines.regression_transformer.implementation'¶
 
- __repr__()¶
 Return repr(self).