ResettableOptimizer
- class hybrid_learning.concepts.train_eval.base_handles.resettable_optimizer.ResettableOptimizer(optim_type, lr_scheduler_type=None, lr_kwargs=None, batch_update_lr=False, **optim_kwargs)[source]
Bases:
object
Wrapper around torch optimizers to enable reset and automatic learning rate handling. Saves the optimizer/learning rate scheduler initialization arguments other than the parameters/optimizer. Replace
1opt = OptimizerType(params, **opt_init_args) 2# optionally lr_scheduler: 3lr_scheduler = LRSchedulerType(opt, **lr_init_args) 4for epoch in epochs: 5 for batch in dataset: 6 opt.zero_grad() 7 ... 8 opt.step() 9 lr_scheduler.step() # if batch-wise update 10 lr_scheduler.step() # if epoch-wise update
with
1opt_handle = ResetOptimizer(OptimizerType, 2 LRSchedulerType, lr_init_args, 3 # if batch-wise lr updates: 4 batch_update_lr=True) 5... 6opt = opt_handle(params) 7for epoch in epochs: 8 for batch in dataset: 9 opt.zero_grad() 10 ... 11 opt.step() 12 opt.epoch_end()
A cast as in
torch.optim.Optimizer.cast()
is currently not supported.Public Data Attributes:
The used learning rate default of the optimizer (starting value for lr scheduler)
Return nice dict representation of init args and optimizer type.
Public Methods:
reset
()Reset optimizer (and lr scheduler); requires init before next use of optimizer.
init
(params)Initialize optimizer and learning rate scheduler with given parameters.
Wrapper around
torch.optim.Optimizer.zero_grad()
.add_param_group
(param_group)Wrapper around
torch.optim.Optimizer.add_param_group()
.step
([closure, epoch])Step the optimizer and the learning rate scheduler if set.
Update of learning rate scheduler after an epoch.
Special Methods:
__init__
(optim_type[, lr_scheduler_type, ...])Init.
__call__
(params)Return a fresh instance of the optimizer with the saved settings.
__repr__
()Return repr(self).
__str__
()Return str(self).
- Parameters
- __call__(params)[source]
Return a fresh instance of the optimizer with the saved settings. Intends to wrap call to
_optim_type
, such that an instance of a resettable optimizer can replace the type of its target optimizer.
- __init__(optim_type, lr_scheduler_type=None, lr_kwargs=None, batch_update_lr=False, **optim_kwargs)[source]
Init.
- Parameters
optim_type (Callable[[...],
torch.optim.Optimizer
]) – a callable that yields an optimizer when called with model parameters and theoptim_kwargs
optim_kwargs – keyword arguments for creating the optimizer
lr_scheduler_type (Optional[Callable]) – Optional; if given, a callable that yields a learning rate scheduler at which to register the optimizer
lr_kwargs (Optional[Dict[str, Any]]) – arguments for the optional learning rate scheduler
batch_update_lr (bool) – whether to step the learning rate scheduler after each batch, so on
step()
call, or only onepoch_end()
call; set this e.g. fortorch.optim.lr_scheduler.CyclicLR
- add_param_group(param_group)[source]
Wrapper around
torch.optim.Optimizer.add_param_group()
. Only call afterinit()
or__call__()
.
- epoch_end()[source]
Update of learning rate scheduler after an epoch. Only specific learning rate scheduler require updating after the epoch.
- zero_grad()[source]
Wrapper around
torch.optim.Optimizer.zero_grad()
. Only call afterinit()
or__call__()
.
- _lr_kwargs: Dict[str, Any]
The arguments to
_lr_scheduler_type
besides the optimizer. Only used if_lr_scheduler_type
is set.
- _lr_scheduler_type: Optional[Callable[..., Any]]
Optional learning rate scheduler builder. Used in
init()
to get new scheduler.
- _optim_kwargs: Dict[str, Any]
The arguments (besides the
parameters
key) to_optim_type
to get new optimizer.
- _optim_type: Callable[..., Optimizer]
Type/builder for optimizer. Used in
init()
to get new optimizer.
- batch_update_lr: bool
Whether to call step on the learning rate scheduler after each batch or only after each epoch. Batch-wise updates are handled in
step()
, epoch-wise ones inepoch_end()
,
- property lr: float
The used learning rate default of the optimizer (starting value for lr scheduler)
- lr_scheduler: Optional
Reference to the optional current learning rate scheduler.
None
if_lr_scheduler_type
isNone
, afterreset()
and beforeinit()
.
- optimizer: Optional[Optimizer]
Reference to the current optimizer.
None
afterreset()
and beforeinit()
- property settings: Dict
Return nice dict representation of init args and optimizer type.