amago.agent#

Actor-Critic agents and RL objectives.

Functions

binary_filter(adv[, threshold])

Weights policy regression data according to (adv > threshold).float()

exp_filter(adv[, beta, clip_adv_low, ...])

Weights policy regression data according to exp(beta * adv).

Classes

Agent(obs_space, rl2_space, action_space, ...)

Actor-Critic with a shared sequence model backbone.

MultiTaskAgent(obs_space, rl2_space, ...[, ...])

A variant of Agent aimed at learning from distinct reward functions.

Multigammas([discrete, continuous])

A hook for gin configuration of Multi-gamma values.

class Agent(obs_space, rl2_space, action_space, max_seq_len, tstep_encoder_type, traj_encoder_type, num_critics=4, num_critics_td=2, online_coeff=1.0, offline_coeff=0.1, gamma=0.999, reward_multiplier=10.0, tau=0.003, fake_filter=False, num_actions_for_value_in_critic_loss=1, num_actions_for_value_in_actor_loss=1, fbc_filter_func=<function binary_filter>, popart=True, use_target_actor=True, use_multigamma=True)[source]#

Bases: Module

Actor-Critic with a shared sequence model backbone.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Agent manages the training and inference of a sequence model policy. The base learning update is a heavily parallelized/ensembled version of DDPG/TD3/REDQ/etc. + CRR/AWAC.

Given a sequence of trajectory data traj_seq, we embed and encode the sequence as follows:

emb_seq = timestep_encoder(traj_seq) # [B, L, dim]
state_emb_seq = traj_encoder(emb_seq)  # [B, L, dim]
action_dist = actor(state_emb_seq)

If using a discrete action space, the critic outputs a vector of Q-values (one per action), and continuous actions follow the (state + action) –> scalar setup.

if discrete:
    value_pred = critic(state_emb_seq)[action_dist.sample()]
else:
    value_pred = critic(state_emb_seq, action_dist.sample())

Value estimates are derived from Q-vals according to:

def Q(state, critic, action) -> float:
    if discrete:
        return critic(state)[action]
    else:
        return critic(state, action)

def V(state, critic, action_dist, k) -> float:
    if discrete:
        return (critic(state) * action_dist.probs).sum()
    else:
        return 1 / k * sum(Q(state, critic, action_dist.sample()) for _ in range(k))

k_c = num_actions_for_value_in_critic_loss
td_target = mean_or_min_over_ensemble(
    r + gamma * (1 - d) * V(next_state_emb, target_critic, target_actor(next_state_emb), k_c)
)

The advantage estimate and corresponding losses are:

k_a = num_actions_for_value_in_actor_loss
advantages = Q(state_emb, critic, action) - V(state_emb, critic, action_dist, k_a)

offline_loss = -fbc_filter_func(advantages) * actor(state_emb).log_prob(action)
online_loss = -V(state_emb, critic.detach(), actor(state_emb), k_a)

actor_loss = online_coeff * offline_loss + online_coeff * online_loss
critic_loss = (Q(state_emb, critic, action) - td_target) ** 2

And this is done in parallel across every timestep and multiple values of the discount factor gamma.

