amago.nets.transformer#

Custom Transformer components.

Classes

AttentionLayer(self_attention, d_model, ...)

Query, Key, Value --> Self-Attention --> Output Projection

Cache(device, dtype, layers, batch_size, ...)

A cache for key and value tensors.

FixedPosEmb(d_model)

Classic sinusoidal positional encoding.

FlashAttention(causal, dropout[, window_size])

Optimized self-attention using flash_attn.

FlexAttention(score_mod, mask_mod, causal, ...)

Experimental support for flex_attention (a recent pytorch feature).

LearnablePosEmb(d_model[, max_time_idx])

Learnable positional encoding.

SelfAttention([causal, dropout])

A base class for self-attention layers.

SigmaReparam(d_in, d_out[, bias, fast_init])

SigmaReparam nn.Linear alternative.

SlidingWindowFlexAttention(causal, dropout)

A more useful test of FlexAttention that implements a sliding window pattern for long context lengths.

TformerHiddenState(key_cache, val_cache, ...)

Helps manage the Cache hidden state during rollouts.

Transformer(inp_dim, d_model, layers[, ...])

Build a full Transformer model from a list of layers.

TransformerLayer(attention_layer, d_model, d_ff)

Pre-Norm Self-Attention Layer

VanillaAttention(causal, dropout)

Unoptimized self-attention in regular pytorch.

VanillaFlexAttention(causal, dropout)

A sanity-check test of FlexAttention that should be equivalent to VanillaAttention.

class AttentionLayer(self_attention, d_model, d_qkv, n_heads, dropout_qkv=0.0, head_scaling=True, sigma_reparam=True)[source]#

Bases: Module

Query, Key, Value –> Self-Attention –> Output Projection

