amago.nets.utils#

Miscellaneous utilities for neural network modules.

Functions

activation_switch(activation)

Quick switch for the activation function.

add_activation_log(root_key, activation[, ...])

Add activation statistics to a logging dictionary.

symexp(x)

Symmetric exponential transform.

symlog(x)

Symmetric log transform.

Classes

InputNorm(dim[, beta, init_nu, skip])

Moving-average feature normalization.

SlowAdaptiveRational([trainable])

Adaptive Rational Activation.

class InputNorm(dim, beta=0.0001, init_nu=1.0, skip=False)[source]#

Bases: Module

Moving-average feature normalization.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Normalizes input features using a moving average of their statistics. This helps stabilize training by keeping the input distribution relatively constant.

Parameters:
  • dim – Dimension of the input feature.

  • beta – Smoothing parameter for the moving average. Defaults to 1e-4.

  • init_nu – Initial value for the moving average of the squared feature values. Defaults to 1.0.

  • skip (no gin) – Whether to skip normalization. Defaults to False. Cannot be configured via gin (disable input norm in the TstepEncoder config).

denormalize_values(val)[source]#
Return type:

Tensor

forward(x, denormalize=False)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

masked_stats(val)[source]#
Return type:

tuple[Tensor, Tensor]

normalize_values(val)[source]#
Return type:

Tensor

property sigma#
update_stats(val)[source]#
Return type:

None

class SlowAdaptiveRational(trainable=True)[source]#

Bases: Module

Adaptive Rational Activation.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A slow non-cuda version of “Adaptive Rational Activations to Boost Deep Reinforcement Learning”, Delfosse et al., 2021 (https://arxiv.org/pdf/2102.09407.pdf). Hardcoded to the Leaky Relu version.

Parameters:

trainable (bool) – Whether to train the parameters of the activation. Defaults to True.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

activation_switch(activation)[source]#

Quick switch for the activation function.

Parameters:

activation (str) – The activation function name. Options are: - “leaky_relu” (Leaky ReLU) - “relu” (ReLU) - “gelu” (GeLU) - “adaptive” (SlowAdaptiveRational)

Return type:

callable

Returns:

The activation function (callable).

Raises:

ValueError – If the activation function name is not recognized.

add_activation_log(root_key, activation, log_dict=None)[source]#

Add activation statistics to a logging dictionary.

Logs the maximum, minimum, standard deviation, and mean of the activation tensor under the key prefix “activation-{root_key}-“.

Parameters:
  • root_key (str) – Prefix for the log keys.

  • activation (Tensor) – Tensor to compute statistics from.

  • log_dict (dict | None) – Dictionary to add statistics to. If None, no logging is performed.

symexp(x)[source]#

Symmetric exponential transform.

Applies sign(x) * (exp(|x|) - 1) to the input. This is the inverse of the symmetric log transform.

Parameters:

x (Tensor | float) – Input tensor or scalar value.

Return type:

Tensor | float

Returns:

symexp(x) as a Tensor if x is a Tensor, otherwise symexp(x) as a float.

symlog(x)[source]#

Symmetric log transform.

Applies sign(x) * log(|x| + 1) to the input. This transform is useful for rescaling ~unbounded ranges to a suitable range for network inputs/outputs.

Parameters:

x (Tensor | float) – Input tensor or scalar value.

Return type:

Tensor | float

Returns:

symlog(x) as a Tensor if x is a Tensor, otherwise symlog(x) as a float.