amago.nets.utils#
Miscellaneous utilities for neural network modules.
Functions
|
Quick switch for the activation function. |
|
Add activation statistics to a logging dictionary. |
|
Symmetric exponential transform. |
|
Symmetric log transform. |
Classes
|
Moving-average feature normalization. |
|
Adaptive Rational Activation. |
- class InputNorm(dim, beta=0.0001, init_nu=1.0, skip=False)[source]#
Bases:
Module
Moving-average feature normalization.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Normalizes input features using a moving average of their statistics. This helps stabilize training by keeping the input distribution relatively constant.
- Parameters:
dim – Dimension of the input feature.
beta – Smoothing parameter for the moving average. Defaults to 1e-4.
init_nu – Initial value for the moving average of the squared feature values. Defaults to 1.0.
skip (no gin) – Whether to skip normalization. Defaults to False. Cannot be configured via gin (disable input norm in the TstepEncoder config).
- forward(x, denormalize=False)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property sigma#
- class SlowAdaptiveRational(trainable=True)[source]#
Bases:
Module
Adaptive Rational Activation.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A slow non-cuda version of “Adaptive Rational Activations to Boost Deep Reinforcement Learning”, Delfosse et al., 2021 (https://arxiv.org/pdf/2102.09407.pdf). Hardcoded to the Leaky Relu version.
- Parameters:
trainable (
bool
) – Whether to train the parameters of the activation. Defaults to True.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- activation_switch(activation)[source]#
Quick switch for the activation function.
- Parameters:
activation (
str
) – The activation function name. Options are: - “leaky_relu” (Leaky ReLU) - “relu” (ReLU) - “gelu” (GeLU) - “adaptive” (SlowAdaptiveRational)- Return type:
callable
- Returns:
The activation function (callable).
- Raises:
ValueError – If the activation function name is not recognized.
- add_activation_log(root_key, activation, log_dict=None)[source]#
Add activation statistics to a logging dictionary.
Logs the maximum, minimum, standard deviation, and mean of the activation tensor under the key prefix “activation-{root_key}-“.
- Parameters:
root_key (
str
) – Prefix for the log keys.activation (
Tensor
) – Tensor to compute statistics from.log_dict (
dict
|None
) – Dictionary to add statistics to. If None, no logging is performed.
- symexp(x)[source]#
Symmetric exponential transform.
Applies sign(x) * (exp(|x|) - 1) to the input. This is the inverse of the symmetric log transform.
- Parameters:
x (
Tensor
|float
) – Input tensor or scalar value.- Return type:
Tensor
|float
- Returns:
symexp(x) as a Tensor if x is a Tensor, otherwise symexp(x) as a float.
- symlog(x)[source]#
Symmetric log transform.
Applies sign(x) * log(|x| + 1) to the input. This transform is useful for rescaling ~unbounded ranges to a suitable range for network inputs/outputs.
- Parameters:
x (
Tensor
|float
) – Input tensor or scalar value.- Return type:
Tensor
|float
- Returns:
symlog(x) as a Tensor if x is a Tensor, otherwise symlog(x) as a float.