Source code for gt4sd.training_pipelines.regression_transformer.core
#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Regression Transformer training utilities."""
import os
from dataclasses import dataclass, field
from typing import Optional
from ...configuration import gt4sd_configuration_instance
from ..core import TrainingPipelineArguments
from .utils import TransformersTrainingArgumentsCLI
DATA_ROOT = os.path.join(
gt4sd_configuration_instance.gt4sd_local_cache_path, "data", "RegressionTransformer"
)
os.makedirs(DATA_ROOT, exist_ok=True)
[docs]@dataclass
class RegressionTransformerTrainingArguments(
TrainingPipelineArguments, TransformersTrainingArgumentsCLI
):
"""
Arguments related to RegressionTransformer trainer.
NOTE: All arguments from `transformers.training_args.TrainingArguments` can be used.
Only additional ones are specified below.
"""
__name__ = "training_args"
training_name: str = field(
default="rt_training", metadata={"help": "Name used to identify the training."}
)
num_train_epochs: int = field(default=10, metadata={"help": "Number of epochs."})
batch_size: int = field(default=16, metadata={"help": "Size of the batch."})
log_interval: int = field(
default=100, metadata={"help": "Number of steps between log intervals."}
)
gradient_interval: int = field(
default=1, metadata={"help": "Gradient accumulation steps"}
)
eval_steps: int = field(
default=1000,
metadata={"help": "The time interval at which validation is performed."},
)
max_span_length: int = field(
default=5, metadata={"help": "Max length of a span of masked tokens for PLM."}
)
plm_probability: float = field(
default=1 / 6,
metadata={
"help": "Ratio of length of a span of masked tokens to surrounding context length for PLM."
},
)
alternate_steps: int = field(
default=50,
metadata={
"help": "Per default, training alternates between property prediction and "
"conditional generation. This argument specifies the alternation frequency."
"If you set it to 0, no alternation occurs and we fall back to vanilla "
"permutation language modeling (PLM). Default: 50."
},
)
cc_loss: bool = field(
default=False,
metadata={
"help": "Whether the cycle-consistency loss is computed during the conditional "
"generation task. Defaults to False."
},
)
cg_collator: str = field(
default="vanilla_cg",
metadata={
"help": "The collator class. Following options are implemented: "
"'vanilla_cg': Collator class that does not mask the properties but anything else as a regular DataCollatorForPermutationLanguageModeling. Can optionally replace the properties with sampled values. "
"NOTE: This collator can deal with multiple properties. "
"'multientity_cg': A training collator the conditional-generation task that can handle multiple entities. "
"Default: vanilla_cg."
},
)
entity_to_mask: int = field(
default=-1,
metadata={
"help": "Only applies if `cg_collator='multientity_cg'`. The entity that is being masked during training. 0 corresponds to first entity and so on. -1 corresponds to "
"a random sampling scheme where the entity-to-be-masked is determined "
"at runtime in the collator. NOTE: If 'mask_entity_separator' is true, "
"this argument will not have any effect. Defaults to -1."
},
)
entity_separator_token: str = field(
default=".",
metadata={
"help": "Only applies if `cg_collator='multientity_cg'`.The token that is used to separate "
"entities in the input. Defaults to '.' (applicable to SMILES & SELFIES)"
},
)
mask_entity_separator: bool = field(
default=False,
metadata={
"help": "Only applies if `cg_collator='multientity_cg'`. Whether or not the entity separator token can be masked. If True, *all** textual tokens can be masked and we "
"the collator behaves like the `vanilla_cg ` even though it is a `multientity_cg`. If False, the exact behavior "
"depends on the entity_to_mask argument. Defaults to False."
},
)
[docs]@dataclass
class RegressionTransformerDataArguments(TrainingPipelineArguments):
"""Arguments related to RegressionTransformer data loading."""
__name__ = "dataset_args"
train_data_path: str = field(
metadata={
"help": "Path to a `.csv` file with the input training data. The file has to "
"contain a `text` column (with the string input, e.g, SMILES, AAS, natural "
"text) and an arbitrary number of numerical columns."
},
)
test_data_path: str = field(
metadata={
"help": "Path to a `.csv` file with the input testing data. The file has to "
"contain a `text` column (with the string input, e.g, SMILES, AAS, natural "
"text) and an arbitrary number of numerical columns."
},
)
augment: Optional[int] = field(
default=0,
metadata={
"help": "Factor by which the training data is augmented. The data modality "
"(SMILES, SELFIES, AAS, natural text) is inferred from the tokenizer. "
"NOTE: For natural text, no augmentation is supported. Defaults to 0, "
"meaning no augmentation. "
},
)
save_datasets: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to save the datasets to disk. Datasets will be saved as `.txt` file to "
"the same location where `train_data_path` and `test_data_path` live. Defaults to False."
},
)
[docs]@dataclass
class RegressionTransformerSavingArguments(TrainingPipelineArguments):
"""Saving arguments related to RegressionTransformer trainer."""
__name__ = "saving_args"
model_path: str = field(
metadata={"help": "Path where the model artifacts are stored."}
)
checkpoint_name: str = field(
default=str(),
metadata={
"help": "Name for the checkpoint that should be copied to inference model. "
"Has to be a subfolder of `model_path`. Defaults to empty string meaning that "
"files are taken from `model_path` (i.e., after training finished)."
},
)