gt4sd.training_pipelines.regression_transformer.implementation module

Regression Transformer training implementation.

Summary

Classes:

RegressionTransformerModelArguments

Arguments pertaining to model instantiation.

RegressionTransformerTrainingPipeline

RegressionTransformer training pipeline.

Reference

class RegressionTransformerTrainingPipeline[source]

Bases: TrainingPipeline

RegressionTransformer training pipeline.

train(training_args, model_args, dataset_args)[source]
Generic training function for training a Regression Transformer (RT) model.
For details see:

Born, J., & Manica, M. (2023). Regression Transformer enables concurrent sequence regression and generation for molecular language modelling. Nature Machine Intelligence, 5(4), 432-444.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

None

setup_model(params)[source]

Error handling and training setup routine.

Parameters

params (Dict[str, Any]) – A dictionary with all parameters to launch training.

Raises

ValueError – If flawed values are passed.

setup_dataset(train_data_path, test_data_path, augment=0, save_datasets=False, *args, **kwargs)[source]

Constructs the dataset objects.

Parameters
  • train_data_path (str) – Path to .csv file. Has to have a text column and at least one column of numerical properties.

  • train_data_path – Path to .csv file. Has to have a text column and at least one column of numerical properties.

  • augment (int) – How many times each training sample is augmented.

  • save_datasets (bool) – Whether to save the datasets to disk (will be stored in same location as train_data_path and test_data_path).

Returns

A tuple of train and test dataset.

create_dataset_from_list(data, save_path=None)[source]

Creates a LineByLineTextDataset from a List of strings.

Parameters
  • data (List[str]) – List of strings with the samples.

  • save_path (Optional[str, None]) – Path to save the dataset to. Defaults to None, meaning the dataset will not be saved.

Return type

LineByLineTextDataset

Returns

The dataset.

__annotations__ = {}
__doc__ = 'RegressionTransformer training pipeline.'
__module__ = 'gt4sd.training_pipelines.regression_transformer.implementation'
class RegressionTransformerModelArguments(model_path=None, tokenizer_name=None, config_name=None, model_type='xlnet', cache_dir=None)[source]

Bases: TrainingPipelineArguments

Arguments pertaining to model instantiation.

__name__ = 'RegressionTransformerModelArguments'
model_path: Optional[str] = None
tokenizer_name: Optional[str] = None
config_name: Optional[str] = None
model_type: Optional[str] = 'xlnet'
cache_dir: Optional[str] = None
__annotations__ = {'cache_dir': typing.Optional[str], 'config_name': typing.Optional[str], 'model_path': typing.Optional[str], 'model_type': typing.Optional[str], 'tokenizer_name': typing.Optional[str]}
__dataclass_fields__ = {'cache_dir': Field(name='cache_dir',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Where do you want to store the pretrained models downloaded from s3'}),kw_only=False,_field_type=_FIELD), 'config_name': Field(name='config_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Pretrained config name or path. But `model_path` takes preference.'}),kw_only=False,_field_type=_FIELD), 'model_path': Field(name='model_path',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Path where the pretrained model artifacts are stored.'}),kw_only=False,_field_type=_FIELD), 'model_type': Field(name='model_type',type=typing.Optional[str],default='xlnet',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'If training from scratch, pass a model type from the list: albert, bart, bert, big_bird, bigbird_pegasus, blenderbot-small, bloom, camembert, codegen, convbert, ctrl, data2vec-text, deberta, deberta-v2, distilbert, electra, encoder-decoder, ernie, esm, flaubert, fnet, fsmt, funnel, gpt2, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, ibert, layoutlm, led, longformer, longt5, luke, m2m_100, marian, megatron-bert, mobilebert, mpnet, mvp, nezha, nystromformer, openai-gpt, pegasus_x, plbart, qdqbert, reformer, rembert, roberta, roformer, speech_to_text, squeezebert, t5, tapas, transfo-xl, wav2vec2, whisper, xlm, xlm-roberta, xlm-roberta-xl, xlnet, yoso. If `model_path` is also provided, `model_path` takes preference.'}),kw_only=False,_field_type=_FIELD), 'tokenizer_name': Field(name='tokenizer_name',type=typing.Optional[str],default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Pretrained tokenizer name or path. If not provided, will be inferred from `model_path`. If `model_path` is not provided either you have to pass a tokenizer.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments pertaining to model instantiation.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(model_path=None, tokenizer_name=None, config_name=None, model_type='xlnet', cache_dir=None)
__match_args__ = ('model_path', 'tokenizer_name', 'config_name', 'model_type', 'cache_dir')
__module__ = 'gt4sd.training_pipelines.regression_transformer.implementation'
__repr__()

Return repr(self).