Setup the Environment#



AMAGO follows the gymnasium (0.26 < version < 0.30) environment API. A typical gymnasium.Env simulates a single instance of an environment. We’ll be collecting data in parallel by creating multiple independent instances. All we need to do is define a function that creates an AMAGOEnv, for example:

import gymnasium
import amago

def make_env():
    env = gymnasium.make("Pendulum-v1")
    # `env_name` is used for logging eval metrics. multi-task envs
    # will sample a new task between resets and change the name accordingly.
    env = amago.envs.AMAGOEnv(env=env, env_name="Pendulum", batched_envs=1)
    return env

sample_env = make_env()
# If the obs space is not a dict, AMAGOEnv will create a default key of 'observation':
sample_env.observation_spce
# >>> Dict('observation': Box([-1. -1. -8.], [1. 1. 8.], (3,), float32))
# environments return an `amago.hindsight.Timestep`
sample_timestep, info = sample_env.reset()
# each environment has a batch dimension of 1
sample_timestep.obs["observation"].shape
# >>> (1, 3)

experiment = amago.Experiment(
    make_train_env=make_env,
    make_val_env=make_env,
    parallel_actors=36,
    env_mode="async", # or "sync" for easy debugging / reduced overhead
    ..., # we'll be walking through more arguments in the following sections
)

Note

We follow infinite-bootstrapping convention where environments are reset on done = terminated or truncated but RL training only uses done = terminated for value learning.


Vectorized Envs and jax#

Some domains already parallelize computation over many environments: step expects a batch of actions and returns a batch of observations. Examples include recent envs like gymnax that use jax and a GPU to boost their framerate:

import gymnax
from amago.envs.builtin.gymnax_envs import GymnaxCompatability

def make_env():
    env, params = gymnax.make("Pendulum-v1")
    # AMAGO expects numpy data and an unbatched observation space
    vec_env = GymnaxCompatability(env, num_envs=512, params=params)
    # vec_env.reset()[0].shape >>> (512, 3) # already vectorized!
    return AMAGOEnv(env=vec_env, env_name="gymnax_Pendulum", batched_envs=512)

experiment = amago.Experiment(
    make_train_env=make_env,
    make_val_env=make_env,
    parallel_actors=512, # match batch dim of environment
    env_mode="already_vectorized", # prevents spawning multiple async instances
    ...,
)

There are some details in getting the pytorch agent and jax envs to cooperate and share a GPU. See Gymnax.


Meta-RL and Auto-Resets#

Most meta-RL problems involve an environment that resets itself to the same task k times. There is no consistent way to handle this across different benchmarks. Therefore, AMAGO expects the environment to be handling multi-trial resets on its own. terminated and truncated indicate that this environment interaction is finished and should be saved/logged. For example:

from amago.envs import AMAGO_ENV_LOG_PREFIX

class MyMetaRLEnv(gym.Wrapper):

    def reset(self):
        self.sample_new_task_somehow()
        obs, info = self.env.reset()
        self.current_episode = 0
        self.episode_return = 0
        return obs, info

    def step(self, action):
        next_obs, reward, terminated, truncated, info = self.env.step(action)
        self.episode_return += reward
        if terminated or truncated:
            # "trial-done"
            next_obs, info = self.reset_to_the_same_task_somehow()
            # we'll log anything in `info` that begins with `AMAGO_ENV_LOG_PREFIX`
            info[f"{AMAGO_ENV_LOG_PREFIX} Ep {self.current_episode} Return"] = self.episode_return
            self.episode_return = 0
            self.current_episode += 1
        # only indicate when the rollout is finished and the env needs to be completely reset
        done = self.current_episode >= self.k
        return next_obs, reward, done, done, info

An important limitation of this is that while AMAGO will automatically organize meta-RL policy inputs for the previous action and reward, it is not aware of the reset signal. If we need the trial reset signal it can go in the observation. We could concat an extra feature or make the observation a dict with an extra reset key. The amago.envs.builtin envs contain many examples.


Exploration#

Explorative action noise is implemented by a gymasium.Wrapper (amago.envs.exploration). Env creation automatically wraps the training envs in Experiment.exploration_wrapper_type.

from amago.envs.exploration import EpsilonGreedy
experiment = Experiment(
    exploration_wrapper_type=EpsilonGreedy,
    ...
)