amago.utils#

Miscellaneous training loop utilities, optimizers, and schedulers.

Functions

amago_warning(msg[, category])

Print a warning message in green, usually to warn about unintuitive hparam settings at startup.

avg_over_accelerate(data)

Average a dictionary of ints or floats over all devices.

call_async_env(env, method_name, *args, **kwargs)

Convenience that calls a method over (async) parallel envs and waits for the results.

count_params(model)

Count the number of trainable parameters in a pytorch module.

get_constant_schedule_with_warmup(optimizer, ...)

Get a constant learning rate schedule with a warmup period.

get_grad_norm(model)

Get the (L2) norm of the gradients for a pytorch module.

gin_as_wandb_config()

Convert the active gin config to a dictionary for convenient logging to wandb.

masked_avg(tensor, mask)

Average a tensor over a mask.

retry_load_checkpoint(ckpt_path, map_location)

Load a model checkpoint with a retry loop in case of async read/write issues

split_batch(arr, axis)

split_dict(dict_[, axis])

Split a dictionary of numpy arrays into a list of dictionaries of numpy arrays.

stack_list_array_dicts(list_[, axis, cat])

Stack a list of dictionaries of numpy arrays.

sum_over_accelerate(data)

Sum a dictionary of ints or floats over all devices.

Classes

AdamWRel(params[, reset_interval, lr, ...])

A variant of AdamW with timestep resets.

Exceptions

AmagoWarning(*args, **kwargs)

class AdamWRel(params, reset_interval=gin.REQUIRED, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)[source]#

Bases: AdamW

A variant of AdamW with timestep resets.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Implementation of the optimizer discussed in “Adam on Local Time: Addressing Nonstationarity in RL with Relative Adam Timesteps”, Ellis et al., 2024. (https://openreview.net/pdf?id=biAqUbAuG7). Treats optimization of an RL policy as a sequence of stationary supervised learning stages, and resets Adam’s timestep variable accordingly.

Keyword arguments follow the main AdamW.

Parameters:

reset_interval (int) – Number of gradient steps between resets of Adam’s time / step count tracker. Must be configured by gin.

step(closure=None)[source]#

Perform a single optimization step.

Parameters:

closure (Callable, optional) – A closure that reevaluates the model and returns the loss.

exception AmagoWarning(*args, **kwargs)[source]#

Bases: Warning

amago_warning(msg, category=<class 'amago.utils.AmagoWarning'>)[source]#

Print a warning message in green, usually to warn about unintuitive hparam settings at startup.

avg_over_accelerate(data)[source]#

Average a dictionary of ints or floats over all devices.

Parameters:

data (dict[str, int | float]) – Dictionary of ints or floats.

Return type:

dict[str, int | float]

call_async_env(env, method_name, *args, **kwargs)[source]#

Convenience that calls a method over (async) parallel envs and waits for the results.

count_params(model)[source]#

Count the number of trainable parameters in a pytorch module.

Parameters:

model (Module) – Pytorch module.

Return type:

int

get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1)[source]#

Get a constant learning rate schedule with a warmup period.

get_grad_norm(model)[source]#

Get the (L2) norm of the gradients for a pytorch module.

Parameters:

model (Module) – Pytorch module.

Return type:

float

gin_as_wandb_config()[source]#

Convert the active gin config to a dictionary for convenient logging to wandb.

Return type:

dict

masked_avg(tensor, mask)[source]#

Average a tensor over a mask.

Parameters:
  • tensor (Tensor) – Tensor to average.

  • mask (Tensor) – Mask to average over. False where indices should be ignored.

Return type:

Tensor

retry_load_checkpoint(ckpt_path, map_location, tries=10)[source]#

Load a model checkpoint with a retry loop in case of async read/write issues

Parameters:
  • ckpt_path – Path to the checkpoint file.

  • map_location – Device map location for the checkpoint.

  • tries (int) – Number of tries to load the checkpoint before giving up.

Returns:

torch.load() result. None if load failed.

Return type:

ckpt

split_batch(arr, axis)[source]#
Return type:

list[ndarray]

split_dict(dict_, axis=0)[source]#

Split a dictionary of numpy arrays into a list of dictionaries of numpy arrays.

Inverse of stack_list_array_dicts.

Parameters:
  • dict – Dictionary of numpy arrays.

  • axis – Axis to split along.

Return type:

list[dict[str, ndarray]]

stack_list_array_dicts(list_, axis=0, cat=False)[source]#

Stack a list of dictionaries of numpy arrays.

Parameters:
  • list – List of dictionaries of numpy arrays.

  • axis – Axis to stack along.

  • cat (bool) – Whether to concatenate along an existing axis instead of stacking along a new one.

Return type:

dict[str, ndarray]

sum_over_accelerate(data)[source]#

Sum a dictionary of ints or floats over all devices.

Parameters:

data (dict[str, int | float]) – Dictionary of ints or floats.

Return type:

dict[str, int | float]