amago.nets.cnn#
CNN input modules.
Functions
|
Classes
|
Abstract base class for built-in CNN architectures. |
|
CNN architecture from DrQ-v2. |
|
Pad+Random Crop image augmentation from DrQv2. |
|
Tiny CNN architecture useful for gridworld map features. |
|
CNN architecture from IMPALA. |
|
Customizable version of the small CNN architecture from DQN. |
- class CNN(img_shape, channels_first, activation)[source]#
Bases:
Module
,ABC
Abstract base class for built-in CNN architectures.
- Parameters:
img_shape (
tuple
[int
,int
,int
]) – Shape of the image (H, W, C) or (C, H, W).channels_first (
bool
) – Whether the image is in channels-first format.activation (
str
) – Activation function to use. Seeamago.nets.utils.activation_switch
.
- property blank_img: Tensor#
Returns an example input image of shape (1, 1) + self.img_shape (uint8)
- forward(obs, from_float=False, flatten=True)[source]#
Produce a feature map from an image.
- Parameters:
obs (
Tensor
) – Image tensor of shape (Batch, Len, H, W, C) for channels_last or (Batch, Len, C, H, W) for channels_first. Can be uint8 or float dtype.from_float (
bool
) – If False, assumes obs has pixel values in [0, 255], casts to float, and scales to [-1, 1]. If True, assumes obs is already in desired range/dtype. Defaults to False.flatten (
bool
) – If True, flatten output activations into feature array for linear layers. Defaults to True.
- Return type:
FloatTensor
- Returns:
Feature tensor of shape (B, L, out_dim) if flatten=True, otherwise (B, L, C, H, W) matching CNN output shape.
- class DrQCNN(img_shape, channels_first, activation)[source]#
Bases:
CNN
CNN architecture from DrQ-v2.
https://arxiv.org/abs/2107.09645
- Parameters:
img_shape (
tuple
[int
]) – Shape of the image (H, W, C) or (C, H, W).channels_first (
bool
) – Whether the image is in channels-first format.activation (
str
) – Activation function to use. Seeamago.nets.utils.activation_switch
.
- class DrQv2Aug(pad, channels_first)[source]#
Bases:
Module
Pad+Random Crop image augmentation from DrQv2.
- Parameters:
pad (
int
) – Number of pixels to pad on each side of the image.channels_first (
bool
) – Whether the image is in channels-first format.
- forward(x)[source]#
Apply pad+random crop augmentation to an image.
- Parameters:
x (
Tensor
) – Image tensor of shape (Batch, Len, H, W, C) for channels_last or (Batch, Len, C, H, W) for channels_first. Currently, H must equal W.- Return type:
Tensor
- Returns:
Augmented image with the same shape as the input.
- class GridworldCNN(img_shape, channels_first, activation, channels=[16, 32, 48], kernels=[2, 2, 2], strides=[1, 1, 1])[source]#
Bases:
CNN
Tiny CNN architecture useful for gridworld map features.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- Parameters:
img_shape (
tuple
[int
]) – Shape of the image (H, W, C) or (C, H, W).channels_first (
bool
) – Whether the image is in channels-first format.activation (
str
) – Activation function to use. Seeamago.nets.utils.activation_switch
.channels (
list
[int
]) – List of 3 ints representing the number of output channels for each convolutional layer. Defaults to [16, 32, 48].kernels (
list
[int
]) – List of 3 ints representing the kernel size for each convolutional layer. Defaults to [2, 2, 2].strides (
list
[int
]) – List of 3 ints representing the stride for each convolutional layer. Defaults to [1, 1, 1].
- class IMPALAishCNN(img_shape, channels_first, activation, cnn_block_depths=[16, 32, 32], post_group_norm=True)[source]#
Bases:
CNN
CNN architecture from IMPALA.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- Parameters:
img_shape (
tuple
[int
]) – Shape of the image (H, W, C) or (C, H, W).channels_first (
bool
) – Whether the image is in channels-first format.activation (
str
) – Activation function to use. Seeamago.nets.utils.activation_switch
.cnn_block_depths (
list
[int
]) – List of ints representing the number of output channels for each convolutional block. Length defines the number of residual blocks. Defaults to [16, 32, 32].post_group_norm (
bool
) – Whether to use group normalization after each convolutional block. Defaults to True.
- class NatureishCNN(img_shape, channels_first, activation, channels=[32, 64, 64], kernels=[8, 4, 3], strides=[4, 2, 1])[source]#
Bases:
CNN
Customizable version of the small CNN architecture from DQN.
Tip
This class is
@gin.configurable
. Default values of kwargs can be overridden using gin.- Parameters:
img_shape (
tuple
[int
]) – Shape of the image (H, W, C) or (C, H, W).channels_first (
bool
) – Whether the image is in channels-first format.activation (
str
) – Activation function to use. Seeamago.nets.utils.activation_switch
.channels (
list
[int
]) – List of 3 ints representing the number of output channels for each convolutional layer. Defaults to [32, 64, 64].kernels (
list
[int
]) – List of 3 ints representing the kernel size for each convolutional layer. Defaults to [8, 4, 3].strides (
list
[int
]) – List of 3 ints representing the stride for each convolutional layer. Defaults to [4, 2, 1].