gt4sd.frameworks.crystals_rfc.rf_classifier module

Model module.

Summary

Classes:

RFC

RandomForest classifier for crystals.

Reference

class RFC(crystal_sys='all')[source]

Bases: object

RandomForest classifier for crystals.

__init__(crystal_sys='all')[source]

Initialize RandomForest classifier.

Parameters

crystal_sys (str) –

crystal systems to be used.

”all” for all the crystal systems. Other seven options are: “monoclinic”, “triclinic”, “orthorhombic”, “trigonal”, “hexagonal”, “cubic”, “tetragonal”

load_data(file_name)[source]

Load dataset.

Parameters

file_name (str) – path of the dataset.

Return type

DataFrame

Returns

Dataframe with the loaded dataset.

split_data(df, test_size=0.2)[source]

Load dataset.

Parameters
  • df (DataFrame) – dataset’s dataframe.

  • test_size (float) – size of the test set.

Return type

Tuple[ndarray, ndarray, ndarray, ndarray]

Returns

Training and testing sets.

normalize_data(train_x, test_x, train_y, test_y)[source]

Normalize dataset.

Parameters
  • train_x (ndarray) – training set’s input.

  • test_x (ndarray) – testing set’s input.

  • train_y (ndarray) – training set’s groundtruth.

  • test_y (ndarray) – testing set’s groundtruth.

Return type

Tuple[ndarray, ndarray, ndarray, ndarray]

Returns

Training and testing sets.

train(x, y)[source]

Train a RandomForest model.

Parameters
  • x (ndarray) – training set’s input.

  • y (ndarray) – training set’s groundtruth.

Return type

RandomForestClassifier

Returns

Trained model.

save(path)[source]

Save model.

Parameters

path (str) – path to store the model.

Return type

None

load_model(path)[source]

Save model.

Parameters

path (str) – path where the file is located.

Return type

None

predict(pred_x)[source]

Predict.

Parameters

pred_x (ndarray) – input.

Return type

List[str]

Returns

Predictions

__dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.crystals_rfc.rf_classifier', '__doc__': 'RandomForest classifier for crystals.', '__init__': <function RFC.__init__>, 'load_data': <function RFC.load_data>, 'split_data': <function RFC.split_data>, 'normalize_data': <function RFC.normalize_data>, 'train': <function RFC.train>, 'save': <function RFC.save>, 'load_model': <function RFC.load_model>, 'predict': <function RFC.predict>, '__dict__': <attribute '__dict__' of 'RFC' objects>, '__weakref__': <attribute '__weakref__' of 'RFC' objects>, '__annotations__': {'maxm': 'np.ndarray'}})
__doc__ = 'RandomForest classifier for crystals.'
__module__ = 'gt4sd.frameworks.crystals_rfc.rf_classifier'
__weakref__

list of weak references to the object (if defined)