BaseDataset

class hybrid_learning.datasets.base.BaseDataset(split=None, dataset_root=None, transforms=None, transforms_cache=None, after_cache_transforms=None, device='cpu')[source]

Bases: Dataset

Abstract base class for tuple datasets with storage location.

Derived datasets should yield tuples of (input, target). The transformation transforms is applied to data tuples before return from __getitem__() can be controlled. The default for transforms is given by The default for transforms is given by _default_transforms. Override in sub-classes if necessary. The default combination of collected dataset tuples and _default_transforms should yield a tuple of torch.Tensor or dicts thereof.

The hybrid_learning.datasets.base.BaseDataset.dataset_root is assumed to provide information about the storage location. Best, all components (input data, annotations, etc.) should be stored relative to this root location.

The transformed tuple values are cached by transforms_cache if it is given. Then values are only collected and transformed if they cannot be loaded from the cache. To get the cache descriptor for an entry the, the descriptor() method is consulted. Make sure to override this appropriately (e.g. by image ID or image file name).

Note

In case a CacheTuple is used, make sure that None is returned if any tuple value is None.

Public Data Attributes:

settings

Settings of the instance.

Public Methods:

getitem(idx)

Get data item tuple from idx in this dataset.

descriptor(i)

Return a unique descriptor for the item at position i.

Special Methods:

__init__([split, dataset_root, transforms, ...])

Init.

__len__()

Number of data points in the dataset; to be implemented in subclasses.

__getitem__(idx)

Get item from idx in dataset with transformations applied.

__repr__()

Nice printing function.

Inherited from : py: class:Dataset

__getitem__(idx)

Get item from idx in dataset with transformations applied.

__add__(other)


__getitem__(idx)[source]

Get item from idx in dataset with transformations applied. Transformations must be stored as single tuple transformation in transforms.

Returns

tuple output of getitem() transformed by transforms

Parameters

idx (int) –

__init__(split=None, dataset_root=None, transforms=None, transforms_cache=None, after_cache_transforms=None, device='cpu')[source]

Init.

Parameters
  • split (Optional[DatasetSplit]) – The split of the dataset (e.g. DatasetSplit.TRAIN, DatasetSplit.VAL, DatasetSplit.TEST).

  • dataset_root (Optional[str]) – The location where to store the dataset.

  • transforms (Optional[Callable]) – The transformations to be applied to the data when loaded; defaults to _default_transforms

  • transforms_cache (Optional[Cache]) – optional cache instance for caching transformed tuples; must return None in case one of the tuple values has not been cached yet; see transforms_cache

  • after_cache_transforms (Optional[Callable]) – transformations applied after consulting the cache (no matter, whether the tuples was retrieved from cache or not); by default, tensor gradients are disabled and tensors are moved to a common device

  • device (Optional[Union[str, device]]) – device to use in the default after_cache_transforms

abstract __len__()[source]

Number of data points in the dataset; to be implemented in subclasses.

__repr__()[source]

Nice printing function.

Return type

str

abstract descriptor(i)[source]

Return a unique descriptor for the item at position i. This can e.g. be an image ID or the image file name. It is used for caching.

Parameters

i (int) –

Return type

Hashable

abstract getitem(idx)[source]

Get data item tuple from idx in this dataset.

Parameters

idx (int) – index to retrieve data point from

Returns

tuple (input, label) with

  • input one of: image (as PIL.Image.Image), Radar/Lidar point cloud

  • label one of:

    • None

    • class label (as torch.Tensor or bool),

    • semantic segmentation map (as PIL.Image.Image or torch.Tensor compatible with torchvision transforms),

    • bounding box

    • string-indexed dict of combinations

Return type

Tuple[Union[Tensor, Image], Union[Tensor, Image, Dict[Tensor, Image]]]

__parameters__ = ()
after_cache_transforms: Callable

Transformation function applied after consulting the cache (no matter, whether the tuples was retrieved from cache or not). Use these transformations instead of transforms to ensure the transformation is always applied, regardless of caching. By default, tensor gradients are disabled and tensors are moved to a common device (see _get_default_after_cache_trafo()).

dataset_root: str

Assuming the dataset is saved in some storage location, a root from which to navigate to the dataset information.

property settings: Dict[str, Any]

Settings of the instance. transforms info is skipped if set to default.

split: Optional[DatasetSplit]

Optional specification what use-case this dataset is meant to represent, e.g. training, validation, or testing.

transforms: Callable

Transformation function applied to each item tuple before return. Applied in __getitem__(). Default transformations are sub-class-specific. Items transformed using transforms can be cached by setting transforms_cache. If the transformations should be applied always, regardless of caching, use after_cache_transforms.

transforms_cache: Optional[Cache]

Cache for the transformed (input, target) tuples. If set, __getitem__() will first try to load the tuple from cache before loading and transforming it normally. Items not in the cache are put in there after transforms is applied.