gt4sd.algorithms.prediction.topics_zero_shot.implementation module

Implementation of the zero-shot classifier.

Summary

Classes:

ZeroShotClassifier

Zero-shot classifier based on the HuggingFace pipeline leveraging MLNI.

Reference

class ZeroShotClassifier(resources_path, model_name, device=None)[source]

Bases: object

Zero-shot classifier based on the HuggingFace pipeline leveraging MLNI.

__init__(resources_path, model_name, device=None)[source]

Initialize ZeroShotClassifier.

Parameters
  • resources_path (str) – path where to load hypothesis, candidate labels and, optionally, the model.

  • model_name (str) – name of the model to load from the cache or download from HuggingFace.

  • device (Union[device, str, None]) – device where the inference is running either as a dedicated class or a string. If not provided is inferred.

load_pipeline()[source]

Load zero shot classification MLNI pipeline.

Return type

None

predict(text)[source]

Get sorted classification labels.

Parameters

text (str) – text to classify.

Return type

List[str]

Returns

labels sorted by score from highest to lowest.

__dict__ = mappingproxy({'__module__': 'gt4sd.algorithms.prediction.topics_zero_shot.implementation', '__doc__': '\n    Zero-shot classifier based on the HuggingFace pipeline leveraging MLNI.\n    ', '__init__': <function ZeroShotClassifier.__init__>, 'load_pipeline': <function ZeroShotClassifier.load_pipeline>, 'predict': <function ZeroShotClassifier.predict>, '__dict__': <attribute '__dict__' of 'ZeroShotClassifier' objects>, '__weakref__': <attribute '__weakref__' of 'ZeroShotClassifier' objects>, '__annotations__': {}})
__doc__ = '\n    Zero-shot classifier based on the HuggingFace pipeline leveraging MLNI.\n    '
__module__ = 'gt4sd.algorithms.prediction.topics_zero_shot.implementation'
__weakref__

list of weak references to the object (if defined)