#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Patent Generative Transformer (PGT) generation algorithm."""
import logging
import os
import shutil
from dataclasses import field
from typing import Any, ClassVar, Dict, Optional, TypeVar
from gt4sd_trainer.hf_pl.cli_pl_to_hf_converter import convert_pl_to_hf
from gt4sd_trainer.hf_pl.core import LanguageModelingSavingArguments
from typing_extensions import Protocol, runtime_checkable
from ....training_pipelines.core import TrainingPipelineArguments
from ...core import AlgorithmConfiguration, GeneratorAlgorithm, Untargeted
from ...registry import ApplicationsRegistry
from .implementation import (
COHERENCE_TYPES,
EDITING_TYPES,
GENERATION_PROMPTS,
CoherenceCheckGenerator,
EditGenerator,
Generator,
PartGenerator,
)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
T = type(None)
S = TypeVar("S", bound=str)
[docs]class PGT(GeneratorAlgorithm[S, T]):
"""PGT Algorithm."""
[docs] def __init__(
self,
configuration: AlgorithmConfiguration[S, T],
target: Optional[T] = None,
) -> None:
"""Instantiate PGT ready to generate items.
Args:
configuration: domain and application
specification defining parameters, types and validations.
target: unused since it is not a conditional generator.
Example:
An example for generating abstract from a given claim:
config = PGTGenerator(task="claim_to_abstract", input_text="My interesting claim")
generator = PGT(configuration=config)
print(list(generator.sample(1)))
"""
configuration = self.validate_configuration(configuration)
self.max_samples = configuration.num_return_sequences # type: ignore
# No validation/check on the target input here, since model is not yet loaded.
super().__init__(
configuration=configuration, # type:ignore
target=target, # type:ignore
)
[docs] def get_generator(
self,
configuration: AlgorithmConfiguration[S, T],
target: Optional[T],
) -> Untargeted:
"""Get the function to sample with the given configuration.
Args:
configuration: helps to set up specific application of PGT.
target: context or condition for the generation. Unused in the algorithm.
Returns:
callable with target generating a batch of items.
"""
logger.info("ensure artifacts for the application are present.")
self.local_artifacts = configuration.ensure_artifacts()
implementation: Generator = configuration.get_generator( # type: ignore
self.local_artifacts
)
return implementation.generate_case # type: ignore
[docs] def validate_configuration(
self, configuration: AlgorithmConfiguration[S, T]
) -> AlgorithmConfiguration[S, T]:
@runtime_checkable
class AnyPGTConfiguration(Protocol):
"""Protocol for PGT configurations."""
def get_generator(self, resources_path: str) -> Generator:
...
def validate_item(self, item: Any) -> S:
...
# TODO raise InvalidAlgorithmConfiguration
assert isinstance(configuration, AnyPGTConfiguration)
assert isinstance(configuration, AlgorithmConfiguration)
return configuration
[docs]@ApplicationsRegistry.register_algorithm_application(PGT)
class PGTAlgorithmConfiguration(AlgorithmConfiguration[str, None]):
"""Basic configuration for a PGT algorithm"""
algorithm_type: ClassVar[str] = "generation"
domain: ClassVar[str] = "nlp"
algorithm_version: str = "v0"
model_type: str = field(
default="",
metadata=dict(description="Type of the model."),
)
max_length: int = field(
default=512, metadata=dict(description="Maximum length of the generated text.")
)
top_k: int = field(
default=50,
metadata=dict(description="Number of top-k probability tokens to keep."),
)
top_p: float = field(
default=1.0,
metadata=dict(
description="Only tokens with cumulative probabilities summing up to this value are kept."
),
)
num_return_sequences: int = field(
default=3,
metadata=dict(description="Number of alternatives to be generated."),
)
no_repeat_ngram_size: int = field(
default=2,
metadata=dict(description="Size of n-gram to not appear twice."),
)
[docs] def get_target_description(self) -> Optional[Dict[str, str]]:
"""Get description of the target for generation.
Returns:
target description, returns None in case no target is used.
"""
return None
[docs] def get_generator(self, resources_path: str, **kwargs) -> Generator:
"""Instantiate the actual PGT implementation.
Args:
resources_path: local path to model files.
Returns:
instance with
:meth:`generate_batch<gt4sd.algorithms.generation.pgt.implementation.Generator.generate_case>`
method for targeted generation.
"""
return Generator(
resources_path=resources_path,
model_type=self.model_type,
model_name=self.algorithm_version,
max_length=self.max_length,
top_k=self.top_k,
top_p=self.top_p,
num_return_sequences=self.num_return_sequences,
)
[docs] @classmethod
def save_version_from_training_pipeline_arguments_postprocess(
cls,
training_pipeline_arguments: TrainingPipelineArguments,
):
"""Postprocess after saving. Remove temporarily converted hf model
if pytorch-lightning checkpoint is given.
Args:
training_pipeline_arguments: training pipeline arguments.
"""
if isinstance(training_pipeline_arguments, LanguageModelingSavingArguments):
if training_pipeline_arguments.ckpt is not None:
shutil.rmtree(training_pipeline_arguments.hf_model_path)
logger.info(
f"Cleaning up temporary files from {training_pipeline_arguments.hf_model_path}"
)
else:
return super().save_version_from_training_pipeline_arguments_postprocess(
training_pipeline_arguments
)
[docs] @classmethod
def get_filepath_mappings_for_training_pipeline_arguments(
cls, training_pipeline_arguments: TrainingPipelineArguments
) -> Dict[str, str]:
"""Ger filepath mappings for the given training pipeline arguments.
Args:
training_pipeline_arguments: training pipeline arguments.
Returns:
a mapping between artifacts' files and training pipeline's output files.
"""
if isinstance(training_pipeline_arguments, LanguageModelingSavingArguments):
if training_pipeline_arguments.ckpt is not None:
convert_pl_to_hf(training_pipeline_arguments)
model_files = os.listdir(training_pipeline_arguments.hf_model_path)
model_files_dict = {
file: os.path.join(training_pipeline_arguments.hf_model_path, file)
for file in model_files
}
return model_files_dict
else:
return super().get_filepath_mappings_for_training_pipeline_arguments(
training_pipeline_arguments
)
[docs]@ApplicationsRegistry.register_algorithm_application(PGT)
class PGTGenerator(PGTAlgorithmConfiguration):
"""Configuration for a PGT Generator algorithm"""
input_text: str = field(
default="This is my input",
metadata=dict(description="Input text."),
)
task: str = field(
default="title-to-abstract",
metadata=dict(
description=f"Generation tasks. Supported: {', '.join(GENERATION_PROMPTS.keys())}"
),
)
[docs] def get_generator(self, resources_path: str, **kwargs) -> Generator:
"""Instantiate the actual PGT implementation for part of patent generation.
Args:
resources_path: local path to model files.
Returns:
instance with
:meth:`generate_batch<gt4sd.algorithms.generation.pgt.implementation.Generator.generate_case>`
method for targeted generation.
"""
return PartGenerator(
resources_path=resources_path,
input_text=self.input_text,
model_type=self.model_type,
model_name=self.algorithm_version,
max_length=self.max_length,
top_k=self.top_k,
top_p=self.top_p,
num_return_sequences=self.num_return_sequences,
no_repeat_ngram_size=self.no_repeat_ngram_size,
task=self.task,
)
[docs]@ApplicationsRegistry.register_algorithm_application(PGT)
class PGTEditor(PGTAlgorithmConfiguration):
"""Configuration for a PGT Editor algorithm."""
input_text: str = field(
default="This is my input",
metadata=dict(description="Input text."),
)
input_type: str = field(
default="abstract",
metadata=dict(
description=f"Part of a patent the input text belongs. Supported: {', '.join(EDITING_TYPES)}"
),
)
[docs] def get_generator(self, resources_path: str, **kwargs) -> Generator:
"""Instantiate the actual PGT implementation for part of patent editing.
Args:
resources_path: local path to model files.
Returns:
instance with
:meth:`generate_batch<gt4sd.algorithms.generation.pgt.implementation.Generator.generate_case>`
method for targeted generation.
"""
return EditGenerator(
resources_path=resources_path,
input_text=self.input_text,
model_type=self.model_type,
model_name=self.algorithm_version,
max_length=self.max_length,
top_k=self.top_k,
top_p=self.top_p,
num_return_sequences=self.num_return_sequences,
no_repeat_ngram_size=self.no_repeat_ngram_size,
input_type=self.input_type,
)
[docs]@ApplicationsRegistry.register_algorithm_application(PGT)
class PGTCoherenceChecker(PGTAlgorithmConfiguration):
"""Configuration for a PGT coherence check algorithm"""
num_return_sequences: int = field(
default=1,
metadata=dict(
description="Number of alternatives should be always 1 for coherence check."
),
)
input_a: str = field(
default="I'm a stochastic parrot.",
metadata=dict(description="First input for coherence check."),
)
input_b: str = field(
default="I'm a stochastic parrot.",
metadata=dict(description="Second input for coherence check."),
)
coherence_type: str = field(
default="title-abstract",
metadata=dict(
description=f"Input types for the check. Supported: {', '.join(COHERENCE_TYPES)}"
),
)
[docs] def get_generator(self, resources_path: str, **kwargs) -> Generator:
"""Instantiate the actual PGT implementation for patent coherence check.
Args:
resources_path: local path to model files.
Returns:
instance with
:meth:`generate_batch<gt4sd.algorithms.generation.pgt.implementation.Generator.generate_case>`
method for targeted generation.
"""
return CoherenceCheckGenerator(
resources_path=resources_path,
input_a=self.input_a,
input_b=self.input_b,
model_type=self.model_type,
model_name=self.algorithm_version,
max_length=self.max_length,
top_k=self.top_k,
top_p=self.top_p,
num_return_sequences=self.num_return_sequences,
no_repeat_ngram_size=self.no_repeat_ngram_size,
coherence_type=self.coherence_type,
)