gt4sd.frameworks.crystals_rfc.rf_classifier module¶
Model module.
Summary¶
Classes:
RandomForest classifier for crystals. |
Reference¶
- class RFC(crystal_sys='all')[source]¶
Bases:
object
RandomForest classifier for crystals.
- __init__(crystal_sys='all')[source]¶
Initialize RandomForest classifier.
- Parameters
crystal_sys (
str
) –crystal systems to be used.
”all” for all the crystal systems. Other seven options are: “monoclinic”, “triclinic”, “orthorhombic”, “trigonal”, “hexagonal”, “cubic”, “tetragonal”
- load_data(file_name)[source]¶
Load dataset.
- Parameters
file_name (
str
) – path of the dataset.- Return type
DataFrame
- Returns
Dataframe with the loaded dataset.
- split_data(df, test_size=0.2)[source]¶
Load dataset.
- Parameters
df (
DataFrame
) – dataset’s dataframe.test_size (
float
) – size of the test set.
- Return type
Tuple
[ndarray
,ndarray
,ndarray
,ndarray
]- Returns
Training and testing sets.
- normalize_data(train_x, test_x, train_y, test_y)[source]¶
Normalize dataset.
- Parameters
train_x (
ndarray
) – training set’s input.test_x (
ndarray
) – testing set’s input.train_y (
ndarray
) – training set’s groundtruth.test_y (
ndarray
) – testing set’s groundtruth.
- Return type
Tuple
[ndarray
,ndarray
,ndarray
,ndarray
]- Returns
Training and testing sets.
- train(x, y)[source]¶
Train a RandomForest model.
- Parameters
x (
ndarray
) – training set’s input.y (
ndarray
) – training set’s groundtruth.
- Return type
RandomForestClassifier
- Returns
Trained model.
- load_model(path)[source]¶
Save model.
- Parameters
path (
str
) – path where the file is located.- Return type
None
- predict(pred_x)[source]¶
Predict.
- Parameters
pred_x (
ndarray
) – input.- Return type
List
[str
]- Returns
Predictions
- __dict__ = mappingproxy({'__module__': 'gt4sd.frameworks.crystals_rfc.rf_classifier', '__doc__': 'RandomForest classifier for crystals.', '__init__': <function RFC.__init__>, 'load_data': <function RFC.load_data>, 'split_data': <function RFC.split_data>, 'normalize_data': <function RFC.normalize_data>, 'train': <function RFC.train>, 'save': <function RFC.save>, 'load_model': <function RFC.load_model>, 'predict': <function RFC.predict>, '__dict__': <attribute '__dict__' of 'RFC' objects>, '__weakref__': <attribute '__weakref__' of 'RFC' objects>, '__annotations__': {'maxm': 'np.ndarray'}})¶
- __doc__ = 'RandomForest classifier for crystals.'¶
- __module__ = 'gt4sd.frameworks.crystals_rfc.rf_classifier'¶
- __weakref__¶
list of weak references to the object (if defined)