data_for_concept_model

hybrid_learning.concepts.analysis.analysis_handle.data_for_concept_model(concept_model=None, main_model_stump=None, concept=None, in_channels=None, transforms=None, cache_builder=None, cache_root=None, cache_in_memory=False, device=None)[source]

Data handles with activation maps for and ground truth from concept. The data from the concept model’s concept is wrapped by an ActivationDatasetWrapper. Its input and ground truth are:

Input

the required activation maps of the main model

Ground truth

the segmentation masks scaled to the activation map size (currently scaling is done on __getitem__-call of ActivationDatasetWrapper)

Parameters
  • concept_model (Optional[ConceptDetectionModel2D]) – the concept model (with concept and main model) to generate the wrapped dataset from; if not set, main_model_stump, concept, and in_channels are used

  • main_model_stump (Optional[Module]) – the model stump that generates the activations

  • concept (Optional[Concept]) – the concept the data of which is to be wrapped

  • in_channels (Optional[int]) – (optional for validation purposes) the input channels for the concept model

  • transforms (Optional[Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]]]) – the transformations to add to each wrapper instance

  • cache_builder (Optional[Callable[[BaseDataset, ConceptDetectionModel2D], Cache]]) – a builder that accepts the dataset to be wrapped and the concept model for which to wrap it, and returns a cache to be registered to the dataset wrapper; should have no side effects; defaults to a cache tuple of each a cache cascade for the activations and the masks

  • cache_root (Optional[str]) – in case cache_root is given instead of cache_builder, a default cache builder is defined using cache_root and default_cache_roots.

  • cache_in_memory (bool) – apply in-memory caches as default cache; if cache_root is also set, use as default a cache cascades of in-memory then file cache

  • device (Optional[Union[str, device]]) – the device to move all dataset items to after loading

Raises

ValueError if the data dimensions do not fit the in_channels of the concept model’s concept layers

Returns

tuple of train data, test data, validation data, all with activation maps as outputs

Return type

DataTriple