Parameters:
  • obs_space (Dict) – Environment observation space (for creating input layers).

  • rl2_space (Box) – A gymnasium space that is automatically generated by AMAGOEnv to represent the shape of extra input features for the previous action and reward.

  • action_space (Space) – Environment action space (for creating output layers).

  • max_seq_len (int) – Maximum context length of the policy (in timesteps).

  • tstep_encoder_type (Type[TstepEncoder]) – Type of TstepEncoder to use. Initialized based on provided gym spaces.

  • traj_encoder_type (Type[TrajEncoder]) – Type of TrajEncoder to use. Initialized based on provided gym spaces.

  • num_critics (int) – Number of critics in the ensemble. Defaults to 4.

  • num_critics_td (int) – Number of critics from the (larger) ensemble used to create clipped double q targets (REDQ). Defaults to 2.

  • online_coeff (float) – Weight of the “online” aka DPG/TD3-like actor loss -Q(s, a ~ pi(s)). Defaults to 1.0.

  • offline_coeff (float) – Weight of the “offline” aka advantage weighted/”filtered” regression term (CRR/AWAC). Defaults to 0.1.

  • gamma (float) – Discount factor of the policy we sample during rollouts/evals. Defaults to 0.999.

  • reward_multiplier (float) – Scale every reward by a constant (for loss function only). Only relevant for numerical stability in value normalization. Avoid large (> 1e5) and small (< 1) absolute values of returns when reward functions are known. Defaults to 10.0.

  • tau (float) – Polyak averaging factor for target network updates (DDPG-like). Defaults to 0.003.

  • fake_filter (bool) – If True, skips computation of the advantage weights/”filter”. Speeds up pure behavior cloning. Defaults to False.

  • num_actions_for_value_in_critic_loss (int) – Number of actions used to estimate E_[Q(s, a ~ pi)] for continuous action spaces in critic loss (TD targets). Defaults to 1.

  • num_actions_for_value_in_actor_loss (int) – Number of actions used to estimate E_[Q(s, a ~ pi)] for continuous action spaces in the actor loss. Defaults to 1.

  • fbc_filter_func (callable) – Function that takes seq of advantage estimates and outputs the regression weights. See binary_filter() or exp_filter(). Defaults to binary_filter().

  • popart (bool) – If True, use PopArtLayer normalization for value network outputs. Defaults to True.

  • use_target_actor (bool) – If True, use a target actor to sample actions used in TD targets. Defaults to True.

  • use_multigamma (bool) – If True, train on multiple discount horizons (Multigammas) in parallel. Defaults to True.

forward(batch, log_step)[source]#

Computes actor and critic losses from a Batch of trajectory data.

Parameters:
  • batch (Batch) – Batch object containing trajectory data including observations, actions, rewards, dones, etc.

  • log_step (bool) – If True, computes and stores additional statistics in self.update_info for wandb logging.

Returns:

  • critic_loss: Tensor of shape (B, L-1, num_critics, G, 1) where B is batch size,

    L is sequence length, G is number of discount factors

  • actor_loss: Tensor of shape (B, L-1, G, 1)

Return type:

A tuple containing

get_actions(obs, rl2s, time_idxs, hidden_state=None, sample=True)[source]#

Get rollout actions from the current policy.

Note the standard torch forward implements the training step, while get_actions is the inference step. Most of the arguments here are easily gathered from the AMAGOEnv gymnasium wrapper. See amago.experiment.Experiment.interact for an example.

Parameters:
  • obs (dict[str, Tensor]) – Dictionary of (batched) observation tensors. AMAGOEnv makes all observations into dicts.

  • rl2s (Tensor) – Batched Tensor of previous action and reward. AMAGOEnv makes these.

  • time_idxs (Tensor) – Batched Tensor indicating the global timestep of the episode. Mainly used for position embeddings when the sequence length is much shorter than the episode length.

  • hidden_state – Hidden state of the TrajEncoder. Defaults to None.

  • sample (bool) – Whether to sample from the action distribution or take the argmax (discrete) or mean (continuous). Defaults to True.

Returns:

  • Batched Tensor of actions to take in each parallel env for the primary (“test-time”) discount factor Agent.gamma.

  • Updated hidden state of the TrajEncoder.

Return type:

tuple

hard_sync_targets()[source]#

Hard copy online actor/critics to target actor/critics

soft_sync_targets()[source]#

EMA copy online actor/critics to target actor/critics (DDPG-style)

property trainable_params#

Iterable over all trainable parameters, which should be passed to the optimizer.

class MultiTaskAgent(obs_space, rl2_space, action_space, tstep_encoder_type, traj_encoder_type, max_seq_len, num_critics=4, num_critics_td=2, online_coeff=0.0, offline_coeff=1.0, gamma=0.999, reward_multiplier=10.0, tau=0.003, fake_filter=False, num_actions_for_value_in_critic_loss=1, num_actions_for_value_in_actor_loss=3, fbc_filter_func=<function binary_filter>, popart=True, use_target_actor=True, use_multigamma=True)[source]#

