gt4sd.training_pipelines.paccmann.vae.core module¶
PaccMann VAE training utilities.
Summary¶
Classes:
Arguments pertaining to model instantiation. |
|
Language modeling training pipelines. |
Reference¶
- class PaccMannVAETrainingPipeline[source]¶
Bases:
PaccMannTrainingPipeline
Language modeling training pipelines.
- train(training_args, model_args, dataset_args)[source]¶
Generic training function for PaccMann training.
- Parameters
training_args (
Dict
[str
,Any
]) – training arguments passed to the configuration.model_args (
Dict
[str
,Any
]) – model arguments passed to the configuration.dataset_args (
Dict
[str
,Any
]) – dataset arguments passed to the configuration.
- Return type
None
- __annotations__ = {}¶
- __doc__ = 'Language modeling training pipelines.'¶
- __module__ = 'gt4sd.training_pipelines.paccmann.vae.core'¶
- class PaccMannVAEModelArguments(n_layers=2, bidirectional=False, rnn_cell_size=512, latent_dim=256, stack_width=50, stack_depth=50, decode_search='sampling', dropout=0.2, generate_len=100, kl_growth=0.003, input_keep=0.85, test_input_keep=1.0, temperature=0.8, embedding='one_hot', vocab_size=380, pad_index=0, embedding_size=380, beam_width=3, top_tokens=5)[source]¶
Bases:
TrainingPipelineArguments
Arguments pertaining to model instantiation.
- __name__ = 'PaccMannVAEModelArguments'¶
- n_layers: int = 2¶
- bidirectional: bool = False¶
- rnn_cell_size: int = 512¶
- latent_dim: int = 256¶
- stack_width: int = 50¶
- stack_depth: int = 50¶
- decode_search: str = 'sampling'¶
- dropout: float = 0.2¶
- generate_len: int = 100¶
- kl_growth: float = 0.003¶
- input_keep: float = 0.85¶
- temperature: float = 0.8¶
- embedding: str = 'one_hot'¶
- vocab_size: int = 380¶
- pad_index: int = 0¶
- embedding_size: int = 380¶
- beam_width: int = 3¶
- top_tokens: int = 5¶
- __annotations__ = {'beam_width': <class 'int'>, 'bidirectional': <class 'bool'>, 'decode_search': <class 'str'>, 'dropout': <class 'float'>, 'embedding': <class 'str'>, 'embedding_size': <class 'int'>, 'generate_len': <class 'int'>, 'input_keep': <class 'float'>, 'kl_growth': <class 'float'>, 'latent_dim': <class 'int'>, 'n_layers': <class 'int'>, 'pad_index': <class 'int'>, 'rnn_cell_size': <class 'int'>, 'stack_depth': <class 'int'>, 'stack_width': <class 'int'>, 'temperature': <class 'float'>, 'test_input_keep': <class 'float'>, 'top_tokens': <class 'int'>, 'vocab_size': <class 'int'>}¶
- __dataclass_fields__ = {'beam_width': Field(name='beam_width',type=<class 'int'>,default=3,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Width of the beam search.'}),kw_only=False,_field_type=_FIELD), 'bidirectional': Field(name='bidirectional',type=<class 'bool'>,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether the RNN cells are bidirectional.'}),kw_only=False,_field_type=_FIELD), 'decode_search': Field(name='decode_search',type=<class 'str'>,default='sampling',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Decoder search strategy.'}),kw_only=False,_field_type=_FIELD), 'dropout': Field(name='dropout',type=<class 'float'>,default=0.2,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Dropout rate to apply.'}),kw_only=False,_field_type=_FIELD), 'embedding': Field(name='embedding',type=<class 'str'>,default='one_hot',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Embedding technique for the tokens. 'one_hot' or 'learned'."}),kw_only=False,_field_type=_FIELD), 'embedding_size': Field(name='embedding_size',type=<class 'int'>,default=380,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the embedding vectors.'}),kw_only=False,_field_type=_FIELD), 'generate_len': Field(name='generate_len',type=<class 'int'>,default=100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Length in tokens of the generated molecules.'}),kw_only=False,_field_type=_FIELD), 'input_keep': Field(name='input_keep',type=<class 'float'>,default=0.85,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Probability to keep input tokens in train.'}),kw_only=False,_field_type=_FIELD), 'kl_growth': Field(name='kl_growth',type=<class 'float'>,default=0.003,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Growth of the KL term weight in the loss.'}),kw_only=False,_field_type=_FIELD), 'latent_dim': Field(name='latent_dim',type=<class 'int'>,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the RNN cells.'}),kw_only=False,_field_type=_FIELD), 'n_layers': Field(name='n_layers',type=<class 'int'>,default=2,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of layers for the RNNs.'}),kw_only=False,_field_type=_FIELD), 'pad_index': Field(name='pad_index',type=<class 'int'>,default=0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Index for the padding token.'}),kw_only=False,_field_type=_FIELD), 'rnn_cell_size': Field(name='rnn_cell_size',type=<class 'int'>,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the RNN cells.'}),kw_only=False,_field_type=_FIELD), 'stack_depth': Field(name='stack_depth',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Depth of the memory stack for the RNN cell.'}),kw_only=False,_field_type=_FIELD), 'stack_width': Field(name='stack_width',type=<class 'int'>,default=50,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Width of the memory stack for the RNN cell.'}),kw_only=False,_field_type=_FIELD), 'temperature': Field(name='temperature',type=<class 'float'>,default=0.8,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature for the sampling.'}),kw_only=False,_field_type=_FIELD), 'test_input_keep': Field(name='test_input_keep',type=<class 'float'>,default=1.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Probability to keep input tokens in test.'}),kw_only=False,_field_type=_FIELD), 'top_tokens': Field(name='top_tokens',type=<class 'int'>,default=5,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of tokens to consider in the beam search.'}),kw_only=False,_field_type=_FIELD), 'vocab_size': Field(name='vocab_size',type=<class 'int'>,default=380,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Size of the vocabulary of chemical tokens.'}),kw_only=False,_field_type=_FIELD)}¶
- __dataclass_params__ = _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False)¶
- __doc__ = 'Arguments pertaining to model instantiation.'¶
- __eq__(other)¶
Return self==value.
- __hash__ = None¶
- __init__(n_layers=2, bidirectional=False, rnn_cell_size=512, latent_dim=256, stack_width=50, stack_depth=50, decode_search='sampling', dropout=0.2, generate_len=100, kl_growth=0.003, input_keep=0.85, test_input_keep=1.0, temperature=0.8, embedding='one_hot', vocab_size=380, pad_index=0, embedding_size=380, beam_width=3, top_tokens=5)¶
- __match_args__ = ('n_layers', 'bidirectional', 'rnn_cell_size', 'latent_dim', 'stack_width', 'stack_depth', 'decode_search', 'dropout', 'generate_len', 'kl_growth', 'input_keep', 'test_input_keep', 'temperature', 'embedding', 'vocab_size', 'pad_index', 'embedding_size', 'beam_width', 'top_tokens')¶
- __module__ = 'gt4sd.training_pipelines.paccmann.vae.core'¶
- __repr__()¶
Return repr(self).