1from argparse import ArgumentParser
2
3import wandb
4import gin
5
6from amago.envs import AMAGOEnv
7from amago.envs.builtin.tmaze import TMazeAltPassive, TMazeAltActive
8from amago.envs.exploration import EpsilonGreedy
9from amago import cli_utils
10
11
12def add_cli(parser):
13 parser.add_argument("--horizon", type=int, required=True)
14 return parser
15
16
17@gin.configurable
18class TMazeExploration(EpsilonGreedy):
19 """
20 The Tmaze environment is meant to evaluate recall over long context lengths without
21 testing exploration, but it does this by requiring horizon - 1 deterministic actions
22 to create a gap between the timestep that reveals the correct action and the timestep
23 it is taken. This unintentionally creates a worst-case scenario for epsilon greedy
24 exploration. We use this epsilon greedy exploration schedule to answer the
25 central memory question while fixing the sample efficiency problems it creates.
26
27 https://github.com/twni2016/Memory-RL/issues/1
28 """
29
30 def __init__(
31 self,
32 env,
33 start_window=0,
34 end_window=3,
35 horizon: int = gin.REQUIRED,
36 eps_start=1.0,
37 eps_end=0.01,
38 steps_anneal=100_000,
39 ):
40 self.start_window = start_window
41 self.end_window = end_window
42 self.horizon = horizon
43 super().__init__(
44 env,
45 eps_start=eps_start,
46 eps_end=eps_end,
47 steps_anneal=steps_anneal,
48 )
49
50 def current_eps(self, local_step: int):
51 current = super().current_eps(local_step)
52 if (
53 local_step > self.start_window
54 and local_step < self.horizon - self.end_window
55 ):
56 # low exploration during the easy corridor section; regular during the
57 # interesting early and late timesteps.
58 current[:] = 0.5 / self.horizon
59 return current
60
61
62if __name__ == "__main__":
63 parser = ArgumentParser()
64 cli_utils.add_common_cli(parser)
65 add_cli(parser)
66 args = parser.parse_args()
67
68 config = {
69 "TMazeExploration.horizon": args.horizon,
70 }
71 traj_encoder_type = cli_utils.switch_traj_encoder(
72 config,
73 arch=args.traj_encoder,
74 memory_size=args.memory_size,
75 layers=args.memory_layers,
76 )
77 tstep_encoder_type = cli_utils.switch_tstep_encoder(
78 config,
79 arch="ff",
80 n_layers=2,
81 d_hidden=128,
82 d_output=128,
83 normalize_inputs=False,
84 )
85 agent_type = cli_utils.switch_agent(
86 config, args.agent_type, reward_multiplier=100.0, gamma=0.9999
87 )
88 cli_utils.use_config(config, args.configs)
89
90 group_name = f"{args.run_name}_TMazePassive_H{args.horizon}"
91 for trial in range(args.trials):
92 run_name = group_name + f"_trial_{trial}"
93 make_env = lambda: AMAGOEnv(
94 env=TMazeAltPassive(
95 corridor_length=args.horizon, penalty=-1.0 / args.horizon
96 ),
97 env_name=f"TMazePassive-H{args.horizon}",
98 )
99 experiment = cli_utils.create_experiment_from_cli(
100 args,
101 make_train_env=make_env,
102 make_val_env=make_env,
103 max_seq_len=args.horizon + 1,
104 traj_save_len=args.horizon + 1,
105 group_name=group_name,
106 run_name=run_name,
107 tstep_encoder_type=tstep_encoder_type,
108 traj_encoder_type=traj_encoder_type,
109 agent_type=agent_type,
110 val_timesteps_per_epoch=args.horizon + 1,
111 sample_actions=False, # even softmax prob .999 isn't good enough for this env...
112 exploration_wrapper_type=TMazeExploration,
113 )
114 experiment = cli_utils.switch_async_mode(experiment, args.mode)
115 experiment.start()
116 if args.ckpt is not None:
117 experiment.load_checkpoint(args.ckpt)
118 experiment.learn()
119 experiment.evaluate_test(make_env, timesteps=args.horizon * 5, render=False)
120 experiment.delete_buffer_from_disk()
121 wandb.finish()