Gymnax

Gymnax#

  1"""
  2Support for gymnax is experimental and mainly meant to test the already_vectorized 
  3env API used by XLand MiniGrid (an unsolved environment) with classic gym envs. 
  4Many of the gymnax envs appear to be broken by recent versions of jax.
  5There are a couple memory/meta-RL bsuite envs where AMAGO+Transformer
  6is significantly better than the gymnax reference scores though.
  7"""
  8
  9import os
 10
 11# stop jax from stealing pytorch's memory, since we're only using it for the envs
 12os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
 13
 14from argparse import ArgumentParser
 15import math
 16from functools import partial
 17
 18import gymnax
 19import torch
 20import jax
 21import wandb
 22import numpy as np
 23
 24from amago.envs import AMAGOEnv
 25from amago.envs.builtin.gymnax_envs import GymnaxCompatibility
 26from amago.nets.cnn import GridworldCNN
 27from amago.cli_utils import *
 28
 29
 30def add_cli(parser):
 31    parser.add_argument("--env", type=str, required=True)
 32    parser.add_argument("--max_seq_len", type=int, default=128)
 33    parser.add_argument("--eval_timesteps", type=int, default=1000)
 34    return parser
 35
 36
 37def make_gymnax_amago(env_name, parallel_envs):
 38    env, params = gymnax.make(env_name)
 39    vec_env = GymnaxCompatibility(env, num_envs=parallel_envs, params=params)
 40    # when the environment is already vectorized, alert the AMAGOEnv wrapper with `batched_envs`
 41    return AMAGOEnv(
 42        env=vec_env, env_name=f"gymnax_{env_name}", batched_envs=parallel_envs
 43    )
 44
 45
 46def guess_tstep_encoder(config, obs_shape):
 47    """
 48    We'll move past the somewhat random collection of gymnax envs by making up a simple
 49    timestep encoder based on a few hacks. If we really cared about gymnax performance we
 50    could tune this per environment.
 51    """
 52    if len(obs_shape) == 3:
 53        print(f"Guessing CNN for observation of shape {obs_shape}")
 54        channels_first = np.argmin(obs_shape).item() == 0
 55        return switch_tstep_encoder(
 56            config,
 57            "cnn",
 58            cnn_type=GridworldCNN,
 59            channels_first=channels_first,
 60        )
 61    else:
 62        print(f"Guessing MLP for observation of shape {obs_shape}")
 63        dim = math.prod(obs_shape)  # FFTstepEncoder will flatten the obs on input
 64        return switch_tstep_encoder(
 65            config,
 66            "ff",
 67            d_hidden=max(dim // 3, 128),
 68            n_layers=2,
 69            d_output=max(dim // 4, 96),
 70        )
 71
 72
 73if __name__ == "__main__":
 74    parser = ArgumentParser()
 75    add_common_cli(parser)
 76    add_cli(parser)
 77    args = parser.parse_args()
 78
 79    # "already_vectorized" will stop the training loop from trying spawn multiple instances of the env
 80    args.env_mode = "already_vectorized"
 81
 82    # config
 83    config = {}
 84    traj_encoder_type = switch_traj_encoder(
 85        config,
 86        arch=args.traj_encoder,
 87        memory_size=args.memory_size,
 88        layers=args.memory_layers,
 89    )
 90    with jax.default_device(jax.devices("cpu")[0]):
 91        test_env, env_params = gymnax.make(args.env)
 92        test_obs_shape = test_env.observation_space(env_params).shape
 93    tstep_encoder_type = guess_tstep_encoder(config, test_obs_shape)
 94    agent_type = switch_agent(config, args.agent_type)
 95
 96    use_config(config, args.configs)
 97    make_env = partial(
 98        make_gymnax_amago, env_name=args.env, parallel_envs=args.parallel_actors
 99    )
100    group_name = f"{args.run_name}_{args.env}"
101    for trial in range(args.trials):
102        run_name = group_name + f"_trial_{trial}"
103        experiment = create_experiment_from_cli(
104            args,
105            make_train_env=make_env,
106            make_val_env=make_env,
107            max_seq_len=args.max_seq_len,
108            traj_save_len=args.max_seq_len * 20,
109            run_name=run_name,
110            agent_type=agent_type,
111            tstep_encoder_type=tstep_encoder_type,
112            traj_encoder_type=traj_encoder_type,
113            group_name=group_name,
114            val_timesteps_per_epoch=args.eval_timesteps,
115            grad_clip=2.0,
116            l2_coeff=1e-4,
117            save_trajs_as="npz-compressed",
118        )
119        experiment = switch_async_mode(experiment, args.mode)
120        amago_device = experiment.DEVICE.index or torch.cuda.current_device()
121        env_device = jax.devices("gpu")[amago_device]
122        with jax.default_device(env_device):
123            experiment.start()
124            if args.ckpt is not None:
125                experiment.load_checkpoint(args.ckpt)
126            experiment.learn()
127            experiment.evaluate_test(make_env, timesteps=10_000, render=False)
128            experiment.delete_buffer_from_disk()
129            wandb.finish()