gt4sd.frameworks.granular.dataloader.sampler module

Sampler implementation.

Reimplemented starting from: https://github.com/ncullen93/torchsample/blob/ea4d1b3975f68be0521941e733887ed667a1b46e/torchsample/samplers.py. The main reason for reimplementation is to avoid to add a dependency and to control better the logger.

Summary

Classes:

StratifiedSampler

Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit.

Reference

class StratifiedSampler(targets, batch_size, test_size=0.5)[source]

Bases: Sampler

Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit.

__init__(targets, batch_size, test_size=0.5)[source]

Construct a StratifiedSampler.

Parameters
  • targets (Tensor) – targets tensor.

  • batch_size (int) – size of the batch.

  • test_size (float) – proportion of samples in the test set. Defaults to 0.5.

gen_sample_array()[source]

Get sample array.

Return type

ndarray

Returns

sample array.

__iter__()[source]

Get an iterator over the sample array.

Return type

Iterator[ndarray]

Returns

sample array iterator.

Yields

a sample array.

__len__()[source]

Length of the sampler.

Return type

int

Returns

the sampler length.

__annotations__ = {}
__doc__ = 'Implementation of a sampler for tensors based on scikit-learn StratifiedShuffleSplit.'
__module__ = 'gt4sd.frameworks.granular.dataloader.sampler'
__parameters__ = ()