amago.utils#
Miscellaneous training loop utilities, optimizers, and schedulers.
Functions
|
Print a warning message in green, usually to warn about unintuitive hparam settings at startup. |
|
Average a dictionary of ints or floats over all devices. |
|
Convenience that calls a method over (async) parallel envs and waits for the results. |
|
Count the number of trainable parameters in a pytorch module. |
|
Get a constant learning rate schedule with a warmup period. |
|
Get the (L2) norm of the gradients for a pytorch module. |
Convert the active gin config to a dictionary for convenient logging to wandb. |
|
|
Average a tensor over a mask. |
|
Load a model checkpoint with a retry loop in case of async read/write issues |
|
|
|
Split a dictionary of numpy arrays into a list of dictionaries of numpy arrays. |
|
Stack a list of dictionaries of numpy arrays. |
|
Sum a dictionary of ints or floats over all devices. |
Classes
|
A variant of AdamW with timestep resets. |
Exceptions
|
- class AdamWRel(params, reset_interval=gin.REQUIRED, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)[source]#
Bases:
AdamW
A variant of AdamW with timestep resets.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Implementation of the optimizer discussed in “Adam on Local Time: Addressing Nonstationarity in RL with Relative Adam Timesteps”, Ellis et al., 2024. (https://openreview.net/pdf?id=biAqUbAuG7). Treats optimization of an RL policy as a sequence of stationary supervised learning stages, and resets Adam’s timestep variable accordingly.
Keyword arguments follow the main AdamW.
- Parameters:
reset_interval (
int
) – Number of gradient steps between resets of Adam’s time / step count tracker. Must be configured by gin.
- amago_warning(msg, category=<class 'amago.utils.AmagoWarning'>)[source]#
Print a warning message in green, usually to warn about unintuitive hparam settings at startup.
- avg_over_accelerate(data)[source]#
Average a dictionary of ints or floats over all devices.
- Parameters:
data (
dict
[str
,int
|float
]) – Dictionary of ints or floats.- Return type:
dict
[str
,int
|float
]
- call_async_env(env, method_name, *args, **kwargs)[source]#
Convenience that calls a method over (async) parallel envs and waits for the results.
- count_params(model)[source]#
Count the number of trainable parameters in a pytorch module.
- Parameters:
model (
Module
) – Pytorch module.- Return type:
int
- get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1)[source]#
Get a constant learning rate schedule with a warmup period.
- get_grad_norm(model)[source]#
Get the (L2) norm of the gradients for a pytorch module.
- Parameters:
model (
Module
) – Pytorch module.- Return type:
float
- gin_as_wandb_config()[source]#
Convert the active gin config to a dictionary for convenient logging to wandb.
- Return type:
dict
- masked_avg(tensor, mask)[source]#
Average a tensor over a mask.
- Parameters:
tensor (
Tensor
) – Tensor to average.mask (
Tensor
) – Mask to average over. False where indices should be ignored.
- Return type:
Tensor
- retry_load_checkpoint(ckpt_path, map_location, tries=10)[source]#
Load a model checkpoint with a retry loop in case of async read/write issues
- Parameters:
ckpt_path – Path to the checkpoint file.
map_location – Device map location for the checkpoint.
tries (
int
) – Number of tries to load the checkpoint before giving up.
- Returns:
torch.load() result. None if load failed.
- Return type:
ckpt
- split_dict(dict_, axis=0)[source]#
Split a dictionary of numpy arrays into a list of dictionaries of numpy arrays.
Inverse of
stack_list_array_dicts
.- Parameters:
dict – Dictionary of numpy arrays.
axis – Axis to split along.
- Return type:
list
[dict
[str
,ndarray
]]
- stack_list_array_dicts(list_, axis=0, cat=False)[source]#
Stack a list of dictionaries of numpy arrays.
- Parameters:
list – List of dictionaries of numpy arrays.
axis – Axis to stack along.
cat (
bool
) – Whether to concatenate along an existing axis instead of stacking along a new one.
- Return type:
dict
[str
,ndarray
]