ResettableOptimizer

class hybrid_learning.concepts.train_eval.base_handles.resettable_optimizer.ResettableOptimizer(optim_type, lr_scheduler_type=None, lr_kwargs=None, batch_update_lr=False, **optim_kwargs)[source]

Bases: object

Wrapper around torch optimizers to enable reset and automatic learning rate handling. Saves the optimizer/learning rate scheduler initialization arguments other than the parameters/optimizer. Replace

 1opt = OptimizerType(params, **opt_init_args)
 2# optionally lr_scheduler:
 3lr_scheduler = LRSchedulerType(opt, **lr_init_args)
 4for epoch in epochs:
 5    for batch in dataset:
 6        opt.zero_grad()
 7        ...
 8        opt.step()
 9        lr_scheduler.step()  # if batch-wise update
10    lr_scheduler.step() # if epoch-wise update

with

 1opt_handle = ResetOptimizer(OptimizerType,
 2                            LRSchedulerType, lr_init_args,
 3                            # if batch-wise lr updates:
 4                            batch_update_lr=True)
 5...
 6opt = opt_handle(params)
 7for epoch in epochs:
 8    for batch in dataset:
 9        opt.zero_grad()
10        ...
11        opt.step()
12    opt.epoch_end()

A cast as in torch.optim.Optimizer.cast() is currently not supported.

Public Data Attributes:

lr

The used learning rate default of the optimizer (starting value for lr scheduler)

settings

Return nice dict representation of init args and optimizer type.

Public Methods:

reset()

Reset optimizer (and lr scheduler); requires init before next use of optimizer.

init(params)

Initialize optimizer and learning rate scheduler with given parameters.

zero_grad()

Wrapper around torch.optim.Optimizer.zero_grad().

add_param_group(param_group)

Wrapper around torch.optim.Optimizer.add_param_group().

step([closure, epoch])

Step the optimizer and the learning rate scheduler if set.

epoch_end()

Update of learning rate scheduler after an epoch.

Special Methods:

__init__(optim_type[, lr_scheduler_type, ...])

Init.

__call__(params)

Return a fresh instance of the optimizer with the saved settings.

__repr__()

Return repr(self).

__str__()

Return str(self).


Parameters
__call__(params)[source]

Return a fresh instance of the optimizer with the saved settings. Intends to wrap call to _optim_type, such that an instance of a resettable optimizer can replace the type of its target optimizer.

Parameters

params (Union[Iterable[Tensor], Iterable[Dict[str, Tensor]]]) –

Return type

ResettableOptimizer

__init__(optim_type, lr_scheduler_type=None, lr_kwargs=None, batch_update_lr=False, **optim_kwargs)[source]

Init.

Parameters
  • optim_type (Callable[[...], torch.optim.Optimizer]) – a callable that yields an optimizer when called with model parameters and the optim_kwargs

  • optim_kwargs – keyword arguments for creating the optimizer

  • lr_scheduler_type (Optional[Callable]) – Optional; if given, a callable that yields a learning rate scheduler at which to register the optimizer

  • lr_kwargs (Optional[Dict[str, Any]]) – arguments for the optional learning rate scheduler

  • batch_update_lr (bool) – whether to step the learning rate scheduler after each batch, so on step() call, or only on epoch_end() call; set this e.g. for torch.optim.lr_scheduler.CyclicLR

__repr__()[source]

Return repr(self).

__str__()[source]

Return str(self).

add_param_group(param_group)[source]

Wrapper around torch.optim.Optimizer.add_param_group(). Only call after init() or __call__().

epoch_end()[source]

Update of learning rate scheduler after an epoch. Only specific learning rate scheduler require updating after the epoch.

init(params)[source]

Initialize optimizer and learning rate scheduler with given parameters.

Parameters

params (Union[Iterable[Tensor], Iterable[Dict[str, Tensor]]]) –

reset()[source]

Reset optimizer (and lr scheduler); requires init before next use of optimizer.

step(closure=None, epoch=None)[source]

Step the optimizer and the learning rate scheduler if set.

zero_grad()[source]

Wrapper around torch.optim.Optimizer.zero_grad(). Only call after init() or __call__().

_lr_kwargs: Dict[str, Any]

The arguments to _lr_scheduler_type besides the optimizer. Only used if _lr_scheduler_type is set.

_lr_scheduler_type: Optional[Callable[..., Any]]

Optional learning rate scheduler builder. Used in init() to get new scheduler.

_optim_kwargs: Dict[str, Any]

The arguments (besides the parameters key) to _optim_type to get new optimizer.

_optim_type: Callable[..., Optimizer]

Type/builder for optimizer. Used in init() to get new optimizer.

batch_update_lr: bool

Whether to call step on the learning rate scheduler after each batch or only after each epoch. Batch-wise updates are handled in step(), epoch-wise ones in epoch_end(),

property lr: float

The used learning rate default of the optimizer (starting value for lr scheduler)

lr_scheduler: Optional

Reference to the optional current learning rate scheduler. None if _lr_scheduler_type is None, after reset() and before init().

optimizer: Optional[Optimizer]

Reference to the current optimizer. None after reset() and before init()

property settings: Dict

Return nice dict representation of init args and optimizer type.