gt4sd.frameworks.granular.dataloader.sampler module¶
Sampler implementation.
Reimplemented starting from: https://github.com/ncullen93/torchsample/blob/ea4d1b3975f68be0521941e733887ed667a1b46e/torchsample/samplers.py. The main reason for reimplementation is to avoid to add a dependency and to control better the logger.
Summary¶
Classes:
Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit. |
Reference¶
- class StratifiedSampler(targets, batch_size, test_size=0.5)[source]¶
Bases:
Sampler
Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit.
- __init__(targets, batch_size, test_size=0.5)[source]¶
Construct a StratifiedSampler.
- Parameters
targets (
Tensor
) – targets tensor.batch_size (
int
) – size of the batch.test_size (
float
) – proportion of samples in the test set. Defaults to 0.5.
- __iter__()[source]¶
Get an iterator over the sample array.
- Return type
Iterator
[ndarray
]- Returns
sample array iterator.
- Yields
a sample array.
- __annotations__ = {}¶
- __doc__ = 'Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit.'¶
- __module__ = 'gt4sd.frameworks.granular.dataloader.sampler'¶
- __parameters__ = ()¶