gt4sd.algorithms.generation.hugging_face.implementation module¶
Implementation details for HuggingFace generation algorithms.
Parts of the implementation inspired by: https://github.com/huggingface/transformers/blob/v4.2.1/examples/text-generation/run_generation.py.
Summary¶
Classes:
Implementation of a generator. |
Functions:
Adjust sequence length. |
|
Set seed for all random number generators. |
Reference¶
- set_seed(seed=42)[source]¶
Set seed for all random number generators.
- Parameters
seed (
int
) – seed to set. Defaults to 42.- Return type
None
- adjust_length_to_model(length, maximum_sequence_length)[source]¶
Adjust sequence length.
- Parameters
length (
int
) – target length.maximum_sequence_length (
int
) – maximum sequence length.
- Returns
the adjusted length.
- class Generator(resources_path, model_type, model_name, prompt, length, stop_token, num_beams, do_sample, temperature, repetition_penalty, k, p, prefix, number_of_sequences, device=None)[source]¶
Bases:
object
Implementation of a generator.
- __init__(resources_path, model_type, model_name, prompt, length, stop_token, num_beams, do_sample, temperature, repetition_penalty, k, p, prefix, number_of_sequences, device=None)[source]¶
An HuggingFace generation algorithm.
- Parameters
resources_path (
str
) – path to the cache.model_type (
str
) – type of the model.model_name (
str
) – name of the model weights/version.prompt (
str
) – prompt for text generation.length (
int
) – length of the generated text.stop_token (
str
) – stop token for text generation.temperature (
float
) – temperature for sampling, the lower the greedier the sampling.repetition_penalty (
float
) – primarily useful for CTRL model, where 1.2 should be used.k (
int
) – number of top-k probability token to keep.p (
float
) – only tokens with cumulative probabilities summing up to this value are kept.prefix (
str
) – text defining context provided prior to the prompt.number_of_sequences (
int
) – number of generated sequences.device (
Union
[device
,str
,None
]) – device where the inference is running either as a dedicated class or a string. If not provided is inferred.
- __dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.generation.hugging_face.implementation', '__doc__': 'Implementation of a generator.', '__init__': <function Generator.__init__>, 'load_model': <function Generator.load_model>, 'sample': <function Generator.sample>, '__dict__': <attribute '__dict__' of 'Generator' objects>, '__weakref__': <attribute '__weakref__' of 'Generator' objects>, '__annotations__': {}})¶
- __doc__ = 'Implementation of a generator.'¶
- __module__ = 'gt4sd.algorithms.generation.hugging_face.implementation'¶
- __weakref__¶
list of weak references to the object (if defined)