Customize#
Almost anything else can be customized by inheriting from a base class and pointing Experiment
to our custom version.
Timestep Encoder#
For example, if we want a custom TstepEncoder
, we can implement the abstract methods and pass our module into the experiment:
from torch import nn
import torch.nn.functional as F
from amago import TstepEncoder
# there's no specific requirement to use AMAGO's pytorch modules, but
# we've built up a collection of common RL components that might be helpful!
from amago.nets.cnn import NatureishCNN
from amago.nets.ff import Normalization
class MultiModalRobotTstepEncoder(TstepEncoder):
def __init__(
self,
obs_space: gym.spaces.Dict,
rl2_space: gym.spaces.Box,
):
super().__init__(obs_space=obs_space, rl2_space=rl2_space)
img_space = obs_space["image"]
joint_space = obs_space["joints"]
self.cnn = NatureishCNN(img_shape=img_space.shape)
cnn_out_shape = self.cnn(self.cnn.blank_img).shape[-1]
self.joint_rl2_emb = nn.Linear(joint_space.shape[-1] + rl2_space.shape[-1], 32)
self.merge = nn.Linear(cnn_out_shape + 32, 128)
# we'll represent each Timestep as a 64d vector
self.output_layer = nn.Linear(128, 64)
self.out_norm = Noramlization("layer", 64)
@property
def emb_dim(self):
# tell the rest of the model what output shape to expect
return 64
def inner_forward(self, obs, rl2s, log_dict=None):
img_features = self.cnn(obs["image"])
joints_and_rl2s = torch.cat((obs["joints"], rl2s), dim=-1)
joint_features = F.leaky_relu(self.joint_rl2_emb(joints_and_rl2s))
merged = torch.cat((img_features, joint_features), dim=-1)
merged = F.leaky_relu(self.merge(merged))
out = self.out_norm(self.output_layer(merged))
return out
experiment = amago.Experiment(
...,
tstep_encoder_type=MultiModalRobotTstepEncoder,
)
Multi-Task BabyAI and XLand Mini-Grid are relevant examples.
TrajEncoder (Seq2Seq)#
Implement: TrajEncoder
Substitute: Experiment(traj_encoder_type=MyTrajEncoder, ...)
Exploration Strategy#
Implement: ExplorationWrapper
Substitute: Experiment(exploration_wrapper_type=MyExplorationWrapper, ....)
T-Maze demonstrates a custom exploration strategy.
Agent#
Implement: Agent
Substitute: Experiment(agent_type=MyAgent, ...)
RLDataset#
Implement: RLDataset
Substitute: dset = MyDataset(); experiment = Experiment(dataset=dset, ...)
D4RL demonstrates a custom dataset.
(Continuous) Action Distribution#
Implement: PolicyOutput
Substitute: config = {"amago.nets.actor_critic.Actor.continuous_dist_type" : MyPolicyOutput, ...}; use_config(config); experiment = Experiment(...)