amago.nets.actor_critic#
Actor and critic output modules.
Functions
|
Classes
|
Actor output head architecture. |
|
|
|
|
|
Critic output head architecture. |
|
Critic output head architecture. |
|
PopArt value normalization. |
|
Actor output head with residual blocks. |
- class Actor(state_dim, action_dim, discrete, gammas, n_layers=2, d_hidden=256, activation='leaky_relu', dropout_p=0.0, continuous_dist_type=<class 'amago.nets.policy_dists.TanhGaussian'>)[source]#
Bases:
BaseActorHead
Actor output head architecture.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A (small) MLP that maps the output of the TrajEncoder to a distribution over actions.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacediscrete (
bool
) – Whether the action space is discretegammas (
Tensor
) – List of gamma values to use for the multi-gamma actorn_layers (
int
) – Number of layers in the MLP. Defaults to 2.d_hidden (
int
) – Dimension of hidden layers in the MLP. Defaults to 256.activation (
str
) – Activation function to use in the MLP. Defaults to “leaky_relu”.dropout_p (
float
) – Dropout rate to use in the MLP. Defaults to 0.0.continuous_dist_type (
Type
[PolicyOutput
]) – Type of continuous distribution to use if applicable. Must be aPolicyOutput
. Defaults toTanhGaussian
.
- class BaseActorHead(state_dim, action_dim, discrete, gammas, continuous_dist_type)[source]#
Bases:
Module
,ABC
- forward(state, log_dict=None)[source]#
Compute an action distribution from a state representation.
- Parameters:
state (
Tensor
) – The “state” sequence (the output of the TrajEncoder) (Batch, Length, state_dim)- Return type:
Distribution
- Returns:
The action distribution. Type varies according to the output of
PolicyOutput
(e.g.Discrete
orTanhGaussian
). Always a pytorch distribution (e.g.,Categorical
) where sampled actions would have shape (Batch, Length, Gammas, action_dim).
- class BaseCriticHead(state_dim, action_dim, discrete, gammas, num_critics)[source]#
Bases:
Module
,ABC
- forward(state, action, log_dict=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tensor
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class NCritics(state_dim, action_dim, discrete, gammas, num_critics, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu')[source]#
Bases:
BaseCriticHead
Critic output head architecture.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A (small) ensemble of MLPs that maps the state and action to a value estimate.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacediscrete (
bool
) – Whether the action space is discretegammas (
Tensor
) – List of gamma values to use for the multi-gamma criticnum_critics (
int
) – Number of critics in the ensemble. Defaults to 4.d_hidden (
int
) – Dimension of hidden layers in the MLP. Defaults to 256.n_layers (
int
) – Number of layers in the MLP. Defaults to 2.dropout_p (
float
) – Dropout rate to use in the MLP. Defaults to 0.0.activation (
str
) – Activation function to use in the MLP. Defaults to “leaky_relu”.
- critic_network_forward(state, action, log_dict=None)[source]#
Compute a value estimate from a state and action.
- Parameters:
state (
Tensor
) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).action (
Tensor
) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.
- Return type:
Tensor
- Returns:
The value estimate with shape (K, Batch, Length, num_critics, Gammas, 1).
- class NCriticsTwoHot(state_dim, action_dim, gammas, discrete, num_critics, d_hidden=256, n_layers=2, dropout_p=0.0, activation='leaky_relu', min_return=None, max_return=None, output_bins=128, use_symlog=True)[source]#
Bases:
BaseCriticHead
Critic output head architecture.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.A (small) ensemble of MLPs that maps the state and action to a value estimate in the form of a categorical distribution over bins.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacegammas (
Tensor
) – List of gamma values to use for the multi-gamma criticnum_critics (
int
) – Number of critics in the ensemble. Defaults to 4.d_hidden (
int
) – Dimension of hidden layers in the MLP. Defaults to 256.n_layers (
int
) – Number of layers in the MLP. Defaults to 2.dropout_p (
float
) – Dropout rate to use in the MLP. Defaults to 0.0.activation (
str
) – Activation function to use in the MLP. Defaults to “leaky_relu”.min_return (
float
|None
) – Minimum return value. If not set, defaults to a very negative value (-100_000).max_return (
float
|None
) – Maximum return value. If not set, defaults to a very positive value (100_000).output_bins (
int
) – Number of bins in the categorical distribution. Defaults to 128.use_symlog (
bool
) – Whether to use a symlog transformation on the value estimates. Defaults to True.
Note
The default bin settings (wide range, lots of bins, symlog transformation) follow Dreamer-V3 in picking a range that does not demand domain-specific tuning. It may be more sample efficient to use tighter bounds, in which case the unintuitive spacing created by use_symlog may be turned off. However, note that the min_return and max_return do not compensate for Agent.reward_multiplier. For example, if the highest possible return in an env is 1, and reward multiplier is 10, then a tuned max_return might be 10 but should never be 1. More discussion on bin settings in AMAGO-2 Appendix A.
- bin_dist_to_raw_vals(bin_dist)[source]#
Convert a categorical distribution over bins to a scalar.
- Parameters:
bin_dist (
Categorical
) – The categorical distribution over bins (output offorward
).- Return type:
Tensor
- Returns:
The scalar value.
- critic_network_forward(state, action, log_dict=None)[source]#
Compute a categorical distribution over bins from a state and action.
- Parameters:
state (
Tensor
) – The “state” sequence (the output of the TrajEncoder). Has shape (Batch, Length, state_dim).action (
Tensor
) – The action sequence. Has shape (K, Batch, Length, Gammas, action_dim), where K is a dimension denoting multiple action samples from the same state (can be 1, but must exist). Discrete actions are expected to be one-hot vectors.
- Return type:
Categorical
- Returns:
The categorical distribution over bins with shape (K, Batch, Length, num_critics, output_bins).
- raw_vals_to_labels(raw_td_target)[source]#
Convert a scalar to a categorical distribution over bins.
Just a torch port of the
dreamerv3/jaxutils.py
implementation.- Parameters:
raw_td_target (
Tensor
) – The scalar value.- Return type:
Tensor
- Returns:
A two-hot encoded categorical distribution over bins.
- class PopArtLayer(gammas, beta=0.0005, init_nu=100.0, enabled=True)[source]#
Bases:
Module
PopArt value normalization.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Shifts value estimates according to a moving average and helps the outputs of the critic to compensate for the distribution shift. (https://arxiv.org/abs/1809.04474)
- Parameters:
gammas (
int
) – Number of gamma values in the critic.beta (
float
) – The beta parameter for the moving average. Defaults to 5e-4.init_nu (
float
) – The initial nu parameter. Defaults to 100.0 following a recommendation in the PopArt paper.enabled (no gin) – If False, this layer is a no-op. Defaults to True. Cannot be configured by gin. Instead, use
Agent.use_popart
.
- forward(x, normalized=True)[source]#
Modify the value estimate according to the PopArt layer.
Applies normalization or denormalization to value estimates using PopArt’s moving average statistics. When normalized=True, scales and shifts values using the current statistics to normalize them. When normalized=False, maps normalized values back to the original scale of the environment.
- Parameters:
x (
Tensor
) – Value estimate to modifynormalized (
bool
) – Whether to normalize (True) or denormalize (False) the values
- Return type:
Tensor
- Returns:
Modified value estimate in either normalized or denormalized form
- property sigma: Tensor#
- class ResidualActor(state_dim, action_dim, discrete, gammas, feature_dim=256, residual_ff_dim=512, residual_blocks=2, activation='leaky_relu', normalization='layer', dropout_p=0.0, continuous_dist_type=<class 'amago.nets.policy_dists.TanhGaussian'>)[source]#
Bases:
BaseActorHead
Actor output head with residual blocks.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.Based on BRO https://arxiv.org/pdf/2405.16158v1, which recommends similar hparams to our exsiting defaults.
- Parameters:
state_dim (
int
) – Dimension of the “state” space (which is the output of the TrajEncoder)action_dim (
int
) – Dimension of the action spacediscrete (
bool
) – Whether the action space is discretegammas (
Tensor
) – List of gamma values to use for the multi-gamma actorfeature_dim (
int
) – Dimension of the embedding between residual blocks (analogous to d_model in a Transformer). Defaults to 256.residual_ff_dim (
int
) – Hidden dimension of residual blocks (analogous to d_ff in a Transformer). Defaults to 512.residual_blocks (
int
) – Number of residual blocks. Defaults to 2.activation (
str
) – Activation function to use in the MLPs. Defaults to “leaky_relu”.normalization (
str
) – Normalization to use in the residual blocks. Defaults to “layer” (LayerNorm).dropout_p (
float
) – Dropout rate to use in the initial linear layers. Defaults to 0.0.continuous_dist_type (
Type
[PolicyOutput
]) – Type of continuous distribution to use if applicable. Must be aPolicyOutput
. Defaults toTanhGaussian
.