exact_hessian

hybrid_learning.concepts.train_eval.hessian.exact_hessian(fn, params, device=None)[source]

Compute all second derivatives of a scalar w.r.t. parameters.

The order of parameters corresponds to a one-dimensional vectorization followed by a concatenation of all tensors in parameters.

Parameters
  • fn (Tensor) – Scalar PyTorch function/tensor.

  • params (Iterable[Tensor]) – iterable object containing all tensors acting as variables of f

  • device (Optional[Union[str, device]]) – the torch device to use for the hessian

Returns

Hessian of f with respect to the concatenated version of all flattened quantities in parameters

Return type

Tensor

Note

The parameters in the list are all flattened and concatenated into one large vector theta. Return the matrix \(d^2 E / d \theta^2\) with

Related work:: The code is a modified version of https://discuss.pytorch.org/t/compute-the-hessian-matrix-of-a-network/15270/3