forward(sequence, key_cache=None, val_cache=None, cache_seqlens=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Cache(device, dtype, layers, batch_size, max_seq_len, n_heads, head_dim)[source]#

Bases: object

A cache for key and value tensors.

roll_back(seq_lens)[source]#
class FixedPosEmb(d_model)[source]#

Bases: Module

Classic sinusoidal positional encoding.

forward(pos_idxs)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class FlashAttention(causal, dropout, window_size=(-1, -1))[source]#

Bases: SelfAttention

Optimized self-attention using flash_attn.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • causal (bool) – Whether to use a causal mask.

  • dropout (float) – The dropout rate of the attention matrix.

  • window_size (tuple[int, int]) – flash_attn’s window_size parameter, which enables sliding window attention. Defaults to (-1, -1), which is standard full-length attention.

forward(qkv, key_cache=None, val_cache=None, cache_seqlens=None)[source]#

Map queries keys and values to attention output.

Should implement full training pass when key_cache/val_cache/cache_seqlens are None, and (cached) inference when provided.

Parameters:
  • qkv – A tensor of shape (batch_size, sequence_length, 3, num_heads, head_dim). Packed queries, keys, and values.

  • key_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • val_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • cache_seqlens – A tensor of shape (batch_size,) that defines the current index of the k/v cache.

Returns:

A tensor of shape (batch_size, sequence_length, num_heads, head_dim).

class FlexAttention(score_mod, mask_mod, causal, dropout)[source]#

Bases: SelfAttention

Experimental support for flex_attention (a recent pytorch feature).

Allows custom sparse attention patterns using score_mod and mask_mod function. (https://pytorch.org/blog/flexattention/) (pytorch-labs/attention-gym)

The main benefit of flex_attention for our purposes is a unified implementation of key/value cache inference for more complex attention patterns.

Parameters:
  • score_mod (callable) – A function that takes the batch_idx, head_idx, q_idx, kv_idx and computes a scalar score for the attention matrix entry between these locations.

  • mask_mod (callable) – A function that takes the batch_idx, head_idx, q_idx, kv_idx and returns False if attention scores between these locations should be masked.

  • causal (bool) – Whether to use a causal mask. If True, the causal mask is applied on top of the custom mask_mod. Defaults to True.

  • dropout (float) – The dropout rate of the attention matrix. Defaults to 0.0.

cached_training_mask(q_len, kv_len)[source]#
flex_attention(q, k, v, score_mod, block_mask)[source]#
flex_attention_inf(q, k, v, score_mod, block_mask)[source]#
forward(qkv, key_cache=None, val_cache=None, cache_seqlens=None)[source]#

Map queries keys and values to attention output.

Should implement full training pass when key_cache/val_cache/cache_seqlens are None, and (cached) inference when provided.

Parameters:
  • qkv – A tensor of shape (batch_size, sequence_length, 3, num_heads, head_dim). Packed queries, keys, and values.

  • key_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • val_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • cache_seqlens – A tensor of shape (batch_size,) that defines the current index of the k/v cache.

Returns:

A tensor of shape (batch_size, sequence_length, num_heads, head_dim).

kv_cache_mask_mod(cache_seqlens)[source]#
kv_cache_score_mod(cache_seqlens)[source]#
class LearnablePosEmb(d_model, max_time_idx=gin.REQUIRED)[source]#

Bases: Module

Learnable positional encoding.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Creates a lookup table of d_model size embeddings for every timestep of the episode.

Parameters:
  • d_model (int) – The dimension of the embeddings.

  • max_time_idx (int) – The maximum timestep we’ll need to learn an embedding for. So application-specific that it’s gin.REQUIRED and therefore must be configured manually in the training script or its .gin files.

forward(pos_idxs)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SelfAttention(causal=True, dropout=0.0)[source]#

Bases: Module, ABC

A base class for self-attention layers.

Parameters:
  • causal (bool) – Whether to use a causal mask.

  • dropout (float) – The dropout rate of the attention matrix.

abstract forward(qkv, key_cache=None, val_cache=None, cache_seqlens=None)[source]#

Map queries keys and values to attention output.

Should implement full training pass when key_cache/val_cache/cache_seqlens are None, and (cached) inference when provided.

Parameters:
  • qkv – A tensor of shape (batch_size, sequence_length, 3, num_heads, head_dim). Packed queries, keys, and values.

  • key_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • val_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • cache_seqlens – A tensor of shape (batch_size,) that defines the current index of the k/v cache.

Returns:

A tensor of shape (batch_size, sequence_length, num_heads, head_dim).

class SigmaReparam(d_in, d_out, bias=True, fast_init=False)[source]#

Bases: Linear

SigmaReparam nn.Linear alternative.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

apple/ml-sigma-reparam ywchan2005/sigma-reparam-pytorch

SigmaReparam is an alternative to nn.Linear that can be used in Transformer blocks to stabilize attention scores. (https://arxiv.org/abs/2303.06296)

Parameters:
  • d_in – The input dimension of the layer.

  • d_out – The output dimension of the layer.

  • bias (bool) – Whether to use a bias in the layer. Defaults to True.

  • fast_init (bool) – Skip a SVD initialization step and use a simpler strategy. Mainly used for backward compatability with old results and as a hacky way to speed up init for large models when we’ll be loading a pretrained checkoint soon anyway. Defaults to False.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class SlidingWindowFlexAttention(causal, dropout, window_size=gin.REQUIRED)[source]#

Bases: FlexAttention

A more useful test of FlexAttention that implements a sliding window pattern for long context lengths.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

class TformerHiddenState(key_cache, val_cache, seq_lens)[source]#

Bases: object

Helps manage the Cache hidden state during rollouts.

reset(idxs)[source]#
update()[source]#
class Transformer(inp_dim, d_model, layers, dropout_emb=0.05, norm='layer', pos_emb='fixed')[source]#

Bases: Module

Build a full Transformer model from a list of layers.

property emb_dim#
forward(seq, pos_idxs, hidden_state=None)[source]#

Transformer seq2seq

Parameters:
  • seq – The input sequence of shape (batch_size, seq_len, inp_dim).

  • pos_idxs – The position indices of the input sequence of shape (batch_size, seq_len).

  • hidden_state (TformerHiddenState | None) – The hidden state of the transformer.

Returns:

The output sequence of shape (batch_size, seq_len, d_model). The new hidden state of the transformer.

inference_forward(seq, hidden_state)[source]#
preprocess_seq(seq, pos_idxs)[source]#
training_forward(seq)[source]#
class TransformerLayer(attention_layer, d_model, d_ff, dropout_ff=0.1, activation='leaky_relu', norm='layer', sigma_reparam=True, normformer_norms=True)[source]#

Bases: Module

Pre-Norm Self-Attention Layer

forward(self_seq, key_cache=None, val_cache=None, cache_seqlens=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class VanillaAttention(causal, dropout)[source]#

Bases: SelfAttention

Unoptimized self-attention in regular pytorch.

Parameters:
  • causal (bool) – Whether to use a causal mask.

  • dropout (float) – The dropout rate of the attention matrix.

forward(qkv, key_cache=None, val_cache=None, cache_seqlens=None)[source]#

Map queries keys and values to attention output.

Should implement full training pass when key_cache/val_cache/cache_seqlens are None, and (cached) inference when provided.

Parameters:
  • qkv – A tensor of shape (batch_size, sequence_length, 3, num_heads, head_dim). Packed queries, keys, and values.

  • key_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • val_cache – A tensor of shape (batch_size, max_sequence_length, num_heads, head_dim).

  • cache_seqlens – A tensor of shape (batch_size,) that defines the current index of the k/v cache.

Returns:

A tensor of shape (batch_size, sequence_length, num_heads, head_dim).

class VanillaFlexAttention(causal, dropout)[source]#

Bases: FlexAttention

A sanity-check test of FlexAttention that should be equivalent to VanillaAttention.