1"""
2A demonstration of the hindsight instruction relabeling techinque discussed in the AMAGO paper -
3a generalization of Hindsight Experience Replay (HER) to sequences of multiple goals.
4The ability to relabel data is another good reason to prefer off-policy RL^2 to on-policy.
5
6This example uses the "MazeRunner" environment. MazeRunner is an adapted version of Memory Maze
7(https://arxiv.org/abs/2210.13383) that does not require the DM Lab simulator or learning from pixels.
8For more information please refer to the AMAGO Appendix C.4.
9
10There are three steps to using hindsight relabeling:
11
121. Make the env's observations a dict with keys for the intended goal and alternative goals achieved
13 at that timestep. Goals are often subsets of the state space. In this example, observations consist of:
14
15 `obs` : the regular observation from the maze environment (LIDAR-ish depth sensors to the walls, timer, etc.)
16 `goals` : a sequence of k goal positions to navigate to.
17 `achieved`: the agent's current position.
18
192. Let the policy network take the observation and goals as input, but ignore the `achieved` data.
20
213. Use the `achieved` key to relabel data with alternative goal sequences that would lead to higher returns.
22"""
23
24from argparse import ArgumentParser
25import random
26from functools import partial
27
28import gymnasium as gym
29import wandb
30import numpy as np
31
32import amago
33from amago.envs.builtin.mazerunner import MazeRunnerAMAGOEnv
34from amago.hindsight import Relabeler, FrozenTraj
35from amago.loading import DiskTrajDataset
36from amago.cli_utils import *
37
38
39def add_cli(parser):
40 parser.add_argument(
41 "--maze_size",
42 type=int,
43 default=11,
44 help="Dimension of randomly generated mazes (n x n). n must be odd.",
45 )
46 parser.add_argument(
47 "--goals",
48 type=int,
49 default=3,
50 help="Length of the sequence of goal positions to reach during the episode.",
51 )
52 parser.add_argument(
53 "--time_limit", type=int, default=250, help="Episode time limit."
54 )
55 parser.add_argument(
56 "--relabel",
57 choices=["some", "all", "none"],
58 default="some",
59 help="`none` skips relabeling, `all` relabels every trajectory to a success. `some` creates a mixture of varying returns.",
60 )
61 parser.add_argument(
62 "--randomized_actions",
63 action="store_true",
64 help="Randomize the directions of movement each episode (requires context-based identitifcation of current controls).",
65 )
66 return parser
67
68
69class HindsightInstructionReplay(Relabeler):
70 """
71 Hindsight Experience Replay extended to "instructions" or sequences of multiple goals.
72
73 Relabelers are passed RL trajectory data before it is padded + batched and sent to the agent
74 for training.
75 """
76
77 def __init__(self, num_goals: int, strategy: str = "all"):
78 assert strategy in ["some", "all", "none"]
79 self.strategy = strategy
80 self.k = num_goals
81
82 def relabel(self, traj: FrozenTraj) -> FrozenTraj:
83 """
84 Assume observations are a dict with three keys:
85 1. `obs` : the regular observation from the env
86 2. `achieved` : candidate goals for relabeling (in this case: current (x, y) position)
87 3. `goals` : a (k, n) array of the k goals we want to reach
88
89 Agent receives `obs` and `goals` as input, and we use `achieved` to relabel failed trajectories
90 with new goals that lead to more reward signal.
91
92 var names are references to pseudocode in AMAGO Appendix B Alg 1 (arXiv page 20)
93 """
94 if self.strategy == "none":
95 del traj.obs["achieved"]
96 return traj
97
98 # 1. Find timesteps where original goals were completed
99 k = self.k
100 length = len(traj.rews)
101 tsteps_with_goals = np.nonzero(traj.rews[:, 0])[0].tolist()
102 n = len(tsteps_with_goals)
103 goals_completed = [
104 traj.obs["goals"][t][i] for i, t in enumerate(tsteps_with_goals)
105 ]
106
107 # 2. Determine how many relabled goals we'll add
108 h = k - n if self.strategy == "all" else random.randint(0, k - n)
109
110 # it's important that this relabeler can recreate the exact reward func
111 # and terminal signals of the original env. The best way to check that is to
112 # let successful trajs (n == self.k, traj.rews.sum() == self.k) or those
113 # where we're not adding goals (h == 0) carry on through relabeling,
114 # then check that the "relabeled" verion is the same as the original.
115 # Since we've already done this we'll save the relabel time:
116 if h == 0:
117 del traj.obs["achieved"]
118 return traj
119
120 # 3. Pick h goals that were achieved as replacements for relabeling
121 # in this env, every timestep has an alternative goal
122 alternative_tsteps = set(range(1, length))
123 # but don't sample a timestep that achieved a real goal already
124 candidate_tsteps = list(alternative_tsteps - set(tsteps_with_goals))
125 h = min(h, len(candidate_tsteps))
126 tsteps_with_alt_goals = random.sample(candidate_tsteps, k=h)
127 alternatives = [g[0] for g in traj.obs["achieved"][tsteps_with_alt_goals]]
128
129 # 4. Sort the (new) "alternative" goals and completed (real) goals in the order
130 # they'd occur in the trajectory. Leave uncompleted real goals at the end
131 # (in their original order).
132 combined_goals = alternatives + goals_completed
133 tsteps_with_combined_goals = tsteps_with_alt_goals + tsteps_with_goals
134 # sort combined_goals by tsteps
135 r = [g for _, g in sorted(zip(tsteps_with_combined_goals, combined_goals))]
136 new_goals = traj.obs["goals"][0].copy()
137 for i, new_goal in enumerate(r):
138 new_goals[i, :] = new_goal
139
140 # 5. Replay the trajectory as if rewards were computed with new_goals.
141 # This step requires knowledge of the reward function (like HER). In this
142 # case: binary check that "achieved" == "goals"[current_goal_idx]
143 traj.obs["goals"] = np.repeat(new_goals[np.newaxis, ...], length + 1, axis=0)
144 traj.rews[:] = 0.0
145 # note that `rl2s` array is prev_action + prev_reward. The rew is in the 0 index.
146 traj.rl2s[:, 0] = 0.0
147 active_goal_idx = 0
148 for t in range(length + 1):
149 # the format of arrays in the `traj` object is:
150 # traj.obs = {o_0, o_1, ..., o_length}
151 # traj.rl2s = {rl2_0, rl2_1, ..., rl2_length}
152 # traj.rews = {r_1, r_2, ..., missing}
153 # traj.dones = {d_1, d_2, ..., missing}
154 achieved_this_turn = traj.obs["achieved"][t][0]
155 # some envs might have multiple goals per timestep
156 if np.array_equal(achieved_this_turn, new_goals[active_goal_idx]):
157 traj.rews[t - 1] = 1.0
158 traj.rl2s[t][0] = 1.0
159 active_goal_idx += 1
160 # -1 is "accomplished"... would change by env, but we need some way to keep goals consistent shape
161 traj.obs["goals"][t][:active_goal_idx, ...] = -1
162 if active_goal_idx >= k:
163 traj.dones[t - 1] = True
164 break
165 # enforce early termination
166 del traj.obs["achieved"]
167 traj.obs = {k: v[: t + 1] for k, v in traj.obs.items()}
168 traj.rl2s = traj.rl2s[: t + 1]
169 traj.time_idxs = traj.time_idxs[: t + 1]
170 traj.rews = traj.rews[:t]
171 traj.dones = traj.dones[:t]
172 traj.actions = traj.actions[:t]
173 return traj
174
175
176if __name__ == "__main__":
177 parser = ArgumentParser()
178 add_common_cli(parser)
179 add_cli(parser)
180 args = parser.parse_args()
181
182 # setup environment
183 # the AMAGO wrapper adds the relabeling info to the obs dict
184 make_train_env = lambda: MazeRunnerAMAGOEnv(
185 maze_dim=args.maze_size,
186 num_goals=args.goals,
187 time_limit=args.time_limit,
188 randomized_action_space=args.randomized_actions,
189 )
190 config = {}
191 # switch sequence model
192 traj_encoder_type = switch_traj_encoder(
193 config,
194 arch=args.traj_encoder,
195 memory_size=args.memory_size,
196 layers=args.memory_layers,
197 )
198 # switch agent
199 agent_type = switch_agent(
200 config,
201 args.agent_type,
202 # reward_multiplier=100.,
203 )
204 tstep_encoder_type = switch_tstep_encoder(
205 config,
206 arch="ff",
207 n_layers=2,
208 d_hidden=128,
209 # ignore "achieved" obs dict in the policy net.
210 specify_obs_keys=["obs", "goals"],
211 )
212 exploration_type = switch_exploration(
213 config,
214 strategy="egreedy",
215 steps_anneal=500_000, # needs re-tuning; paper used an older version
216 )
217 use_config(config, args.configs)
218
219 group_name = args.run_name
220 for trial in range(args.trials):
221 run_name = group_name + f"_trial_{trial}"
222
223 dataset = DiskTrajDataset(
224 dset_root=args.buffer_dir,
225 dset_name=run_name,
226 dset_max_size=args.dset_max_size,
227 relabeler=HindsightInstructionReplay(
228 num_goals=args.goals,
229 strategy=args.relabel,
230 ),
231 )
232
233 experiment = create_experiment_from_cli(
234 args,
235 make_train_env=make_train_env,
236 make_val_env=make_train_env,
237 # paper made a point of using maximum context length; in practice this can be shortened with similar results
238 max_seq_len=args.time_limit,
239 # make sure the entire trajectory is contained in one file that will be sent to the relabeler
240 traj_save_len=args.time_limit + 1,
241 stagger_traj_file_lengths=False,
242 # provide the dataset explicitly to use our relabeler instead of the default.
243 # create_experiment_from_cli creates the default dataset otherwise.
244 dataset=dataset,
245 run_name=run_name,
246 tstep_encoder_type=tstep_encoder_type,
247 traj_encoder_type=traj_encoder_type,
248 exploration_wrapper_type=exploration_type,
249 agent_type=agent_type,
250 group_name=group_name,
251 val_timesteps_per_epoch=args.time_limit * 5,
252 )
253 experiment = switch_async_mode(experiment, args.mode)
254 experiment.start()
255 if args.ckpt is not None:
256 experiment.load_checkpoint(args.ckpt)
257 experiment.learn()
258 experiment.evaluate_test(
259 make_train_env, timesteps=args.time_limit * 20, render=False
260 )
261 experiment.delete_buffer_from_disk()
262 wandb.finish()