DataTriple

class hybrid_learning.datasets.base.DataTriple(data=None, *, train=None, val=None, test=None, train_val=None, validator=None, **split_kwargs)[source]

Bases: object

Tuple of train/test/validation datasets (w/ automatic splitting if necessary). The splitting is conducted on init. This data structure is considered immutable. So, in order to re-do the splitting, create a new instance with the old specification.

To access the held splits either use

  • the corresponding attributes,

  • the dict-like getter functionality, or

  • the dictionary representation of the tuple via as_dict().

Public Data Attributes:

DEFAULT_VAL_SPLIT

Default validation split proportion.

train

Training data set.

val

Validation data set.

val_split

Value of \(\frac{len(val)} {len(val) + len(train)}\) if none of the datasets is None or empty.

test

Testing dataset split.

test_split

Value of \(\frac{len(test)} {len(test) + len(train\_val)}\) if none of the datasets is None or empty.

train_val

Combined dataset of training and validation data.

data

Concatenation of all data (train, val, test) stored in this tuple.

info

Provide a string with some statistics on the held datasets.

Public Methods:

validate_by(validator)

Validate all data splits using validator, which raises in case of invalid format.

as_dict()

Dict of the splits (train, val, test) held in this triple.

items()

Dataset split items.

keys()

Dataset split keys.

Special Methods:

__init__([data, train, val, test, ...])

Init.

__repr__()

String representation of the held dataset splits.

__eq__(other)

Check that all data sub-sets are the same objects.

__getitem__(key)

Get dataset split by split identifier.


Parameters
__eq__(other)[source]

Check that all data sub-sets are the same objects.

Return type

bool

__getitem__(key)[source]

Get dataset split by split identifier.

Parameters

key (DatasetSplit) –

__init__(data=None, *, train=None, val=None, test=None, train_val=None, validator=None, **split_kwargs)[source]

Init.

Exactly one combination of the following must be given:

  • train, test, val

  • train_val, test

  • data

Parameters
Raises

ValueError if the data specification is insufficient or ambiguous, or if the datasets do not pass the validity check

__repr__()[source]

String representation of the held dataset splits.

Return type

str

as_dict()[source]

Dict of the splits (train, val, test) held in this triple.

Return type

Dict[DatasetSplit, torch.utils.data.Dataset]

classmethod from_dict(splits)[source]

Create DataTriple from a dict of datasets indexed by their split.

Parameters

splits (Dict[DatasetSplit, torch.utils.data.Dataset]) –

Return type

DataTriple

items()[source]

Dataset split items. Items of as_dict() output.

Return type

ItemsView[DatasetSplit, torch.utils.data.Dataset]

keys()[source]

Dataset split keys. Keys of as_dict() output.

Return type

KeysView[DatasetSplit]

static split_dataset(dataset, indices1=None, indices2=None, len1=None, split1=None)[source]

Split dataset exhaustively into two subsets, either randomly or according to indices. Yields the resulting splits without changing dataset. For random splitting, the length len1 or split proportion split1 of the first split are used. For splitting by indices, the indices are validated (may take some time …).

Parameter constraints:

  • At least one of the optional splitting specifiers must be given.

  • Only true splits of dataset are allowed, i.e. indices if given must not occur twice!

  • Precedence of given specifiers is as follows (strongest to weakest):

    • indices

    • len

    • split

Parameters
  • dataset – the dataset to split

  • indices1 (Optional[Sequence[int]]) – Optional indices of the first data split; must be disjoint to indices2 and contain no duplicates; defaults to a random set of indices or those not in indices2 if that is given

  • indices2 (Optional[Sequence[int]]) – see indices1

  • len1 (Optional[int]) – length of the desired first data split

  • split1 (Optional[float]) – proportion of the data samples in second data split of all dataset samples

Return type

Tuple[torch.utils.data.Subset, torch.utils.data.Subset]

classmethod split_train_val(train_val_data, train_indices=None, val_indices=None, val_len=None, val_split=None, **ignored_args)[source]

Split train_val_data either randomly or according to indices and return splits. This is a wrapper around split_dataset() with nicer parameter naming, order correction, and defaults. The same parameter constraints apply.

Parameters
Returns

tuple of splits (train, val)

Return type

Tuple[torch.utils.data.Subset, torch.utils.data.Subset]

classmethod split_trainval_test(data, train_val_indices=None, test_indices=None, test_len=None, test_split=None, **ignored_args)[source]

Split data either randomly or according to indices and return splits. This is a wrapper around split_dataset() with nicer parameter naming, order correction, and defaults. The same parameter constraints apply.

Parameters
Returns

tuple of splits (train_val, test)

Return type

Tuple[torch.utils.data.Subset, torch.utils.data.Subset]

validate_by(validator)[source]

Validate all data splits using validator, which raises in case of invalid format.

Parameters

validator (Callable[[torch.utils.data.Dataset, str], Any]) –

Return type

None

DEFAULT_VAL_SPLIT: float = 0.2

Default validation split proportion. This is the proportion of val in train_val.

__hash__ = None
property data: torch.utils.data.Dataset

Concatenation of all data (train, val, test) stored in this tuple.

property info: pandas.core.frame.DataFrame

Provide a string with some statistics on the held datasets.

property test: torch.utils.data.Dataset

Testing dataset split.

property test_split: Optional[float]

Value of \(\frac{len(test)} {len(test) + len(train\_val)}\) if none of the datasets is None or empty.

property train: torch.utils.data.Dataset

Training data set.

property train_val: torch.utils.data.Dataset

Combined dataset of training and validation data.

It is a concatenation of train and val or a permutation thereof.

property val: torch.utils.data.Dataset

Validation data set.

property val_split: Optional[float]

Value of \(\frac{len(val)} {len(val) + len(train)}\) if none of the datasets is None or empty.