
Data module.




Initialize atom feature vectors using a JSON file, which is a python dictionary mapping from element number to a list representing the feature vector of the element.


Base class for intializing the vector representation for atoms.


The CIFData dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files.


Expands the distance by Gaussian basis.



Collate a list of data and return a batch for predicting crystal properties.


Utility function for dividing a dataset to train, val, test datasets.


get_train_val_test_loader(dataset, collate_fn=<function default_collate>, batch_size=64, train_ratio=None, val_ratio=0.1, test_ratio=0.1, return_test=False, num_workers=1, pin_memory=False, **kwargs)[source]

Utility function for dividing a dataset to train, val, test datasets.

!!! The dataset needs to be shuffled before using the function !!!

  • dataset (Dataset) – The full dataset to be divided.

  • collate_fn (Callable[[List[Any]], Any]) –

  • batch_size (int) – int.

  • train_ratio (Optional[float, None]) – float.

  • val_ratio (float) – float.

  • test_ratio (float) – float.

  • return_test (bool) – bool. Whether to return the test dataset loader. If False, the last test_size data will be hidden.

  • num_workers (int) – int.

  • pin_memory (bool) – bool.


DataLoader that random samples the training data.


DataLoader that random samples the validation data.

DataLoader that random samples the test data, Returns if


Return type



Collate a list of data and return a batch for predicting crystal properties.


dataset_list (List[Any]) –

list of tuples for each data point. (atom_fea, nbr_fea, nbr_fea_idx, target)

atom_fea: torch.Tensor shape (n_i, atom_fea_len). nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len). nbr_fea_idx: torch.LongTensor shape (n_i, M). target: torch.Tensor shape (1, ). cif_id: str or int.

Return type

Tuple[Tuple[Tensor, Tensor, Tensor, List[LongTensor]], Tensor, List[Any]]


N = sum(n_i); N0 = sum(i) batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)

Atom features from atom type.

batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len)

Bond features of each atom’s M neighbors.

batch_nbr_fea_idx: torch.LongTensor shape (N, M)

Indices of M neighbors of each atom.

crystal_atom_idx: list of torch.LongTensor of length N0

Mapping from the crystal idx to atom idx.

target: torch.Tensor shape (N, 1)

Target value for prediction.

batch_cif_ids: list.

class GaussianDistance(dmin, dmax, step, var=None)[source]

Bases: object

Expands the distance by Gaussian basis.

Unit: angstrom

__init__(dmin, dmax, step, var=None)[source]
  • dmin (float) – float Minimum interatomic distance.

  • dmax (float) – float Maximum interatomic distance.

  • step (float) – float Step size for the Gaussian filter.


Apply Gaussian disntance filter to a numpy distance array.


distance – np.array shape n-d array A distance matrix of any shape.


shape (n+1)-d array

Expanded distance matrix with the last dimension of length len(self.filter).

Return type


__dict__ = mappingproxy({'__module__': '', '__doc__': 'Expands the distance by Gaussian basis.\n\n    Unit: angstrom\n    ', '__init__': <function GaussianDistance.__init__>, 'expand': <function GaussianDistance.expand>, '__dict__': <attribute '__dict__' of 'GaussianDistance' objects>, '__weakref__': <attribute '__weakref__' of 'GaussianDistance' objects>, '__annotations__': {}})
__doc__ = 'Expands the distance by Gaussian basis.\n\n    Unit: angstrom\n    '
__module__ = ''

list of weak references to the object (if defined)

class AtomInitializer(atom_types)[source]

Bases: object

Base class for intializing the vector representation for atoms.

!!! Use one AtomInitializer per dataset !!!

__dict__ = mappingproxy({'__module__': '', '__doc__': 'Base class for intializing the vector representation for atoms.\n\n    !!! Use one AtomInitializer per dataset !!!\n    ', '__init__': <function AtomInitializer.__init__>, 'get_atom_fea': <function AtomInitializer.get_atom_fea>, 'load_state_dict': <function AtomInitializer.load_state_dict>, 'state_dict': <function AtomInitializer.state_dict>, 'decode': <function AtomInitializer.decode>, '__dict__': <attribute '__dict__' of 'AtomInitializer' objects>, '__weakref__': <attribute '__weakref__' of 'AtomInitializer' objects>, '__annotations__': {}})
__doc__ = 'Base class for intializing the vector representation for atoms.\n\n    !!! Use one AtomInitializer per dataset !!!\n    '
__module__ = ''

list of weak references to the object (if defined)

class AtomCustomJSONInitializer(elem_embedding_file)[source]

Bases: AtomInitializer

Initialize atom feature vectors using a JSON file, which is a python dictionary mapping from element number to a list representing the feature vector of the element.


elem_embedding_file (str) – str The path to the .json file.

__annotations__ = {}
__doc__ = '\n    Initialize atom feature vectors using a JSON file, which is a python\n    dictionary mapping from element number to a list representing the\n    feature vector of the element.\n\n    '
__module__ = ''
class CIFData(root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=123, atom_initialization=None)[source]

Bases: Dataset

The CIFData dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files. The dataset should have the following directory structure:

root_dir ├── id_prop.csv ├── atom_init.json ├── id0.cif ├── id1.cif ├── …

id_prop.csv: a CSV file with two columns. The first column recodes a unique ID for each crystal, and the second column recodes the value of target property.

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that recodes the crystal structure, where ID is the unique ID for the crystal.

__init__(root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=123, atom_initialization=None)[source]
  • root_dir (str) – str The path to the root directory of the dataset.

  • max_num_nbr (int) – int The maximum number of neighbors while constructing the crystal graph.

  • radius (int) – float The cutoff radius for searching neighbors.

  • dmin (int) – float The minimum distance for constructing GaussianDistance.

  • step (float) – float The step size for constructing GaussianDistance.

  • random_seed (int) – int Random seed for shuffling the dataset.

  • atom_initialization (Optional[AtomCustomJSONInitializer, None]) – AtomInitializer The atom initializer for initializing the atom feature vectors. Defaults to None, in which case a atom_init.json should be in root_dir.

__annotations__ = {}
__doc__ = '\n    The CIFData dataset is a wrapper for a dataset where the crystal structures\n    are stored in the form of CIF files. The dataset should have the following\n    directory structure:\n\n    root_dir\n    ├── id_prop.csv\n    ├── atom_init.json\n    ├── id0.cif\n    ├── id1.cif\n    ├── ...\n\n    id_prop.csv: a CSV file with two columns. The first column recodes a\n    unique ID for each crystal, and the second column recodes the value of\n    target property.\n\n    atom_init.json: a JSON file that stores the initialization vector for each\n    element.\n\n    ID.cif: a CIF file that recodes the crystal structure, where ID is the\n    unique ID for the crystal.\n    '
__module__ = ''
__parameters__ = ()

idx (int) – index.


torch.Tensor shape (n_i, atom_fea_len). nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len). nbr_fea_idx: torch.LongTensor shape (n_i, M). target: torch.Tensor shape (1, ). cif_id: str or int.

Return type
