Multi-Game Atari

Multi-Game Atari#

  1from argparse import ArgumentParser
  2from functools import partial
  3
  4import wandb
  5
  6from amago.envs.builtin.ale_retro import AtariAMAGOWrapper, AtariGame
  7from amago.nets.cnn import NatureishCNN, IMPALAishCNN
  8from amago.cli_utils import *
  9
 10
 11def add_cli(parser):
 12    parser.add_argument("--games", nargs="+", default=None)
 13    parser.add_argument("--max_seq_len", type=int, default=80)
 14    parser.add_argument(
 15        "--cnn", type=str, choices=["nature", "impala"], default="impala"
 16    )
 17    return parser
 18
 19
 20DEFAULT_MULTIGAME_LIST = [
 21    "Pong",
 22    "Boxing",
 23    "Breakout",
 24    "Gopher",
 25    "MsPacman",
 26    "ChopperCommand",
 27    "CrazyClimber",
 28    "BattleZone",
 29    "Qbert",
 30    "Seaquest",
 31]
 32
 33ATARI_TIME_LIMIT = (30 * 60 * 60) // 5  # (30 minutes of game time)
 34
 35
 36def make_atari_game(game_name):
 37    return AtariAMAGOWrapper(
 38        AtariGame(
 39            game=game_name,
 40            action_space="discrete",
 41            terminal_on_life_loss=False,
 42            version="v5",
 43            frame_skip=5,
 44            grayscale=False,
 45            sticky_action_prob=0.25,
 46            clip_rewards=False,
 47        ),
 48    )
 49
 50
 51if __name__ == "__main__":
 52    parser = ArgumentParser()
 53    add_cli(parser)
 54    add_common_cli(parser)
 55    args = parser.parse_args()
 56
 57    config = {
 58        "amago.agent.Agent.reward_multiplier": 0.25,
 59        "amago.agent.Agent.offline_coeff": (
 60            1.0 if args.agent_type == "multitask" else 0.0
 61        ),
 62    }
 63    traj_encoder_type = switch_traj_encoder(
 64        config,
 65        arch=args.traj_encoder,
 66        memory_size=args.memory_size,
 67        layers=args.memory_layers,
 68    )
 69
 70    if args.cnn == "nature":
 71        cnn_type = NatureishCNN
 72    elif args.cnn == "impala":
 73        cnn_type = IMPALAishCNN
 74    tstep_encoder_type = switch_tstep_encoder(
 75        config,
 76        arch="cnn",
 77        cnn_type=cnn_type,
 78        channels_first=True,
 79        drqv2_aug=True,
 80    )
 81
 82    agent_type = switch_agent(config, args.agent_type)
 83    use_config(config, args.configs)
 84
 85    # Episode lengths in Atari vary widely across games, so we manually set actors
 86    # to a specific game so that all games are always played in parallel.
 87    games = args.games or DEFAULT_MULTIGAME_LIST
 88    assert (
 89        args.parallel_actors % len(games) == 0
 90    ), "Number of actors must be divisible by number of games."
 91    env_funcs = []
 92    for actor in range(args.parallel_actors):
 93        game_name = games[actor % len(games)]
 94        env_funcs.append(partial(make_atari_game, game_name))
 95
 96    group_name = f"{args.run_name}_atari_l_{args.max_seq_len}_cnn_{args.cnn}"
 97    for trial in range(args.trials):
 98        run_name = group_name + f"_trial_{trial}"
 99        experiment = create_experiment_from_cli(
100            args,
101            make_train_env=env_funcs,
102            make_val_env=env_funcs,
103            max_seq_len=args.max_seq_len,
104            traj_save_len=args.max_seq_len * 3,
105            run_name=run_name,
106            tstep_encoder_type=tstep_encoder_type,
107            traj_encoder_type=traj_encoder_type,
108            agent_type=agent_type,
109            group_name=group_name,
110            val_timesteps_per_epoch=ATARI_TIME_LIMIT,
111            save_trajs_as="npz-compressed",
112        )
113        switch_async_mode(experiment, args.mode)
114        experiment.start()
115        if args.ckpt is not None:
116            experiment.load_checkpoint(args.ckpt)
117        experiment.learn()
118        experiment.evaluate_test(
119            env_funcs, timesteps=ATARI_TIME_LIMIT * 5, render=False
120        )
121        experiment.delete_buffer_from_disk()
122        wandb.finish()