gt4sd.frameworks.cgcnn.model module¶
Model module.
Summary¶
Classes:
Convolutional operation on graphs. |
|
Create a crystal graph convolutional neural network for predicting total material properties. |
|
Normalize a Tensor and restore it later. |
Reference¶
- class ConvLayer(atom_fea_len, nbr_fea_len)[source]¶
Bases:
Module
Convolutional operation on graphs.
- __init__(atom_fea_len, nbr_fea_len)[source]¶
Initialize ConvLayer.
- Parameters
atom_fea_len (
int
) – int Number of atom hidden features.nbr_fea_len (
int
) – int Number of bond features.
- forward(atom_in_fea, nbr_fea, nbr_fea_idx)[source]¶
Forward pass.
N: Total number of atoms in the batch. M: Max number of neighbors.
- Parameters
atom_in_fea (
Tensor
) – Variable(torch.Tensor) shape (N, atom_fea_len) Atom hidden features before convolution.nbr_fea (
Tensor
) – Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom’s M neighbors.nbr_fea_idx (
LongTensor
) – torch.LongTensor shape (N, M) Indices of M neighbors of each atom.
- Returns
- nn.Variable shape (N, atom_fea_len)
Atom hidden features after convolution.
- Return type
atom_out_fea
- __annotations__ = {}¶
- __doc__ = 'Convolutional operation on graphs.'¶
- __module__ = 'gt4sd.frameworks.cgcnn.model'¶
- class CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len, atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, classification=False)[source]¶
Bases:
Module
Create a crystal graph convolutional neural network for predicting total material properties.
- __init__(orig_atom_fea_len, nbr_fea_len, atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, classification=False)[source]¶
Initialize CrystalGraphConvNet.
- Parameters
orig_atom_fea_len (
int
) – int Number of atom features in the input.nbr_fea_len (
int
) – int Number of bond features.atom_fea_len (
int
) – int Number of hidden atom features in the convolutional layers.n_conv (
int
) – int Number of convolutional layers.h_fea_len (
int
) – int Number of hidden features after pooling.n_h (
int
) – int Number of hidden layers after pooling.
- forward(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)[source]¶
Forward pass.
N: Total number of atoms in the batch. M: Max number of neighbors. N0: Total number of crystals in the batch.
- Parameters
atom_fea (
Tensor
) – Variable(torch.Tensor) shape (N, orig_atom_fea_len) Atom features from atom type.nbr_fea (
Tensor
) – Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom’s M neighbors.nbr_fea_idx (
LongTensor
) – torch.LongTensor shape (N, M) Indices of M neighbors of each atom.crystal_atom_idx (
LongTensor
) – list of torch.LongTensor of length N0 Mapping from the crystal idx to atom idx.
- Returns
- nn.Variable shape (N, )
Atom hidden features after convolution.
- Return type
prediction
- pooling(atom_fea, crystal_atom_idx)[source]¶
Pooling the atom features to crystal features.
N: Total number of atoms in the batch. N0: Total number of crystals in the batch.
- Parameters
atom_fea (
Tensor
) – Variable(torch.Tensor) shape (N, atom_fea_len) Atom feature vectors of the batch.crystal_atom_idx (
LongTensor
) – list of torch.LongTensor of length N0 Mapping from the crystal idx to atom idx.
- Return type
Tensor
- __annotations__ = {}¶
- __doc__ = 'Create a crystal graph convolutional neural network for predicting total material properties.'¶
- __module__ = 'gt4sd.frameworks.cgcnn.model'¶
- class Normalizer(tensor)[source]¶
Bases:
object
Normalize a Tensor and restore it later.
- norm(tensor)[source]¶
Noramlize a tensor.
- Parameters
tensor (
Tensor
) – tensor to be normalized.- Return type
Tensor
- Returns
normalized tensor.
- denorm(normed_tensor)[source]¶
Denormalized tensor.
- Parameters
tensor – tensor to be denormalized:
- Return type
Tensor
- Returns
denormalized tensor.
- state_dict()[source]¶
Return the state dict of normalizer.
- Return type
Dict
[str
,Tensor
]- Returns
dictionary including the used mean and std values.
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.cgcnn.model', '__doc__': 'Normalize a Tensor and restore it later.', '__init__': <function Normalizer.__init__>, 'norm': <function Normalizer.norm>, 'denorm': <function Normalizer.denorm>, 'state_dict': <function Normalizer.state_dict>, 'load_state_dict': <function Normalizer.load_state_dict>, '__dict__': <attribute '__dict__' of 'Normalizer' objects>, '__weakref__': <attribute '__weakref__' of 'Normalizer' objects>, '__annotations__': {}})¶
- __doc__ = 'Normalize a Tensor and restore it later.'¶
- __module__ = 'gt4sd.frameworks.cgcnn.model'¶
- __weakref__¶
list of weak references to the object (if defined)