amago.nets.traj_encoders#

Seq2Seq models for long-term memory.

Classes

FFTrajEncoder(tstep_dim, max_seq_len[, ...])

Feed-forward (memory-free) trajectory encoder.

GRUTrajEncoder(tstep_dim, max_seq_len[, ...])

RNN (GRU) Trajectory Encoder.

MambaTrajEncoder(tstep_dim, max_seq_len[, ...])

Mamba Trajectory Encoder.

TformerTrajEncoder(tstep_dim, max_seq_len[, ...])

Transformer Trajectory Encoder.

TrajEncoder(tstep_dim, max_seq_len)

Abstract base class for trajectory encoders.

class FFTrajEncoder(tstep_dim, max_seq_len, d_model=256, d_ff=None, n_layers=1, dropout=0.0, activation='leaky_relu', norm='layer')[source]#

Bases: TrajEncoder

Feed-forward (memory-free) trajectory encoder.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A useful tool for applying AMAGO to standard MDPs and benchmarking general RL details/hyperparamters on common benchmarks. The feed-forward architecture is designed to be close to an attention-less Transformer (residual blocks, norm, dropout, etc.). This makes it easy to create perfect 1:1 ablations of “memory vs. no memory” by only changing the TrajEncoder and without touching the max_seq_len, which would have the side-effect of changing the effective batch size of actor-critic learning.

Parameters:
  • tstep_dim – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.

  • max_seq_len – Maximum sequence length of the model. Any inputs will have been trimmed to this length before reaching the TrajEncoder.

  • d_model (int) – Dimension of the main residual stream and output. 1:1 with how this would be defined in a Transformer. Defaults to 256.

  • d_ff (int | None) – Hidden dim of the feed-forward network along each residual block. 1:1 with how this would be defined in a Transformer. Defaults to 4 * d_model.

  • n_layers (int) – Number of residual feed-forward blocks. 1:1 with how this would be defined in a Transformer. Defaults to 1.

  • dropout (float) – Dropout rate. Equivalent to the dropout paramter of feed-forward blocks in a Transformer, but is also applied to the first and last linear layers (inp –> d_model and d_model –> out). Defaults to 0.0.

  • activation (str) – Activation function. Defaults to “leaky_relu”.

  • norm (str) – Normalization function. Defaults to “layer” (LayerNorm).

property emb_dim#

Defines the expected output dim of this model.

Used to infer the input dim of actor/critics.

Returns:

The embedding dimension.

Return type:

int

forward(seq, time_idxs=None, hidden_state=None, log_dict=None)[source]#

Sequence model forward pass.

Parameters:
  • seq – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.

  • time_idxs – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.

  • hidden_state – Architecture-specific hidden state. Defaults to None.

Returns:

A tuple containing:
  • output_seq: [Batch, Timestep, self.emb_dim]. Output of our

    seq2seq model.

  • new_hidden_state: Architecture-specific hidden state. Expected to

    be None if input hidden_state is None. Otherwise, we assume we are at inference time and that this forward method has handled any updates to the hidden state that were needed.

Return type:

Tuple[torch.Tensor, Optional[Any]]

class GRUTrajEncoder(tstep_dim, max_seq_len, d_hidden=256, n_layers=2, d_output=256, norm='layer')[source]#

Bases: TrajEncoder

RNN (GRU) Trajectory Encoder.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • tstep_dim (int) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.

  • max_seq_len (int) – Maximum sequence length of the model. Any inputs will have been trimmed to this length before reaching the TrajEncoder.

  • d_hidden (int) – Dimension of the hidden state of the GRU. Defaults to 256.

  • n_layers (int) – Number of layers in the GRU. Defaults to 2.

  • d_output (int) – Dimension of the output linear layer after the GRU. Defaults to 256.

  • norm (str) – Normalization applied after the final linear layer. Defaults to “layer” (LayerNorm).

property emb_dim#

Defines the expected output dim of this model.

Used to infer the input dim of actor/critics.

Returns:

The embedding dimension.

Return type:

int

forward(seq, time_idxs=None, hidden_state=None, log_dict=None)[source]#

Sequence model forward pass.

Parameters:
  • seq – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.

  • time_idxs – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.

  • hidden_state – Architecture-specific hidden state. Defaults to None.

Returns:

A tuple containing:
  • output_seq: [Batch, Timestep, self.emb_dim]. Output of our

    seq2seq model.

  • new_hidden_state: Architecture-specific hidden state. Expected to

    be None if input hidden_state is None. Otherwise, we assume we are at inference time and that this forward method has handled any updates to the hidden state that were needed.

Return type:

Tuple[torch.Tensor, Optional[Any]]

reset_hidden_state(hidden_state, dones)[source]#

Hook to implement architecture-specific hidden state reset.

Parameters:
  • hidden_state – We only expect to see hidden states that were created by self.init_hidden_state().

  • dones – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.

Returns:

Architecture-specific hidden state. Defaults to a

no-op: new_hidden_state = hidden_state.

