# Copyright (c) 2022 Continental Automotive GmbH
"""Base classes and helper functions for defining logical operations.
Main base classes are:
- :py:class:`Merge`: Base class for operating on arrays/tensors and booleans,
and for building computational trees of such operations
- :py:class:`TorchOrNumpyOperation`: Base :py:class:`Merge` class for operating
on numpy or pytorch tensors
- :py:class:`MergeBuilder`: A convenience builder class that allows to
define custom constructors for a merge class;
of interest for easily setting defaults
The logical merging operations derived from :py:class:`Merge` allow for
concatenation of operations. Using them, any operation involving
intersection (``AND``),
union (``OR``), and
inversion (``NOT``)
of masks can be modelled. Scalar values in this case are
treated as all-same-valued masks when mixed with mask tensors.
For further information have a look at the :py:class:`Merge` documentation.
"""
import abc
import inspect
from typing import Type, Union, Sequence, Optional, Any, Dict, Callable, Tuple, List, Mapping, Iterable, Set, MutableMapping, \
Collection, Literal
import numpy as np
import torch
from ...datasets.transforms.dict_transforms import DictTransform
from ...datasets.transforms.image_transforms import ToTensor
[docs]def stack_tensors(*inputs: torch.Tensor) -> torch.Tensor:
"""Broadcast and stack the inputs in dim 0 to enable pixel-wise operations."""
return torch.stack(torch.broadcast_tensors(*inputs)
if len(inputs) else torch.tensor(inputs))
[docs]class Merge(DictTransform, abc.ABC):
"""Base class for operations and operation trees on dictionary inputs.
Merge the masks or scalars values of the dict input according to the
operation (tree) definition and store them under the specified output key.
The merge operation may recursively have child merge operations as
:py:attr:`in_keys`, which are evaluated on the given dictionary before
the parent is.
**Operation**
The actual operation is hidden in the :py:meth:`apply_to` method:
It is given a dictionary of annotations of the form ``{ID: value}`` and
will return the dict with the merged mask added as ``{out_key: value}``.
The intermediate outputs of child operations are by default only used
for caching (see :py:attr:`cache_duplicates`) and then discarded.
To include them into the final output, use the ``keep_keys`` argument
to the operation call (see :py:meth:`apply_to`).
The benefit of caching duplicates is that results may be reused
amongst different operations.
**Initialization**
During init, all non-keyword arguments serve as :py:attr:`in_keys`.
These are used when the merge operation is called on a dict:
The dict must provide items with these :py:attr:`in_keys`, and the values of
these items are fed to the actual operation.
Settings must be given as keyword arguments.
To set default keyword arguments for the init call, use a :py:class:`MergeBuilder`.
See :py:meth:`with_` for creating a :py:class:`MergeBuilder` from a :py:class:`Merge`
class.
**Example: Boolean Logic**
To get all heads, noses, and mouths (binary masks)
of real persons (binary masks)
in bathrooms (boolean labels),
call:
>>> from hybrid_learning.fuzzy_logic.tnorm_connectives.boolean import AND, OR, NOT, BooleanLogic
>>> op = AND("person", OR("head", "nose", "mouth"), NOT("bathroom"))
>>> op == BooleanLogic().parser()("person&&head||nose||mouth&&~bathroom")
True
>>> # Example with 1 pixel of a person mouth not in a bathroom:
>>> result = op({"person": 1, "head": 0, "nose": 0, "mouth": 1, "bathroom": False})
>>> result[op.out_key] == 1
True
>>> result
{'person': 1, 'head': 0, 'nose': 0, 'mouth': 1, 'bathroom': False,
'(head||mouth||nose)&&(~bathroom)&&person': 1}
To also inspect the intermediate output, use the ``keep_keys`` option:
>>> op({"person": 1, "head": 0, "nose": 0, "mouth": 1, "bathroom": False},
... keep_keys=op.all_out_keys)
{'person': 1, 'head': 0, 'nose': 0, 'mouth': 1, 'bathroom': False,
'head||mouth||nose': True,
'~bathroom': True,
'(head||mouth||nose)&&(~bathroom)&&person': 1}
Note that the input dict must feature all ``in_keys`` of operations in the formula.
**Subclassing**
To implement your own merge operation
- implement the :py:meth:`operation`
- specify your own :py:attr:`SYMB` (this must be unique within the logic you are using)
- extend the :py:attr:`settings` and :py:attr:`setting_defaults` properties by new items if necessary
**Format and String Parsing**
The (recursive) merge operation best is specified in conjunctive normal
form for uniqueness (thus comparability) and parsing compatibility.
This is the form
.. code::
AND(..., [NOT(...), ...], [OR(..., [NOT(..), ...])])
(see https://en.wikipedia.org/wiki/Conjunctive_normal_form).
Exemplary available operations are the Boolean ones
:py:class:`~hybrid_learning.fuzzy_logic.tnorm_connectives.boolean.AND` (intersection),
:py:class:`~hybrid_learning.fuzzy_logic.tnorm_connectives.boolean.OR` (union), and
:py:class:`~hybrid_learning.fuzzy_logic.tnorm_connectives.boolean.NOT` (inversion)
that all operate pixel-wise.
Boolean classification labels are treated as all-one-masks.
One can use a :py:class:`~hybrid_learning.fuzzy_logic.logic_base.parsing.FormulaParser`
implementation to parse a string representation of an operation tree.
Check the corresponding implementation for the operator precedence and examples.
For parsing, used connectors of the logic must be encoded by their :py:attr:`SYMB`
attribute, e.g. for the examples above:
- ``AND``: a&&b
- ``OR``: a||b
- ``NOT`` (unary operation): ~a
"""
SYMB: str = None
"""The string symbol of this class (override for sub-classes)."""
ARITY: int = -1
"""The arity of the operation. -1 means unlimited number of arguments possible."""
IS_COMMUTATIVE: bool = False
"""Whether instances are equivalent to ones with permuted :py:attr:`in_keys`."""
[docs] @classmethod
def variadic_(cls, **kwargs):
"""Return an instance with variadic __call__.
It's __call__ will accept maps or iterables of arbitrary length
(for :py:attr:`ARITY` = -1) respectively of length matching :py:attr:`ARITY`.
All values/elements are passed through to the :py:meth:`operation`,
and the plain output of :py:meth:`operation` is returned
(see also :py:meth:`variadic_apply_to`).
Use this e.g. to wrap the :py:meth:`operation` into an object in a
multiprocessing-safe manner.
Example:
>>> from hybrid_learning.fuzzy_logic.tnorm_connectives.boolean import AND
>>> primitive_and = AND.variadic_()
>>> primitive_and({"a": 1, "b": True})
1
>>> primitive_and([1, True, 1.])
1
No :py:attr:`in_keys` may be given, and :py:attr:`out_key` is obsolete.
The returned instance may not be used as child element of a formula."""
return cls(**{**dict(_variadic=True), **kwargs})
@property
def is_variadic(self) -> bool:
"""Whether the instance is variadic.
See :py:meth:`variadic_`."""
return self._variadic
[docs] def __init__(self, *in_keys: Union[str, 'Merge'], out_key: str = None,
overwrite: bool = True, skip_none: bool = True,
replace_none=None, symb: str = None, cache_duplicates: bool = True,
keep_keys: Collection[str] = None,
_variadic: bool = False):
"""Init.
Hand over input keys either as str or as a Merge operation of str.
:param in_keys: sequence of either
:py:class:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge` operation
instances or strings with placeholders for the input keys
:param out_key: key for the output of this operation; used to
init :py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.out_key`
:param overwrite: on call, whether to overwrite the value at
:py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.out_key`
in the given dict if the key already exists;
raise if key exists and ``overwrite`` is true; saved in
:py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.overwrite`.
:param replace_none: if not ``None``, the value to replace any
``None`` values with; see
:py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.replace_none`
:param symb: override the
:py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.SYMB`
for this instance
:param keep_keys: intermediate output keys to add to call output;
see :py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.keep_keys`
:param cache_duplicates: whether outputs of children with identical
keys should be cached and reused; see
:py:attr:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.cache_duplicates`
:param _variadic: the preferred way to specify this argument is
:py:meth:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.variadic_`;
see there for details
"""
# region Value checks
if not _variadic and len(in_keys) <= 0:
raise TypeError("Got empty list of in_keys for non-variadic operator!")
if not _variadic and 0 < self.ARITY > len(in_keys):
raise TypeError("Got too few in_keys ({}) for operation of class {} with arity {}: {}"
.format(len(in_keys), self.__class__.__name__, self.ARITY, in_keys))
if not _variadic and 0 < self.ARITY < len(in_keys):
raise TypeError("Got too many in_keys ({}) for operation of class {} with arity {}: {}"
.format(len(in_keys), self.__class__.__name__, self.ARITY, in_keys))
if _variadic and len(in_keys) != 0:
raise TypeError("Variadic instances do not accept in_keys. Either set variadic=True or give in_keys.")
for child in [c for c in in_keys if isinstance(c, Merge) and c.is_variadic]:
raise ValueError("Children operations of a formula may not be variadic, "
"but found variadic child operation {}.".format(repr(child)))
# endregion
if symb is not None:
self.SYMB = symb
if self.SYMB is None:
raise ValueError("SYMB attribute is None for object of class {}!".format(self.__class__) +
" Either set class attribute or specify during init via symb parameter.")
self._variadic: bool = _variadic
"""See :py:meth:`~hybrid_learning.fuzzy_logic.logic_base.merge_operation.Merge.is_variadic`."""
self.in_keys: Sequence[Union[str, 'Merge']] = in_keys
"""The keys of segmentation masks to unite in given order.
Keys are either constant strings or a merge operation."""
self.out_key: str = out_key or str(self)
"""The key to use to store the merge output in the annotations dict.
Take care to not accidentally overwrite existing keys
(cf. :py:attr:`overwrite`)."""
self.keep_keys: Optional[Collection[str]] = keep_keys
"""The keys of intermediate outputs in :py:attr:`all_out_keys` which should
be added to the return of a call.
Default (``None`` or empty collection): duplicate children outputs are cached
but not returned to save memory."""
self.overwrite: Union[bool, Literal['noop']] = overwrite
"""Whether to overwrite a value in the input dictionary when
applying this operation.
The operation is defined in :py:meth:`operation`.
The key that may be overwritten is stored in :py:attr:`out_key`.
An exception is raised if this is ``False`` and the key exists.
If set to ``'noop'`` and :py:attr:`out_key` is in the given
annotations dict, it is returned unchanged."""
self.skip_none: bool = skip_none
"""If set to ``True``, when a None input value is encountered simply
``None`` is returned. If ``False``, an error is raised."""
self.replace_none: Optional[Any] = replace_none
"""If not ``None``, any received ``None`` value is replaced by the
given value. This is done only for computation, the ``None`` value in the
received dict is left unchanged. Key-value pairs with ``None`` value may
come from the input or from child operations."""
self.cache_duplicates: bool = cache_duplicates
"""Whether to cache duplicate child operation outputs with duplicate out_key.
If set to false, all children and children children are evaluated and the
values of duplicate ``out_keys`` are evaluated several times and overwritten,
possibly leading to more computational time while using less memory.
Note that the order of children execution is determined by their order in
:py:attr:`in_keys`, depth first for nested operations."""
@property
def settings(self) -> Dict[str, Any]:
"""Settings to reproduce the instance.
(Mind that in_keys must be expanded! For direct reproduction use copy.)"""
return dict(in_keys=self.in_keys,
out_key=self.out_key,
overwrite=self.overwrite,
skip_none=self.skip_none,
replace_none=self.replace_none,
cache_duplicates=self.cache_duplicates,
keep_keys=self.keep_keys,
_variadic=self.is_variadic,
symb=self.SYMB)
@property
def setting_defaults(self):
"""Defaults used for :py:attr:`settings`."""
return dict(out_key=str(self), overwrite=True, skip_none=True, replace_none=None,
cache_duplicates=True, keep_keys=None, symb=self.__class__.SYMB,
_variadic=False)
@property
def pretty_op_symb(self) -> str:
"""Name of the operation symbol suitable for filenames etc."""
return self.__class__.__name__
[docs] def to_infix_notation(self, sort_key: Callable = None,
use_whitespace: bool = False,
use_pretty_op_symb: bool = False,
precedence: Sequence['Merge'] = None,
brackets: Tuple[str, str] = ('(', ')')) -> str:
"""Return an infix str encoding equal for differently sorted operations.
To define a custom sorting for children of commutative operations, hand over
the ``sort_key`` argument for the builtin ``sorted``.
If no ``precedence`` is given, brackets are set around all child operations.
:param sort_key: sort child operations by the given ``sort_key``
if the parent operation :py:attr:`IS_COMMUTATIVE`; defaults to alphabetical sorting
:param use_whitespace: separate infix operation symbols from their arguments by whitespace
:param use_pretty_op_symb: use the :py:attr:`pretty_op_symb` instead of :py:attr:`SYMB` for
representation of this operation instance
:param precedence: apply brackets according to the given ``precedence``;
if not given, assume this operation is in normal form (no brackets)
must be a list of :py:class:`Merge` operation classes or instances in
order of increasing precedence; their ``SYMB`` attribute is used to access
the operation symbol
:param brackets: tuple of the left and right bracket symbols to use if needed
"""
# Get pairs of (symbol, string_repr):
symbs_and_reprs = [
(key.SYMB, key.to_infix_notation(sort_key=sort_key, precedence=precedence,
use_whitespace=use_whitespace, use_pretty_op_symb=use_pretty_op_symb))
if isinstance(key, Merge) else (None, str(key))
for key in self.in_keys]
normalized_in_keys = self._set_brackets(symbs_and_reprs,
reference_symb=self.SYMB, precedence=precedence,
brackets=brackets)
if self.IS_COMMUTATIVE:
normalized_in_keys = sorted(normalized_in_keys, key=sort_key)
symb: str = self.pretty_op_symb if use_pretty_op_symb else self.SYMB
if len(normalized_in_keys) == 0:
return symb
if len(normalized_in_keys) == 1:
return f"{symb}{normalized_in_keys[0]}"
return (f' {symb} ' if use_whitespace else symb).join(normalized_in_keys)
[docs] def to_str(self, **infix_notation_kwargs) -> str:
"""Alias for :py:meth:`to_infix_notation`."""
return self.to_infix_notation(**infix_notation_kwargs)
[docs] def to_pretty_str(self, **infix_notation_kwargs) -> str:
"""Same as :py:meth:`to_str` but using pretty operation names suitable for
filenames etc."""
return self.to_str(**{**infix_notation_kwargs, **dict(use_pretty_op_symb=True)})
@staticmethod
def _set_brackets(symbs_and_str: Sequence[Tuple[str, str]],
reference_symb: str = None,
precedence: Sequence['Merge'] = None,
brackets: Tuple[str, str] = ('(', ')')) -> List[str]:
"""Join the strings from the symbol-string-tuples with brackets where needed wrt ``precedence``.
If no ``precedence`` is given or none is available for ``reference_symb``,
brackets are set around all strings with non-``None`` symbols (i.e. all but variables).
:param symbs_and_str: list tuples of the form ``(operation_symbol, operation_string_representation)``;
the string representations are to be joined, enclosing those in brackets that
have a lower precedence than ``reference_precedence``
:param reference_symb: set all operation strings in brackets that have a lower
precedence than that associated with the operation with ``reference_symb``;
should be set to the common parent operation symbol
:param precedence: list of Merge operation classes in order of increasing precedence;
their ``SYMB`` class or instance attribute is used to access the operation symbol
"""
symb_to_prec: Mapping[str, int] = {} if precedence is None else \
dict((precedence[i].SYMB, i) for i in range(len(precedence)))
if reference_symb not in symb_to_prec:
return [s if symb is None else f"{brackets[0]}{s}{brackets[1]}" for symb, s in symbs_and_str]
reference_prec: int = symb_to_prec[reference_symb]
# Apply brackets to string_repr for higher precedence symbols:
bracketed_str = [f"{brackets[0]}{p_str}{brackets[1]}"
if (symb is not None and symb_to_prec.get(symb, -1) <= reference_prec) else p_str
for symb, p_str in symbs_and_str]
return bracketed_str
[docs] def __str__(self):
return self.to_str()
[docs] def to_repr(self, settings: Dict[str, Any] = None, defaults: Dict[str, Any] = None,
sort_key: Callable = None, use_module_names: bool = False,
indent: str = None, indent_level: Optional[int]=None, indent_str: str = ' ',
indent_first_child: bool = None, _prepend_indent: bool = True) -> str:
"""Return str representation which can be used to reproduce and compare the instance.
.. warning::
Tautologies in the form of duplicate children are not filtered for now, e.g.
>>> from hybrid_learning.fuzzy_logic.tnorm_connectives.boolean import AND
>>> AND("a") == AND("a", "a")
False
Examples:
>>> from hybrid_learning.fuzzy_logic.tnorm_connectives.boolean import AND, OR
>>> obj = OR(AND("b", "c"), "a", symb="CustomAND", overwrite=False,)
>>> print(obj.to_repr())
OR('a', AND('b', 'c'), overwrite=False, symb='CustomAND')
>>> print(obj.to_repr(indent=True))
OR('a',
AND('b',
'c'),
overwrite=False, symb='CustomAND')
>>> print(obj.to_repr(indent_first_child=True))
OR(
'a',
AND(
'b',
'c'),
overwrite=False, symb='CustomAND')
>>> print(obj.to_repr(indent_level=1, indent_str='--'))
--OR('a',
----AND('b',
------'c'),
----overwrite=False, symb='CustomAND')
:param settings: the settings dict to include as key-value pairs;
defaults to :py:attr:`settings` (set e.g. to overwrite this method)
:param defaults: updates to :py:attr:`setting_defaults`; if a default for a key is given and
the value equals the default, it is excluded from printing
:param sort_key: sort child operations by the given ``sort_key``
if the parent operation :py:attr:`IS_COMMUTATIVE`; defaults to alphabetical sorting
:param use_module_names: whether to use both module plus class names or just the class names
:param indent: if not ``None``, print a tree-like view by putting each ``in_keys``
item in a new line with indent matching the class name length;
takes precedence over ``indent_level`` and ``indent_str`` arguments
:param indent_level: if not ``None`` and indent is ``None``, print a tree-like view
by putting each ``in_keys`` item in a new line with indent of ``indent_str``;
if >0, ``indent_level*indent_str`` is prepended to every printed line.
:param indent_str: the (whitespace) string representing one indentation level
:param indent_first_child: whether to already indent the first child or not
:param _prepend_indent: whether to prepend the given indent to the output string (default: yes)
"""
# Class name
class_name = self.__class__.__name__
if use_module_names:
class_name = f"{self.__class__.__module__}.{class_name}"
# Indentation settings
if (indent_first_child is not None and indent_level is None and indent is None) or indent is True: indent = ''
do_indent: bool = indent is not None or indent_level is not None
base_indent = indent if indent is not None else (indent_level * indent_str if indent_level else '')
key_indent = base_indent + (' ' * (len(class_name)+1) if indent is not None else (indent_str if indent_level is not None else ''))
child_sep = ',\n' + key_indent if do_indent else ', '
setting_first_sep = ',\n' + key_indent if do_indent else ', '
child_first_sep = '\n' + key_indent if (do_indent and indent_first_child) else ''
last_sep = '' # '\n' + base_indent
# Keys
key_reprs = [key.to_repr(sort_key=sort_key, use_module_names=use_module_names,
indent=key_indent if indent else None, _prepend_indent=False,
indent_first_child=indent_first_child,
indent_level=None if indent_level is None else indent_level+1, indent_str=indent_str)
if isinstance(key, Merge) else repr(key)
for key in self.in_keys]
if self.IS_COMMUTATIVE:
key_reprs = sorted(key_reprs, key=sort_key)
# Settings
defaults = {**self.setting_defaults, **(defaults or {})}
settings = settings or self.settings
setting_reprs = [f"{key}={f'{val.__module__}.{val.__name__}' if inspect.isclass(val) else repr(val)}"
for key, val in sorted(settings.items())
if key != 'in_keys' and (key not in defaults or defaults[key] != val)]
return ((base_indent if _prepend_indent else '') + class_name + "("
+ (child_first_sep + child_sep.join(key_reprs) if len(key_reprs) else '')
+ ((setting_first_sep if len(key_reprs) else '')
+ f", ".join(setting_reprs) if len(setting_reprs) else '')
+ last_sep + f")")
[docs] def __repr__(self) -> str:
"""Call :py:meth:`to_repr` without sorting."""
return self.to_repr(sort_key=lambda _: 1)
[docs] def __eq__(self, other: 'Merge') -> bool:
"""Two merge operations are considered equal, if their
normalized representations coincide. (See :py:meth:`to_repr`).
This means, they recursively have the same children up to commutation.
"""
if not isinstance(other, Merge):
return NotImplemented
return self.to_repr() == other.to_repr()
[docs] def __copy__(self) -> 'Merge':
"""Return a deep copy of self using settings."""
setts = self.settings
in_keys = [k.__copy__() if isinstance(k, Merge) else str(k)
for k in setts.pop('in_keys')]
return self.__class__(*in_keys, **setts)
[docs] def treerecurse_replace_keys(self, **replace_map: Dict[str, str]) -> 'Merge':
"""Return a new formula with all occurences of variables in ``replace_map`` replaced
and else identical settings.
The children of the new formula instance are new instances as well.
:param replace_map: mapping ``{old_var_name: new_var_name}``
"""
setts = self.settings
in_keys = [k.treerecurse_replace_keys(**replace_map) if isinstance(k, Merge)
else replace_map.get(str(k), str(k))
for k in setts.pop('in_keys')]
return self.__class__(*in_keys, **setts)
[docs] def treerecurse(self, fun: Callable[[Union['Merge', str]], Optional['Merge']]) -> 'Merge':
"""Apply the given function recursively to this and all children instances.
If ``fun`` returns ``None``, the operation is assumed to have been inline.
A non-``None`` return replaces the original root respectively ``in_keys`` item.
Acting root before children and depth first."""
fun_out: Optional[Merge] = fun(self)
curr_root = self if fun_out is None else fun_out
if isinstance(curr_root, Merge):
curr_root.in_keys = [k.treerecurse(fun) if isinstance(k, Merge) else fun(str(k))
for k in curr_root.in_keys]
return curr_root
[docs] def __call__(self, annotations: Union[Mapping[str, Any], Iterable],
keep_keys: Collection[str] = None
) -> Union[Mapping[str, Any], Any]:
"""Call method modifying a given dictionary."""
return self.apply_to(annotations, keep_keys=keep_keys)
[docs] def apply_to(self, annotations: Union[MutableMapping[str, Any], Iterable],
keep_keys: Collection[str] = None,
) -> Union[Mapping[str, Any], Any]:
"""Apply this operation to the ``annotations`` dict.
In case of a :py:meth:`variadic_` instance, also a plain iterable may be given,
see :py:meth:`variadic_apply_to` which is called in that case.
The operation of this instance is defined in :py:attr:`operation`.
First apply all child operations to the dict.
Hereby try to overwrite a value of annotations if its key correspond
to an :py:attr:`out_key` of a child operation, but do not create the
value of a key twice. Then apply :py:attr:`operation` on the
originally given and generated values now stored in ``annotations``
and store the result also in ``annotations``.
.. warning::
Annotations is inline updated.
Especially, the :py:attr:`out_key` and ``keep_keys`` items are added,
and children may apply inline operations to values!
:param annotations: dict to modify by adding values for :py:attr:`out_key` and ``keep_keys``
:param keep_keys: the output keys in :py:attr:`all_out_keys` for which
values shall be added to ``annotations`` in addition to :py:attr:`keep_keys`
:return: modified ``annotations`` dict, extended by the keys from
:py:attr:`all_out_keys` with the recursively generated values;
variadic instances return the plain output of :py:meth:`operation`
"""
if self.is_variadic:
return self.variadic_apply_to(annotations)
keep_keys: List[str] = [*(keep_keys or []), *(self.keep_keys or [])]
# region value checks
if not isinstance(annotations, MutableMapping):
raise TypeError(("Non-variadic instances of class {} only accept mutable mappings "
"as input to __call__, but got input of type {}")
.format(self.__class__, type(annotations)))
# About to overwrite a value without permission?
if self.out_key in annotations.keys():
if not self.overwrite:
raise KeyError(("out_key {} exists as key in given dict {}, and "
"overwrite is False")
.format(self.out_key, annotations))
elif self.overwrite == 'noop':
return annotations
# Any needed in_keys missing from annotations?
missing_keys: Set[str] = self.all_in_keys - annotations.keys()
if len(missing_keys) > 0:
raise ValueError(("Input keys {} for operation {} missing from "
"annotation keys {}")
.format(missing_keys, repr(self),
annotations.keys()))
# endregion
# region get and add children outputs
# collect from children besides direct output: keys needed for caching, keys in keep_keys
_seen = []
children_keep_keys: Sequence[str] = list({
*([key for key in self._all_out_keys_with_duplicates
if key in _seen or _seen.append(key)] if self.cache_duplicates else []),
*keep_keys})
# get children outputs needed for operation
children_results = dict(annotations)
keys_to_overwrite: Set[str] = self.all_out_keys.intersection(
children_results.keys())
for child_op in self.children:
# Output not yet created/existent?
if not self.cache_duplicates \
or child_op.out_key not in children_results.keys() \
or child_op.out_key in keys_to_overwrite:
children_results = child_op(children_results, keep_keys=children_keep_keys)
# Mark output as created.
if self.cache_duplicates and child_op.out_key in keys_to_overwrite:
keys_to_overwrite.remove(child_op.out_key)
# add children outputs marked for keeping
annotations.update({key: v for key, v in children_results.items()
if key in [*annotations.keys(), *keep_keys]})
# endregion
# region skip or fill None
# Any needed input is None?
if any(children_results[k] is None for k in self.operation_keys):
if self.skip_none: # Fill output with None
annotations[self.out_key] = None
return annotations
if self.replace_none is None:
raise ValueError("Received None values for keys {}"
.format([k for k in set(self.operation_keys) if annotations[k] is None]))
# endregion
# Finally execute operation:
op_inputs = (children_results[k] for k in self.operation_keys)
annotations[self.out_key] = self.operation([self.replace_none if v is None else v for v in op_inputs])
return annotations
[docs] def variadic_apply_to(self, annotations: Union[Mapping[str, Any], Iterable]) -> Any:
"""Return the result of operation on the values/items of a mapping or sequence
of arbitrary length. Performs ``None`` check/replacement and :py:attr:`ARITY` check.
In case of a :py:attr:`ARITY` of -1 and empty annotations list, or an annotations
list length not matching the arity, an :py:class:`IndexError` is raised."""
if isinstance(annotations, Mapping):
annotations: Sequence = list(annotations.values())
elif not isinstance(annotations, Sequence):
annotations: Sequence = list(annotations)
if self.ARITY != -1 and len(annotations) != self.ARITY:
raise IndexError("Length of the given annotations ({}) does not match ARITY ({})!"
.format(len(annotations), self.ARITY))
elif len(annotations) == 0:
raise IndexError("Empty annotations list provided!")
# region skip or replace None
if any(v is None for v in annotations):
if self.skip_none:
return None
if self.replace_none:
annotations = [self.replace_none if v is None else v for v in annotations]
else:
raise ValueError("Received None values in variadic input {}".format(annotations))
# endregion
return self.operation(annotations)
@property
def children(self) -> List['Merge']:
"""The input keys which are child operations.
Input keys are stored in :py:attr:`in_keys`"""
return [key for key in self.in_keys if isinstance(key, Merge)]
@property
def all_children(self) -> List['Merge']:
"""All children operations in the flattened computational tree, sorted depth first.
See :py:attr:`children` for getting only the direct children."""
direct_children: List['Merge'] = self.children
return [child for dchild in direct_children
for child in [dchild, *dchild.all_children]]
@property
def consts(self) -> Set[str]:
"""The constant string keys in the input keys.
The :py:attr:`in_keys` contains both the constant keys which are to be
directly found in a given annotations dictionary, and child
operations whose output is used. For getting the child operations
stored in :py:attr:`in_keys` refer to :py:attr:`children`.
Should preserve the order in which children occur in :py:attr:`in_keys`.
"""
return {key for key in self.in_keys if not isinstance(key, Merge)}
@property
def operation_keys(self) -> List[str]:
"""The list of keys used for this parent operation in original order
(constants and children output keys).
These are all :py:attr:`consts` and the :py:attr:`out_key` of all
:py:attr:`children` operations.
Keys may be duplicate as e.g. in
>>> from hybrid_learning.fuzzy_logic.tnorm_connectives.boolean import OR, NOT
>>> OR("a", NOT("b"), "a", NOT("c", out_key="not_c")).operation_keys
['a', '~b', 'a', 'not_c']
"""
return [key.out_key if isinstance(key, Merge) else key for key in self.in_keys]
@property
def all_in_keys(self) -> Set[str]:
"""All string input keys both of self and of all child operations.
(See :py:attr:`in_keys`.)
These are the keys that must be present in an annotation when called
on it. Should preserve the order in which keys and children occur in
:py:attr:`in_keys`.
"""
base_key_lists: List[List[str]] = \
[key.all_in_keys if isinstance(key, Merge) else [key]
for key in self.in_keys]
return {k for base_key_list in base_key_lists for k in base_key_list}
@property
def all_out_keys(self) -> Set[str]:
"""Output keys of self and all child operations.
(See :py:attr:`children`).
Should preserve the order in which children occur in :py:attr:`in_keys`.
"""
return set(self._all_out_keys_with_duplicates)
@property
def _all_out_keys_with_duplicates(self) -> List[str]:
"""Output keys of self and all child operations with duplicates.
(See :py:attr:`all_out_keys`).
"""
out_key_lists: List[Set[str]] = [c.all_out_keys for c in
self.children] + [{self.out_key}]
return [k for out_key_list in out_key_lists for k in out_key_list]
[docs] def operation(self, annotation_vals: Sequence) -> Any:
"""Actual merge operation on values of the input keys
in annotations. See :py:attr:`in_keys`.
The ``annotation_vals`` must not contain ``None`` values,
and their length must match the :py:attr:`ARITY` of this operation."""
raise NotImplementedError()
[docs] @classmethod
def with_(cls, **additional_args) -> 'MergeBuilder':
"""Return a :py:class:`MergeBuilder` with the same symbol but additional init args.
Example usage (with changed symbol):
>>> from hybrid_learning.fuzzy_logic.tnorm_connectives.boolean import AND
>>> builder = AND.with_(skip_none=False, replace_none=0).symb_('&n&')
>>> builder.SYMB
'&n&'
>>> builder("a", "b")
AND('a', 'b', replace_none=0, skip_none=False, symb='&n&')
"""
return MergeBuilder(cls, symb=cls.SYMB, additional_args=additional_args)
_OpBuilder = Type[Merge]
_TensorType = Union[torch.Tensor, np.ndarray]
_NumericType = Union[bool, int, float, _TensorType]
[docs]class MergeBuilder:
"""Return a :py:class:`Merge` operation of specified class with additional settings upon call.
Common additional init arguments can be specified and a new :py:attr:`SYMB`,
overwriting the :py:attr:`Merge.SYMB` (or attaching a ``SYMB`` attribute to another builder).
Attribute access is passed over to the :py:attr:`merge_class` specified.
For easy instantiation see also :py:meth:`Merge.with_`."""
[docs] def __init__(self, merge_class: Union[_OpBuilder, Callable[..., Merge]],
symb: str = None, additional_args: Dict[str, Any] = None):
self.merge_class: Union[_OpBuilder, Callable[..., Merge]] = merge_class
"""The class or builder to use upon call."""
self._additional_args: Dict[str, Any] = additional_args or {}
"""See :py:attr:`additional_args`."""
self.SYMB: str = symb or merge_class.SYMB
"""The symbol representing the wrapped class."""
@property
def additional_args(self) -> Dict[str, Any]:
"""The additional arguments to and over to the :py:class:`Merge` class on each call."""
return {**self._additional_args,
**(dict(symb=self.SYMB) if self.SYMB != self.merge_class.SYMB else {})}
[docs] def symb_(self, symb: str) -> 'MergeBuilder':
"""Set SYMB and return self.
Can be used in chain assignments."""
self.SYMB = symb
return self
[docs] def with_(self, **additional_args: Dict[str, Any]):
"""Update the additional arguments."""
self._additional_args.update(additional_args)
return self
[docs] def variadic_(self, *args, **kwargs) -> Merge:
"""Return a variadic instance of the wrapped ``merge_class``.
Calls the ``variadic_`` function of :py:attr:`merge_class`."""
try:
self.merge_class: Type[Merge]
return self.merge_class.variadic_(*args, **{**self.additional_args, **kwargs})
except TypeError as t:
t.args = (*t.args, "Building variadic instance of {} of arity {} with {} arguments failed: {}(*{}, **{})"
.format(self.merge_class, getattr(self.merge_class, 'ARITY', 'unknown'),
len(args), self.merge_class.__name__, args, {**self.additional_args, **kwargs}))
raise t
[docs] def __call__(self, *args, **kwargs) -> Merge:
"""Build an instance of the specified Merge class with the additional args.
The given ``kwargs`` will overwrite arguments from :py:attr:`additional_args`."""
try:
return self.merge_class(*args, **{**self.additional_args, **kwargs})
except TypeError as t:
t.args = (*t.args, "The following init call to {} of arity {} with {} arguments failed: {}(*{}, **{})"
.format(self.merge_class, getattr(self.merge_class, 'ARITY', 'unknown'),
len(args), self.merge_class.__name__, args, {**self.additional_args, **kwargs}))
raise t
[docs] def __getattr__(self, k):
"""Pass attribute requests over to Merge class."""
if 'merge_class' not in vars(self):
raise AttributeError()
return getattr(self.merge_class, k)
[docs] def __repr__(self):
merge_class_repr = f"{self.merge_class.__module__}.{self.merge_class.__name__}" \
if inspect.isclass(self.merge_class) else repr(self.merge_class)
setts = {}
if self.SYMB != self.merge_class.SYMB:
setts['symb'] = repr(self.SYMB)
if self._additional_args:
setts['additional_args'] = repr(dict(sorted(self._additional_args.items())))
return (self.__class__.__name__ + "(" + merge_class_repr + ", " +
", ".join([f'{key}={val}' for key, val in setts.items()])
+ ")")
[docs]class TorchOperation(Merge, abc.ABC):
"""Generic merge operation on torch tensors."""
[docs] @staticmethod
@abc.abstractmethod
def torch_operation(*inputs: torch.Tensor) -> torch.Tensor:
"""Operation on pytorch tensors.
If possible, the operation should support broadcasting."""
raise NotImplementedError()
[docs] def operation(self, annotation_vals: Sequence) -> torch.Tensor:
"""Calculate the predicate output.
Non-tensor inputs are transformed to tensors.
See :py:meth:`torch_operation`."""
if len(annotation_vals) < self.ARITY or (self.ARITY > 0 and len(annotation_vals) > self.ARITY):
raise TypeError("Operation {} of type {} and arity {} was called with {} inputs:\n{}"
.format(self, type(self), self.ARITY, len(annotation_vals), annotation_vals))
masks = annotation_vals[:(self.ARITY if self.ARITY >= 0 else len(annotation_vals))]
masks: List[torch.Tensor] = [torch.as_tensor(mask) for mask in masks]
return self.torch_operation(*masks)
[docs]class TorchOrNumpyOperation(TorchOperation, abc.ABC):
"""Generic merge operation allowing to define both a torch and a numpy operation.
Which one is selected depends on the types of the provided annotations:
If any is a torch tensor, the torch operation is used and a torch tensor
returned, otherwise the numpy operation.
"""
[docs] @staticmethod
@abc.abstractmethod
def numpy_operation(*inputs: np.ndarray) -> np.ndarray:
"""Operation on Booleans, numpy arrays and numbers.
If possible, the operation should support broadcasting."""
raise NotImplementedError()
[docs] def operation(self, annotation_vals: Sequence) -> _NumericType:
"""Operation on either torch tensors or Booleans, numpy arrays and numbers."""
if any(isinstance(inp, torch.Tensor) for inp in annotation_vals):
return self.torch_operation(*[ToTensor.to_tens(inp) for inp in annotation_vals])
else:
return self.numpy_operation(*annotation_vals)