gt4sd.algorithms.generation.hugging_face.implementation module

Implementation details for HuggingFace generation algorithms.

Parts of the implementation inspired by: https://github.com/huggingface/transformers/blob/v4.2.1/examples/text-generation/run_generation.py.

Summary

Classes:

Generator

Implementation of a generator.

Functions:

adjust_length_to_model

Adjust sequence length.

prepare_ctrl_input

prepare_prefix_input

set_seed

Set seed for all random number generators.

Reference

set_seed(seed=42)[source]

Set seed for all random number generators.

Parameters

seed (int) – seed to set. Defaults to 42.

Return type

None

prepare_ctrl_input(tokenizer, prompt, **kwargs)[source]
prepare_prefix_input(tokenizer, prompt, **kwargs)[source]
adjust_length_to_model(length, maximum_sequence_length)[source]

Adjust sequence length.

Parameters
  • length (int) – target length.

  • maximum_sequence_length (int) – maximum sequence length.

Returns

the adjusted length.

class Generator(resources_path, model_type, model_name, prompt, length, stop_token, num_beams, do_sample, temperature, repetition_penalty, k, p, prefix, number_of_sequences, device=None)[source]

Bases: object

Implementation of a generator.

__init__(resources_path, model_type, model_name, prompt, length, stop_token, num_beams, do_sample, temperature, repetition_penalty, k, p, prefix, number_of_sequences, device=None)[source]

An HuggingFace generation algorithm.

Parameters
  • resources_path (str) – path to the cache.

  • model_type (str) – type of the model.

  • model_name (str) – name of the model weights/version.

  • prompt (str) – prompt for text generation.

  • length (int) – length of the generated text.

  • stop_token (str) – stop token for text generation.

  • temperature (float) – temperature for sampling, the lower the greedier the sampling.

  • repetition_penalty (float) – primarily useful for CTRL model, where 1.2 should be used.

  • k (int) – number of top-k probability token to keep.

  • p (float) – only tokens with cumulative probabilities summing up to this value are kept.

  • prefix (str) – text defining context provided prior to the prompt.

  • number_of_sequences (int) – number of generated sequences.

  • device (Union[device, str, None]) – device where the inference is running either as a dedicated class or a string. If not provided is inferred.

load_model()[source]

Load a pretrained HuggingFace generation model.

Return type

None

sample()[source]

Sample text snippets.

Return type

List[str]

Returns

generated text snippets.

__dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.hugging_face.implementation', '__doc__': 'Implementation of a generator.', '__init__': <function Generator.__init__>, 'load_model': <function Generator.load_model>, 'sample': <function Generator.sample>, '__dict__': <attribute '__dict__' of 'Generator' objects>, '__weakref__': <attribute '__weakref__' of 'Generator' objects>, '__annotations__': {}})
__doc__ = 'Implementation of a generator.'
__module__ = 'gt4sd.algorithms.generation.hugging_face.implementation'
__weakref__

list of weak references to the object (if defined)