Return type:

Optional[Any]

class MambaTrajEncoder(tstep_dim, max_seq_len, d_model=256, d_state=16, d_conv=4, expand=2, n_layers=3, norm='layer')[source]#

Bases: TrajEncoder

Mamba Trajectory Encoder.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Implementation of the Mamba architecture from “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” (https://arxiv.org/abs/2312.00752).

Parameters:
  • tstep_dim (int) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.

  • max_seq_len (int) – Maximum sequence length of the model. The max context length of the model during training.

  • d_model (int) – Dimension of the main residual stream and output, analogous to the d_model in a Transformer. Defaults to 256.

  • d_state (int) – Dimension of the SSM in Mamba blocks. Defaults to 16.

  • d_conv (int) – Dimension of the convolution layer in Mamba blocks. Defaults to 4.

  • expand (int) – Expansion factor of the SSM. Defaults to 2.

  • n_layers (int) – Number of Mamba blocks. Defaults to 3.

  • norm (str) – Normalization function. Defaults to “layer” (LayerNorm).

References

property emb_dim#

Defines the expected output dim of this model.

Used to infer the input dim of actor/critics.

Returns:

The embedding dimension.

Return type:

int

forward(seq, time_idxs=None, hidden_state=None, log_dict=None)[source]#

Sequence model forward pass.

Parameters:
  • seq (Tensor) – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.

  • time_idxs (Tensor | None) – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.

  • hidden_state (_MambaHiddenState | None) – Architecture-specific hidden state. Defaults to None.

Returns:

A tuple containing:
  • output_seq: [Batch, Timestep, self.emb_dim]. Output of our

    seq2seq model.

  • new_hidden_state: Architecture-specific hidden state. Expected to

    be None if input hidden_state is None. Otherwise, we assume we are at inference time and that this forward method has handled any updates to the hidden state that were needed.

Return type:

Tuple[torch.Tensor, Optional[Any]]

init_hidden_state(batch_size, device)[source]#

Hook to create an architecture-specific hidden state.

Return value is passed as TrajEncoder.forward(..., hidden_state=self.init_hidden_state(...)) when the agent begins to interact with the environment.

Parameters:
  • batch_size (int) – Number of parallel environments.

  • device (device) – Device to store hidden state tensors (if applicable).

Returns:

Some hidden state object, or None if not applicable.

Defaults to None.

Return type:

Optional[Any]

reset_hidden_state(hidden_state, dones)[source]#

Hook to implement architecture-specific hidden state reset.

Parameters:
  • hidden_state (_MambaHiddenState | None) – We only expect to see hidden states that were created by self.init_hidden_state().

  • dones (ndarray) – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.

Returns:

Architecture-specific hidden state. Defaults to a

no-op: new_hidden_state = hidden_state.

Return type:

Optional[Any]

class TformerTrajEncoder(tstep_dim, max_seq_len, d_model=256, n_heads=8, d_ff=1024, n_layers=3, dropout_ff=0.05, dropout_emb=0.05, dropout_attn=0.0, dropout_qkv=0.0, activation='leaky_relu', norm='layer', pos_emb='fixed', sigma_reparam=True, normformer_norms=True, head_scaling=True, attention_type=<class 'amago.nets.transformer.FlashAttention'>)[source]#

Bases: TrajEncoder

Transformer Trajectory Encoder.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

A pre-norm Transformer decoder-only model that processes sequences of timestep embeddings.

Parameters:
  • tstep_dim (int) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.

  • max_seq_len (int) – Maximum sequence length of the model. The max context length of the model during training.

  • d_model (int) – Dimension of the main residual stream and output. Defaults to 256.

  • n_heads (int) – Number of self-attention heads. Each head has dimension d_model/n_heads. Defaults to 8.

  • d_ff (int) – Dimension of feed-forward network in residual blocks. Defaults to 4*d_model.

  • n_layers (int) – Number of Transformer layers. Defaults to 3.

  • dropout_ff (float) – Dropout rate for linear layers within Transformer. Defaults to 0.05.

  • dropout_emb (float) – Dropout rate for input embedding (combined input sequence and position embeddings passed to Transformer). Defaults to 0.05.

  • dropout_attn (float) – Dropout rate for attention matrix. Defaults to 0.00.

  • dropout_qkv (float) – Dropout rate for query/key/value projections. Defaults to 0.00.

  • activation (str) – Activation function. Defaults to “leaky_relu”.

  • norm (str) – Normalization function. Defaults to “layer” (LayerNorm).

  • pos_emb (str) – Position embedding type. “fixed” (default) uses sinusoidal embeddings, “learned” uses trainable embeddings per timestep.

  • causal – Whether to use causal attention mask. Defaults to True.

  • sigma_reparam (bool) – Whether to use \(\sigma\)-reparam feed-forward layers from https://arxiv.org/abs/2303.06296. Defaults to True.

  • normformer_norms (bool) – Whether to use extra norm layers from NormFormer (https://arxiv.org/abs/2110.09456). Always uses pre-norm Transformer.

  • head_scaling (bool) – Whether to use head scaling from NormFormer. Defaults to True.

  • attention_type (type[SelfAttention]) – Attention layer type. Defaults to transformer.FlashAttention. transformer.VanillaAttention provided as backup. New types can inherit from transformer.SelfAttention.

property emb_dim: int#

Defines the expected output dim of this model.

Used to infer the input dim of actor/critics.

Returns:

The embedding dimension.

Return type:

int

forward(seq, time_idxs, hidden_state=None, log_dict=None)[source]#

Sequence model forward pass.

Parameters:
  • seq (Tensor) – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.

  • time_idxs (Tensor) – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.

  • hidden_state (TformerHiddenState | None) – Architecture-specific hidden state. Defaults to None.

Returns:

A tuple containing:
  • output_seq: [Batch, Timestep, self.emb_dim]. Output of our

    seq2seq model.

  • new_hidden_state: Architecture-specific hidden state. Expected to

    be None if input hidden_state is None. Otherwise, we assume we are at inference time and that this forward method has handled any updates to the hidden state that were needed.

Return type:

Tuple[torch.Tensor, Optional[Any]]

init_hidden_state(batch_size, device)[source]#

Hook to create an architecture-specific hidden state.

Return value is passed as TrajEncoder.forward(..., hidden_state=self.init_hidden_state(...)) when the agent begins to interact with the environment.

Parameters:
  • batch_size (int) – Number of parallel environments.

  • device (device) – Device to store hidden state tensors (if applicable).

Returns:

Some hidden state object, or None if not applicable.

Defaults to None.

Return type:

Optional[Any]

reset_hidden_state(hidden_state, dones)[source]#

Hook to implement architecture-specific hidden state reset.

Parameters:
  • hidden_state (TformerHiddenState | None) – We only expect to see hidden states that were created by self.init_hidden_state().

  • dones (ndarray) – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.

Returns:

Architecture-specific hidden state. Defaults to a

no-op: new_hidden_state = hidden_state.

Return type:

Optional[Any]

class TrajEncoder(tstep_dim, max_seq_len)[source]#

Bases: Module, ABC

Abstract base class for trajectory encoders.

An agent’s “TrajEncoder” is the sequence model in charge of mapping the output of the “TstepEncoder” for each timestep of the trajectory to the latent dimension where actor-critic learning takes place. Because the actor and critic are feed-forward networks, this is the place to add long-term memory over previous timesteps.

Note

It would not make sense for the sequence model defined here to be bi-directional or non-causal.

Parameters:
  • tstep_dim (int) – Dimension of the input timestep representation (last dim of the input sequence). Defined by the output of the TstepEncoder.

  • max_seq_len (int) – Maximum sequence length of the model. Any inputs will have been trimmed to this length before reaching the TrajEncoder.

abstract property emb_dim: int#

Defines the expected output dim of this model.

Used to infer the input dim of actor/critics.

Returns:

The embedding dimension.

Return type:

int

abstract forward(seq, time_idxs, hidden_state=None, log_dict=None)[source]#

Sequence model forward pass.

Parameters:
  • seq (Tensor) – [Batch, Num Timesteps, TstepDim]. TstepDim is defined by the output of the TstepEncoder.

  • time_idxs (Tensor) – [Batch, Num Timesteps, 1]. A sequence of ints tying the input seq to the number of steps that have passed since the start of the episode. Can be used to compute position embeddings or other temporal features.

  • hidden_state (Any | None) – Architecture-specific hidden state. Defaults to None.

Returns:

A tuple containing:
  • output_seq: [Batch, Timestep, self.emb_dim]. Output of our

    seq2seq model.

  • new_hidden_state: Architecture-specific hidden state. Expected to

    be None if input hidden_state is None. Otherwise, we assume we are at inference time and that this forward method has handled any updates to the hidden state that were needed.

Return type:

Tuple[torch.Tensor, Optional[Any]]

init_hidden_state(batch_size, device)[source]#

Hook to create an architecture-specific hidden state.

Return value is passed as TrajEncoder.forward(..., hidden_state=self.init_hidden_state(...)) when the agent begins to interact with the environment.

Parameters:
  • batch_size (int) – Number of parallel environments.

  • device (device) – Device to store hidden state tensors (if applicable).

Returns:

Some hidden state object, or None if not applicable.

Defaults to None.

Return type:

Optional[Any]

reset_hidden_state(hidden_state, dones)[source]#

Hook to implement architecture-specific hidden state reset.

Parameters:
  • hidden_state (Any | None) – We only expect to see hidden states that were created by self.init_hidden_state().

  • dones (ndarray) – A bool array of shape (num_parallel_envs,) where True indicates the agent loop has finished this episode and expects the hidden state for this batch index to be reset.

Returns:

Architecture-specific hidden state. Defaults to a

no-op: new_hidden_state = hidden_state.

Return type:

Optional[Any]