Bases: Agent

A variant of Agent aimed at learning from distinct reward functions.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Strives to balance the training loss across tasks with different return scales without resorting to one-hot task IDs. Standard multi-task RL (e.g., N atari games) are all good examples, but so are multi-domain meta-RL problems like Meta-World ML45. This is the agent discussed in the AMAGO-2 paper.

Follows the same learning update as Agent, with three main differences:

  1. Converts critic regression to classification of two-hot encoded labels representing bins spaced across a fixed range (see amago.nets.actor_critic.NCriticsTwoHot). The version here closely follows Dreamer-V3.

  2. Converts the discrete setup of Agent (where critics output a vector of vals per action) to the same format as continuous actions (state + action) –> scalar. This avoids large critic outputs layers but removes our ability to directly compute E_{a ~ π}[Q(s, a)].

  3. Defaults to an online_coeff of 0 and an offline_coeff of 1.0. This is because the “online” loss (-Q(s, a ~ pi)) scales with the magnitude of Q. The online loss is still available as long as the output of the actor network uses the reparameterization trick. Discrete actions are supported via a gumbel softmax, but this has seen limited testing.

The combination of points 2 and 3 stresses accurate advantage estimates and motivates a change in the default value of num_actions_for_value_in_critic_loss from 1 –> 3. Arguments otherwise follow the information listed in amago.agent.Agent.

forward(batch, log_step)[source]#

Computes actor and critic losses from a Batch of trajectory data.

Parameters:
  • batch (Batch) – Batch object containing trajectory data including observations, actions, rewards, dones, etc.

  • log_step (bool) – If True, computes and stores additional statistics in self.update_info for wandb logging.

Returns:

  • critic_loss: Tensor of shape (B, L-1, num_critics, G, 1) where B is batch size,

    L is sequence length, G is number of discount factors

  • actor_loss: Tensor of shape (B, L-1, G, 1)

Return type:

A tuple containing

class Multigammas(discrete=[0.1, 0.9, 0.95, 0.97, 0.99, 0.995], continuous=[0.1, 0.9, 0.95, 0.97, 0.99, 0.995])[source]#

Bases: object

A hook for gin configuration of Multi-gamma values.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Defines the list of gamma values used during training in addition to the main gamma parameter in Agent, which is the value used during rollouts/evals by default. Settings are divided into discrete and continuous action spaces versions, because the cost of adding gammas tends to be much higher for continuous action critics, where they multiply the effective batch size of the actor/critic loss computation. Note that adding gammas has no effect on the batch size of the heavier sequence model backbone. Therefore the relative cost of this trick decreases as the overall model size increases.

Parameters:
  • discrete (List[float]) – List of gamma values for discrete action spaces

  • continuous (List[float]) – List of gamma values for continuous action spaces

binary_filter(adv, threshold=0.0)[source]#

Weights policy regression data according to (adv > threshold).float()

Parameters:
  • adv (Tensor) – Tensor of advantages (Batch, Length, Gammas, 1)

  • threshold (float) – Float, the threshold for the binary filter. Defaults to 0.0.

Return type:

Tensor

exp_filter(adv, beta=1.0, clip_adv_low=None, clip_adv_high=None, clip_weights_low=1e-07, clip_weights_high=None)[source]#

Weights policy regression data according to exp(beta * adv).

Parameters:
  • adv (Tensor) – Tensor of advantages (Batch, Length, Gammas, 1)

  • beta (float) – Float, the beta parameter for the exponential filter. Note that some papers define the beta hparam according to exp( 1/beta * adv ), so check whether you need to invert the value to match their setting. Defaults to 1.0.

  • clip_adv_low (float | None) – If provided, clip input advantages below this value. Defaults to None.

  • clip_adv_high (float | None) – If provided, clip input advantages above this value. Defaults to None.

  • clip_weights_low (float | None) – If provided, clip output weights below this value. Defaults to 1e-7.

  • clip_weights_high (float | None) – If provided, clip output weights above this value. Defaults to None.

Return type:

Tensor