gt4sd.training_pipelines.paccmann.vae.core module

PaccMann VAE training utilities.

Summary

Classes:

PaccMannVAEModelArguments

Arguments pertaining to model instantiation.

PaccMannVAETrainingPipeline

Language modeling training pipelines.

Reference

class PaccMannVAETrainingPipeline[source]

Bases: PaccMannTrainingPipeline

Language modeling training pipelines.

train(training_args, model_args, dataset_args)[source]

Generic training function for PaccMann training.

Parameters
  • training_args (Dict[str, Any]) – training arguments passed to the configuration.

  • model_args (Dict[str, Any]) – model arguments passed to the configuration.

  • dataset_args (Dict[str, Any]) – dataset arguments passed to the configuration.

Return type

None

__annotations__ = {}
__doc__ = 'Language modeling training pipelines.'
__module__ = 'gt4sd.training_pipelines.paccmann.vae.core'
class PaccMannVAEModelArguments(n_layers=2, bidirectional=False, rnn_cell_size=512, latent_dim=256, stack_width=50, stack_depth=50, decode_search='sampling', dropout=0.2, generate_len=100, kl_growth=0.003, input_keep=0.85, test_input_keep=1.0, temperature=0.8, embedding='one_hot', vocab_size=380, pad_index=0, embedding_size=380, beam_width=3, top_tokens=5)[source]

Bases: TrainingPipelineArguments

Arguments pertaining to model instantiation.

__name__ = 'PaccMannVAEModelArguments'
n_layers: int = 2
bidirectional: bool = False
rnn_cell_size: int = 512
latent_dim: int = 256
stack_width: int = 50
stack_depth: int = 50
dropout: float = 0.2
generate_len: int = 100
kl_growth: float = 0.003
input_keep: float = 0.85
temperature: float = 0.8
embedding: str = 'one_hot'
vocab_size: int = 380
pad_index: int = 0
embedding_size: int = 380
beam_width: int = 3
top_tokens: int = 5
__annotations__ = {'beam_width': <class 'int'>, 'bidirectional': <class 'bool'>, 'decode_search': <class 'str'>, 'dropout': <class 'float'>, 'embedding': <class 'str'>, 'embedding_size': <class 'int'>, 'generate_len': <class 'int'>, 'input_keep': <class 'float'>, 'kl_growth': <class 'float'>, 'latent_dim': <class 'int'>, 'n_layers': <class 'int'>, 'pad_index': <class 'int'>, 'rnn_cell_size': <class 'int'>, 'stack_depth': <class 'int'>, 'stack_width': <class 'int'>, 'temperature': <class 'float'>, 'test_input_keep': <class 'float'>, 'top_tokens': <class 'int'>, 'vocab_size': <class 'int'>}
__dataclass_fields__ = {'beam_width': Field(name='beam_width',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Width of the beam search.'}),kw_only=False,_field_type=_FIELD), 'bidirectional': Field(name='bidirectional',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RNN cells are bidirectional.'}),kw_only=False,_field_type=_FIELD), 'decode_search': Field(name='decode_search',type=<class 'str'>,default='sampling',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder search strategy.'}),kw_only=False,_field_type=_FIELD), 'dropout': Field(name='dropout',type=<class 'float'>,default=0.2,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dropout rate to apply.'}),kw_only=False,_field_type=_FIELD), 'embedding': Field(name='embedding',type=<class 'str'>,default='one_hot',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Embedding technique for the tokens. 'one_hot' or 'learned'."}),kw_only=False,_field_type=_FIELD), 'embedding_size': Field(name='embedding_size',type=<class 'int'>,default=380,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the embedding vectors.'}),kw_only=False,_field_type=_FIELD), 'generate_len': Field(name='generate_len',type=<class 'int'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Length in tokens of the generated molecules.'}),kw_only=False,_field_type=_FIELD), 'input_keep': Field(name='input_keep',type=<class 'float'>,default=0.85,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Probability to keep input tokens in train.'}),kw_only=False,_field_type=_FIELD), 'kl_growth': Field(name='kl_growth',type=<class 'float'>,default=0.003,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Growth of the KL term weight in the loss.'}),kw_only=False,_field_type=_FIELD), 'latent_dim': Field(name='latent_dim',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the RNN cells.'}),kw_only=False,_field_type=_FIELD), 'n_layers': Field(name='n_layers',type=<class 'int'>,default=2,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of layers for the RNNs.'}),kw_only=False,_field_type=_FIELD), 'pad_index': Field(name='pad_index',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Index for the padding token.'}),kw_only=False,_field_type=_FIELD), 'rnn_cell_size': Field(name='rnn_cell_size',type=<class 'int'>,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the RNN cells.'}),kw_only=False,_field_type=_FIELD), 'stack_depth': Field(name='stack_depth',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Depth of the memory stack for the RNN cell.'}),kw_only=False,_field_type=_FIELD), 'stack_width': Field(name='stack_width',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Width of the memory stack for the RNN cell.'}),kw_only=False,_field_type=_FIELD), 'temperature': Field(name='temperature',type=<class 'float'>,default=0.8,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature for the sampling.'}),kw_only=False,_field_type=_FIELD), 'test_input_keep': Field(name='test_input_keep',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Probability to keep input tokens in test.'}),kw_only=False,_field_type=_FIELD), 'top_tokens': Field(name='top_tokens',type=<class 'int'>,default=5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of tokens to consider in the beam search.'}),kw_only=False,_field_type=_FIELD), 'vocab_size': Field(name='vocab_size',type=<class 'int'>,default=380,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the vocabulary of chemical tokens.'}),kw_only=False,_field_type=_FIELD)}
__dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)
__doc__ = 'Arguments pertaining to model instantiation.'
__eq__(other)

Return self==value.

__hash__ = None
__init__(n_layers=2, bidirectional=False, rnn_cell_size=512, latent_dim=256, stack_width=50, stack_depth=50, decode_search='sampling', dropout=0.2, generate_len=100, kl_growth=0.003, input_keep=0.85, test_input_keep=1.0, temperature=0.8, embedding='one_hot', vocab_size=380, pad_index=0, embedding_size=380, beam_width=3, top_tokens=5)
__match_args__ = ('n_layers', 'bidirectional', 'rnn_cell_size', 'latent_dim', 'stack_width', 'stack_depth', 'decode_search', 'dropout', 'generate_len', 'kl_growth', 'input_keep', 'test_input_keep', 'temperature', 'embedding', 'vocab_size', 'pad_index', 'embedding_size', 'beam_width', 'top_tokens')
__module__ = 'gt4sd.training_pipelines.paccmann.vae.core'
__repr__()

Return repr(self).