train_eval_funs

Description

Basic methods for training and evaluation of models with callback calls.

Functions

device_of(model)

Return the device of the given pytorch model.

evaluate(model, kpi_fns, val_loader[, ...])

Evaluate the model wrt loss and metric_fns on the test data.

loader([data, batch_size, shuffle, device, ...])

Prepare and return a torch data loader with device-dependent multi-processing settings.

predict_laplace(model, data[, device, var0])

Performs prediction with probit approximation of the Bayesian posterior.

second_stage_train(model, nll_fn, ...)

Evaluate the model wrt loss and metric_fns on the test data.

train_one_epoch(model, loss_fn, metric_fns, ...)

Train for one epoch, evaluate, and return history and test results.