1from argparse import ArgumentParser
2
3import wandb
4
5from amago.envs.builtin.toy_gym import RoomKeyDoor
6from amago.envs import AMAGOEnv
7from amago.cli_utils import *
8
9
10def add_cli(parser):
11 parser.add_argument(
12 "--meta_horizon",
13 type=int,
14 default=500,
15 help="Total meta-adaptation timestep budget for the agent to explore the same room layout.",
16 )
17 parser.add_argument(
18 "--room_size",
19 type=int,
20 default=8,
21 help="Size of the room. Exploration is sparse and difficulty scales quickly with room size.",
22 )
23 parser.add_argument(
24 "--episode_length",
25 type=int,
26 default=50,
27 help="Maximum length of a single episode in the environment.",
28 )
29 parser.add_argument(
30 "--light_room_observation",
31 action="store_true",
32 help="Demonstrate how meta-RL relies on partial observability by revealing the goal location as part of the observation. This version of the environment can be solved without memory!",
33 )
34 parser.add_argument(
35 "--randomize_actions",
36 action="store_true",
37 help="Randomize the agent's action space to make the task harder.",
38 )
39 return parser
40
41
42if __name__ == "__main__":
43 parser = ArgumentParser()
44 add_common_cli(parser)
45 add_cli(parser)
46 args = parser.parse_args()
47
48 config = {}
49 tstep_encoder_type = switch_tstep_encoder(
50 config, arch="ff", n_layers=2, d_hidden=128, d_output=64
51 )
52 traj_encoder_type = switch_traj_encoder(
53 config,
54 arch=args.traj_encoder,
55 memory_size=args.memory_size,
56 layers=args.memory_layers,
57 )
58 agent_type = switch_agent(config, args.agent_type, reward_multiplier=100.0)
59 # the fancier exploration schedule mentioned in the appendix can help
60 # when the domain is a true meta-RL problem and the "horizon" time limit
61 # (above) is actually relevant for resetting the task.
62 exploration_type = switch_exploration(
63 config, "bilevel", steps_anneal=500_000, rollout_horizon=args.meta_horizon
64 )
65 use_config(config, args.configs)
66
67 group_name = f"{args.run_name}_dark_key_door"
68 for trial in range(args.trials):
69 run_name = group_name + f"_trial_{trial}"
70 make_train_env = lambda: AMAGOEnv(
71 env=RoomKeyDoor(
72 size=args.room_size,
73 max_episode_steps=args.episode_length,
74 meta_rollout_horizon=args.meta_horizon,
75 dark=not args.light_room_observation,
76 randomize_actions=args.randomize_actions,
77 ),
78 env_name=f"Dark-Key-To-Door-{args.room_size}x{args.room_size}",
79 )
80 experiment = create_experiment_from_cli(
81 args,
82 agent_type=agent_type,
83 tstep_encoder_type=tstep_encoder_type,
84 traj_encoder_type=traj_encoder_type,
85 make_train_env=make_train_env,
86 make_val_env=make_train_env,
87 max_seq_len=args.meta_horizon,
88 traj_save_len=args.meta_horizon,
89 group_name=group_name,
90 run_name=run_name,
91 val_timesteps_per_epoch=args.meta_horizon * 4,
92 exploration_wrapper_type=exploration_type,
93 )
94 switch_async_mode(experiment, args.mode)
95 experiment.start()
96 if args.ckpt is not None:
97 experiment.load_checkpoint(args.ckpt)
98 experiment.learn()
99 experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
100 experiment.delete_buffer_from_disk()
101 wandb.finish()