#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""
Implementation details for PGT algorithms.
"""
import logging
import os
import re
from typing import List, Optional, Tuple, Union
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from ....frameworks.torch import device_claim
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MAXIMUM_LENGTH = int(10000)
GENERATION_PROMPTS = {
    "title-to-abstract": "{} <|sep|> Given the above title, suggest an abstract <|sep|>",
    "abstract-to-claim": "{} <|sep|> Given the above abstract, suggest a claim <|sep|>",
    "claim-to-abstract": "{} <|sep|> Given the above claim, suggest an abstract <|sep|>",
    "abstract-to-title": "{} <|sep|> Given the above abstract, suggest a title <|sep|>",
}
EDITING_TYPES = ["abstract", "claim"]
COHERENCE_TYPES = ["title-abstract", "abstract-claim", "title-claim"]
STOPPING_PUNCTUATION_REGEX = re.compile(r"(.+(?=\.|!|;)(.|!|;)|(.*))")
[docs]def adjust_length_to_model(length: int, maximum_sequence_length: int):
    """Adjust sequence length.
    Args:
        length: target length.
        maximum_sequence_length: maximum sequence length.
    Returns:
        the adjusted length.
    """
    if length < 0 and maximum_sequence_length > 0:
        logger.warning(
            f"negative length, adjusting to model supported length {maximum_sequence_length}"
        )
        length = maximum_sequence_length
    elif 0 < maximum_sequence_length < length:
        logger.warning(
            f"longer then model supported length, adjusting to {maximum_sequence_length}"
        )
        length = maximum_sequence_length
    elif length < 0:
        logger.warning(f"negative length, adjusting to maximal length {MAXIMUM_LENGTH}")
        length = MAXIMUM_LENGTH
    return length 
[docs]class Generator:
    """Implementation of a generator."""
[docs]    def __init__(
        self,
        resources_path: str,
        model_type: str,
        model_name: str,
        max_length: int,
        top_k: int,
        top_p: float,
        num_return_sequences: int,
        prompt: str = "This is an interesting prompt",
        no_repeat_ngram_size: int = 2,
        device: Optional[Union[torch.device, str]] = None,
    ):
        """PGT generation algorithm.
        Args:
            resources_path: path to the cache.
            model_type: type of the model.
            model_name: name of the model weights/version.
            prompt: prompt for text generation.
            max_length: max length of the generated text.
            top_k: number of top-k probability token to keep.
            top_p: only tokens with cumulative probabilities summing up to this value are kept.
            num_return_sequences: number of generated sequences.
            no_repeat_ngram_size: size of n-gram to not appear twice.
            device: device where the inference
                is running either as a dedicated class or a string. If not provided is inferred.
        """
        self.device = device_claim(device)
        self.resources_path = resources_path
        self.model_type = model_type
        self.model_name = model_name
        self.prompt = prompt
        self.length = max_length
        self.k = top_k
        self.p = top_p
        self.no_repeat_ngram_size = no_repeat_ngram_size
        self.number_of_sequences = num_return_sequences
        self.load_model() 
[docs]    def load_model(self) -> None:
        """Load a pretrained PGT model."""
        if (
            os.path.exists(self.resources_path)
            and len(os.listdir(self.resources_path)) > 0
        ):
            model_name_or_path = self.resources_path
        else:
            logger.error(f"{self.resources_path} not found")
        self.tokenizer = GPT2Tokenizer.from_pretrained(  # type:ignore
            model_name_or_path,
            sep_token="<|sep|>",
            mask_token="[MASK]",
            pad_token="<|pad|>",
            additional_special_tokens=["<|mask_sep|>"],
        )
        self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.model.to(self.device)
        # adjusting length
        self.length = adjust_length_to_model(
            self.length, self.model.config.max_position_embeddings
        ) 
[docs]    def generate_case(self) -> Union[List[str], List[Tuple[str, ...]]]:
        """Sample text snippets.
        Returns:
            generated text snippets.
        """
        self.prompt = re.sub(" +", " ", self.prompt)
        self.prompt = re.sub(r'\s([?.!,:;ยท"](?:\s|$))', r"\1", self.prompt)
        encoded_prompt = self.tokenizer.encode(self.prompt, return_tensors="pt")
        encoded_prompt = encoded_prompt.to(self.device)
        if encoded_prompt.size()[-1] == 0:
            input_ids = None
        else:
            input_ids = encoded_prompt
        output_sequences = self.model.generate(
            input_ids=input_ids,
            max_length=self.length,
            top_k=self.k,
            top_p=self.p,
            do_sample=True,
            no_repeat_ngram_size=self.no_repeat_ngram_size,
            num_return_sequences=self.number_of_sequences,
        )
        # NOTE: remove the batch dimension when returning multiple sequences
        if len(output_sequences.shape) > 2:
            output_sequences.squeeze_()
        generated_sequences: List[str] = []
        for generated_sequence in output_sequences:
            generated_sequence = generated_sequence.tolist()
            text = self.tokenizer.decode(generated_sequence)
            text = text.replace(self.prompt, "")
            text = text.split("<|endoftext|>")[0]
            text = text.replace("<|pad|>", "")
            text = text.strip()
            text = STOPPING_PUNCTUATION_REGEX.search(  # type:ignore
                text
            ).group()
            generated_sequences.append(text)
        return self.format_output(self.prompt, generated_sequences) 
 
[docs]class PartGenerator(Generator):
    """Implementation of edit generator."""
[docs]    def __init__(
        self,
        resources_path: str,
        input_text: str,
        model_type: str,
        model_name: str,
        task: str,
        max_length: int,
        top_k: int,
        top_p: float,
        num_return_sequences: int,
        no_repeat_ngram_size: int = 2,
        device: Optional[Union[torch.device, str]] = None,
    ):
        """PGT generation algorithm.
        Args:
            resources_path: path to the cache.
            input_text: input text for generation.
            task: generation task.
            model_type: type of the model.
            model_name: name of the model weights/version.
            max_length: max length of the generated text.
            top_k: number of top-k probability token to keep.
            top_p: only tokens with cumulative probabilities summing up to this value are kept.
            num_return_sequences: number of generated sequences.
            no_repeat_ngram_size: size of n-gram to not appear twice.
            device: device where the inference
                is running either as a dedicated class or a string. If not provided is inferred.
        """
        if task not in GENERATION_PROMPTS:
            raise ValueError(f"{task} is not a valid option for task.")
        prompt = GENERATION_PROMPTS[task]
        prompt = prompt.format(input_text)
        super().__init__(
            resources_path=resources_path,
            model_type=model_type,
            model_name=model_name,
            prompt=prompt,
            max_length=max_length,
            top_k=top_k,
            top_p=top_p,
            num_return_sequences=num_return_sequences,
            no_repeat_ngram_size=no_repeat_ngram_size,
            device=device,
        )  
[docs]class EditGenerator(Generator):
    """Implementation of edit generator."""
[docs]    def __init__(
        self,
        resources_path: str,
        input_text: str,
        model_type: str,
        model_name: str,
        max_length: int,
        top_k: int,
        top_p: float,
        num_return_sequences: int,
        no_repeat_ngram_size: int = 2,
        device: Optional[Union[torch.device, str]] = None,
        input_type: str = "abstract",
    ):
        """PGT generation algorithm.
        Args:
            resources_path: path to the cache.
            input_text: input text for generation.
            model_type: type of the model.
            model_name: name of the model weights/version.
            max_length: max length of the generated text.
            top_k: number of top-k probability token to keep.
            top_p: only tokens with cumulative probabilities summing up to this value are kept.
            num_return_sequences: number of generated sequences.
            no_repeat_ngram_size: size of n-gram to not appear twice.
            device: device where the inference
                is running either as a dedicated class or a string. If not provided is inferred.
            input_type: part of a patent the input text belongs.
        """
        if input_type not in EDITING_TYPES:
            raise ValueError(
                f"{input_type} is not a valid option for editing input type."
            )
        prompt = f"{input_text} <|sep|> Replace the [MASK] tokens in the above {input_type} <|sep|>"
        super().__init__(
            resources_path=resources_path,
            model_type=model_type,
            model_name=model_name,
            prompt=prompt,
            max_length=max_length,
            top_k=top_k,
            top_p=top_p,
            num_return_sequences=num_return_sequences,
            no_repeat_ngram_size=no_repeat_ngram_size,
            device=device,
        ) 
 
[docs]class CoherenceCheckGenerator(Generator):
    """Implementation of coherence check generator."""
[docs]    def __init__(
        self,
        resources_path: str,
        input_a: str,
        input_b: str,
        model_type: str,
        model_name: str,
        max_length: int,
        top_k: int,
        top_p: float,
        num_return_sequences: int,
        no_repeat_ngram_size: int = 2,
        device: Optional[Union[torch.device, str]] = None,
        coherence_type: str = "title-abstract",
    ):
        """PGT generation algorithm.
        Args:
            resources_path: path to the cache.
            input_a: first input for coherence check.
            input_b: second input for coherence check.
            model_type: type of the model.
            model_name: name of the model weights/version.
            max_length: max length of the generated text.
            top_k: number of top-k probability token to keep.
            top_p: only tokens with cumulative probabilities summing up to this value are kept.
            num_return_sequences: number of generated sequences.
            no_repeat_ngram_size: size of n-gram to not appear twice.
            device: device where the inference
                is running either as a dedicated class or a string. If not provided is inferred.
            coherence_type: input types for the check.
        """
        type_a, type_b = self.extract_coherence_types(coherence_type)
        prompt = f"{input_a} <|sep|> {input_b} <|sep|> Do the above {type_a} and {type_b} belong to the same patent? <|sep|>"
        super().__init__(
            resources_path=resources_path,
            model_type=model_type,
            model_name=model_name,
            prompt=prompt,
            max_length=max_length,
            top_k=top_k,
            top_p=top_p,
            num_return_sequences=num_return_sequences,
            no_repeat_ngram_size=no_repeat_ngram_size,
            device=device,
        )