amago.nets.cnn#

CNN input modules.

Functions

Classes

CNN(img_shape, channels_first, activation)

Abstract base class for built-in CNN architectures.

DrQCNN(img_shape, channels_first, activation)

CNN architecture from DrQ-v2.

DrQv2Aug(pad, channels_first)

Pad+Random Crop image augmentation from DrQv2.

GridworldCNN(img_shape, channels_first, ...)

Tiny CNN architecture useful for gridworld map features.

IMPALAishCNN(img_shape, channels_first, ...)

CNN architecture from IMPALA.

NatureishCNN(img_shape, channels_first, ...)

Customizable version of the small CNN architecture from DQN.

class CNN(img_shape, channels_first, activation)[source]#

Bases: Module, ABC

Abstract base class for built-in CNN architectures.

Parameters:
  • img_shape (tuple[int, int, int]) – Shape of the image (H, W, C) or (C, H, W).

  • channels_first (bool) – Whether the image is in channels-first format.

  • activation (str) – Activation function to use. See amago.nets.utils.activation_switch.

property blank_img: Tensor#

Returns an example input image of shape (1, 1) + self.img_shape (uint8)

abstract conv_forward(imgs)[source]#
forward(obs, from_float=False, flatten=True)[source]#

Produce a feature map from an image.

Parameters:
  • obs (Tensor) – Image tensor of shape (Batch, Len, H, W, C) for channels_last or (Batch, Len, C, H, W) for channels_first. Can be uint8 or float dtype.

  • from_float (bool) – If False, assumes obs has pixel values in [0, 255], casts to float, and scales to [-1, 1]. If True, assumes obs is already in desired range/dtype. Defaults to False.

  • flatten (bool) – If True, flatten output activations into feature array for linear layers. Defaults to True.

Return type:

FloatTensor

Returns:

Feature tensor of shape (B, L, out_dim) if flatten=True, otherwise (B, L, C, H, W) matching CNN output shape.

class DrQCNN(img_shape, channels_first, activation)[source]#

Bases: CNN

CNN architecture from DrQ-v2.

https://arxiv.org/abs/2107.09645

Parameters:
  • img_shape (tuple[int]) – Shape of the image (H, W, C) or (C, H, W).

  • channels_first (bool) – Whether the image is in channels-first format.

  • activation (str) – Activation function to use. See amago.nets.utils.activation_switch.

conv_forward(imgs)[source]#
Return type:

Tensor

class DrQv2Aug(pad, channels_first)[source]#

Bases: Module

Pad+Random Crop image augmentation from DrQv2.

facebookresearch/drqv2

Parameters:
  • pad (int) – Number of pixels to pad on each side of the image.

  • channels_first (bool) – Whether the image is in channels-first format.

forward(x)[source]#

Apply pad+random crop augmentation to an image.

Parameters:

x (Tensor) – Image tensor of shape (Batch, Len, H, W, C) for channels_last or (Batch, Len, C, H, W) for channels_first. Currently, H must equal W.

Return type:

Tensor

Returns:

Augmented image with the same shape as the input.

class GridworldCNN(img_shape, channels_first, activation, channels=[16, 32, 48], kernels=[2, 2, 2], strides=[1, 1, 1])[source]#

Bases: CNN

Tiny CNN architecture useful for gridworld map features.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • img_shape (tuple[int]) – Shape of the image (H, W, C) or (C, H, W).

  • channels_first (bool) – Whether the image is in channels-first format.

  • activation (str) – Activation function to use. See amago.nets.utils.activation_switch.

  • channels (list[int]) – List of 3 ints representing the number of output channels for each convolutional layer. Defaults to [16, 32, 48].

  • kernels (list[int]) – List of 3 ints representing the kernel size for each convolutional layer. Defaults to [2, 2, 2].

  • strides (list[int]) – List of 3 ints representing the stride for each convolutional layer. Defaults to [1, 1, 1].

conv_forward(imgs)[source]#
class IMPALAishCNN(img_shape, channels_first, activation, cnn_block_depths=[16, 32, 32], post_group_norm=True)[source]#

Bases: CNN

CNN architecture from IMPALA.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • img_shape (tuple[int]) – Shape of the image (H, W, C) or (C, H, W).

  • channels_first (bool) – Whether the image is in channels-first format.

  • activation (str) – Activation function to use. See amago.nets.utils.activation_switch.

  • cnn_block_depths (list[int]) – List of ints representing the number of output channels for each convolutional block. Length defines the number of residual blocks. Defaults to [16, 32, 32].

  • post_group_norm (bool) – Whether to use group normalization after each convolutional block. Defaults to True.

conv_forward(imgs)[source]#
Return type:

Tensor

class NatureishCNN(img_shape, channels_first, activation, channels=[32, 64, 64], kernels=[8, 4, 3], strides=[4, 2, 1])[source]#

Bases: CNN

Customizable version of the small CNN architecture from DQN.

Tip

This class is @gin.configurable. Default values of kwargs can be overridden using gin.

Parameters:
  • img_shape (tuple[int]) – Shape of the image (H, W, C) or (C, H, W).

  • channels_first (bool) – Whether the image is in channels-first format.

  • activation (str) – Activation function to use. See amago.nets.utils.activation_switch.

  • channels (list[int]) – List of 3 ints representing the number of output channels for each convolutional layer. Defaults to [32, 64, 64].

  • kernels (list[int]) – List of 3 ints representing the kernel size for each convolutional layer. Defaults to [8, 4, 3].

  • strides (list[int]) – List of 3 ints representing the stride for each convolutional layer. Defaults to [4, 2, 1].

conv_forward(imgs)[source]#
weight_init(m)[source]#