MazeRunner HER

MazeRunner HER#

  1"""
  2A demonstration of the hindsight instruction relabeling techinque discussed in the AMAGO paper -
  3a generalization of Hindsight Experience Replay (HER) to sequences of multiple goals.
  4The ability to relabel data is another good reason to prefer off-policy RL^2 to on-policy.
  5
  6This example uses the "MazeRunner" environment. MazeRunner is an adapted version of Memory Maze 
  7(https://arxiv.org/abs/2210.13383) that does not require the DM Lab simulator or learning from pixels.
  8For more information please refer to the AMAGO Appendix C.4.
  9
 10There are three steps to using hindsight relabeling:
 11
 121. Make the env's observations a dict with keys for the intended goal and alternative goals achieved
 13   at that timestep. Goals are often subsets of the state space. In this example, observations consist of:
 14   
 15   `obs` : the regular observation from the maze environment (LIDAR-ish depth sensors to the walls, timer, etc.)
 16   `goals` : a sequence of k goal positions to navigate to.
 17   `achieved`: the agent's current position.
 18
 192. Let the policy network take the observation and goals as input, but ignore the `achieved` data.
 20
 213. Use the `achieved` key to relabel data with alternative goal sequences that would lead to higher returns.
 22"""
 23
 24from argparse import ArgumentParser
 25import random
 26from functools import partial
 27
 28import gymnasium as gym
 29import wandb
 30import numpy as np
 31
 32import amago
 33from amago.envs.builtin.mazerunner import MazeRunnerAMAGOEnv
 34from amago.hindsight import Relabeler, FrozenTraj
 35from amago.loading import DiskTrajDataset
 36from amago.cli_utils import *
 37
 38
 39def add_cli(parser):
 40    parser.add_argument(
 41        "--maze_size",
 42        type=int,
 43        default=11,
 44        help="Dimension of randomly generated mazes (n x n). n must be odd.",
 45    )
 46    parser.add_argument(
 47        "--goals",
 48        type=int,
 49        default=3,
 50        help="Length of the sequence of goal positions to reach during the episode.",
 51    )
 52    parser.add_argument(
 53        "--time_limit", type=int, default=250, help="Episode time limit."
 54    )
 55    parser.add_argument(
 56        "--relabel",
 57        choices=["some", "all", "none"],
 58        default="some",
 59        help="`none` skips relabeling, `all` relabels every trajectory to a success. `some` creates a mixture of varying returns.",
 60    )
 61    parser.add_argument(
 62        "--randomized_actions",
 63        action="store_true",
 64        help="Randomize the directions of movement each episode (requires context-based identitifcation of current controls).",
 65    )
 66    return parser
 67
 68
 69class HindsightInstructionReplay(Relabeler):
 70    """
 71    Hindsight Experience Replay extended to "instructions" or sequences of multiple goals.
 72
 73    Relabelers are passed RL trajectory data before it is padded + batched and sent to the agent
 74    for training.
 75    """
 76
 77    def __init__(self, num_goals: int, strategy: str = "all"):
 78        assert strategy in ["some", "all", "none"]
 79        self.strategy = strategy
 80        self.k = num_goals
 81
 82    def relabel(self, traj: FrozenTraj) -> FrozenTraj:
 83        """
 84        Assume observations are a dict with three keys:
 85        1. `obs` : the regular observation from the env
 86        2. `achieved` : candidate goals for relabeling (in this case: current (x, y) position)
 87        3. `goals` : a (k, n) array of the k goals we want to reach
 88
 89        Agent receives `obs` and `goals` as input, and we use `achieved` to relabel failed trajectories
 90        with new goals that lead to more reward signal.
 91
 92        var names are references to pseudocode in AMAGO Appendix B Alg 1 (arXiv page 20)
 93        """
 94        if self.strategy == "none":
 95            del traj.obs["achieved"]
 96            return traj
 97
 98        # 1. Find timesteps where original goals were completed
 99        k = self.k
100        length = len(traj.rews)
101        tsteps_with_goals = np.nonzero(traj.rews[:, 0])[0].tolist()
102        n = len(tsteps_with_goals)
103        goals_completed = [
104            traj.obs["goals"][t][i] for i, t in enumerate(tsteps_with_goals)
105        ]
106
107        # 2. Determine how many relabled goals we'll add
108        h = k - n if self.strategy == "all" else random.randint(0, k - n)
109
110        # it's important that this relabeler can recreate the exact reward func
111        # and terminal signals of the original env. The best way to check that is to
112        # let successful trajs (n == self.k, traj.rews.sum() == self.k) or those
113        # where we're not adding goals (h == 0) carry on through relabeling,
114        # then check that the "relabeled" verion is the same as the original.
115        # Since we've already done this we'll save the relabel time:
116        if h == 0:
117            del traj.obs["achieved"]
118            return traj
119
120        # 3. Pick h goals that were achieved as replacements for relabeling
121        # in this env, every timestep has an alternative goal
122        alternative_tsteps = set(range(1, length))
123        # but don't sample a timestep that achieved a real goal already
124        candidate_tsteps = list(alternative_tsteps - set(tsteps_with_goals))
125        h = min(h, len(candidate_tsteps))
126        tsteps_with_alt_goals = random.sample(candidate_tsteps, k=h)
127        alternatives = [g[0] for g in traj.obs["achieved"][tsteps_with_alt_goals]]
128
129        # 4. Sort the (new) "alternative" goals and completed (real) goals in the order
130        # they'd occur in the trajectory. Leave uncompleted real goals at the end
131        # (in their original order).
132        combined_goals = alternatives + goals_completed
133        tsteps_with_combined_goals = tsteps_with_alt_goals + tsteps_with_goals
134        # sort combined_goals by tsteps
135        r = [g for _, g in sorted(zip(tsteps_with_combined_goals, combined_goals))]
136        new_goals = traj.obs["goals"][0].copy()
137        for i, new_goal in enumerate(r):
138            new_goals[i, :] = new_goal
139
140        # 5. Replay the trajectory as if rewards were computed with new_goals.
141        # This step requires knowledge of the reward function (like HER). In this
142        # case: binary check that "achieved" == "goals"[current_goal_idx]
143        traj.obs["goals"] = np.repeat(new_goals[np.newaxis, ...], length + 1, axis=0)
144        traj.rews[:] = 0.0
145        # note that `rl2s` array is prev_action + prev_reward. The rew is in the 0 index.
146        traj.rl2s[:, 0] = 0.0
147        active_goal_idx = 0
148        for t in range(length + 1):
149            # the format of arrays in the `traj` object is:
150            # traj.obs = {o_0, o_1, ..., o_length}
151            # traj.rl2s = {rl2_0, rl2_1, ..., rl2_length}
152            # traj.rews = {r_1, r_2, ..., missing}
153            # traj.dones = {d_1, d_2, ..., missing}
154            achieved_this_turn = traj.obs["achieved"][t][0]
155            # some envs might have multiple goals per timestep
156            if np.array_equal(achieved_this_turn, new_goals[active_goal_idx]):
157                traj.rews[t - 1] = 1.0
158                traj.rl2s[t][0] = 1.0
159                active_goal_idx += 1
160            # -1 is "accomplished"... would change by env, but we need some way to keep goals consistent shape
161            traj.obs["goals"][t][:active_goal_idx, ...] = -1
162            if active_goal_idx >= k:
163                traj.dones[t - 1] = True
164                break
165        # enforce early termination
166        del traj.obs["achieved"]
167        traj.obs = {k: v[: t + 1] for k, v in traj.obs.items()}
168        traj.rl2s = traj.rl2s[: t + 1]
169        traj.time_idxs = traj.time_idxs[: t + 1]
170        traj.rews = traj.rews[:t]
171        traj.dones = traj.dones[:t]
172        traj.actions = traj.actions[:t]
173        return traj
174
175
176if __name__ == "__main__":
177    parser = ArgumentParser()
178    add_common_cli(parser)
179    add_cli(parser)
180    args = parser.parse_args()
181
182    # setup environment
183    # the AMAGO wrapper adds the relabeling info to the obs dict
184    make_train_env = lambda: MazeRunnerAMAGOEnv(
185        maze_dim=args.maze_size,
186        num_goals=args.goals,
187        time_limit=args.time_limit,
188        randomized_action_space=args.randomized_actions,
189    )
190    config = {}
191    # switch sequence model
192    traj_encoder_type = switch_traj_encoder(
193        config,
194        arch=args.traj_encoder,
195        memory_size=args.memory_size,
196        layers=args.memory_layers,
197    )
198    # switch agent
199    agent_type = switch_agent(
200        config,
201        args.agent_type,
202        # reward_multiplier=100.,
203    )
204    tstep_encoder_type = switch_tstep_encoder(
205        config,
206        arch="ff",
207        n_layers=2,
208        d_hidden=128,
209        # ignore "achieved" obs dict in the policy net.
210        specify_obs_keys=["obs", "goals"],
211    )
212    exploration_type = switch_exploration(
213        config,
214        strategy="egreedy",
215        steps_anneal=500_000,  # needs re-tuning; paper used an older version
216    )
217    use_config(config, args.configs)
218
219    group_name = args.run_name
220    for trial in range(args.trials):
221        run_name = group_name + f"_trial_{trial}"
222
223        dataset = DiskTrajDataset(
224            dset_root=args.buffer_dir,
225            dset_name=run_name,
226            dset_max_size=args.dset_max_size,
227            relabeler=HindsightInstructionReplay(
228                num_goals=args.goals,
229                strategy=args.relabel,
230            ),
231        )
232
233        experiment = create_experiment_from_cli(
234            args,
235            make_train_env=make_train_env,
236            make_val_env=make_train_env,
237            # paper made a point of using maximum context length; in practice this can be shortened with similar results
238            max_seq_len=args.time_limit,
239            # make sure the entire trajectory is contained in one file that will be sent to the relabeler
240            traj_save_len=args.time_limit + 1,
241            stagger_traj_file_lengths=False,
242            # provide the dataset explicitly to use our relabeler instead of the default.
243            # create_experiment_from_cli creates the default dataset otherwise.
244            dataset=dataset,
245            run_name=run_name,
246            tstep_encoder_type=tstep_encoder_type,
247            traj_encoder_type=traj_encoder_type,
248            exploration_wrapper_type=exploration_type,
249            agent_type=agent_type,
250            group_name=group_name,
251            val_timesteps_per_epoch=args.time_limit * 5,
252        )
253        experiment = switch_async_mode(experiment, args.mode)
254        experiment.start()
255        if args.ckpt is not None:
256            experiment.load_checkpoint(args.ckpt)
257        experiment.learn()
258        experiment.evaluate_test(
259            make_train_env, timesteps=args.time_limit * 20, render=False
260        )
261        experiment.delete_buffer_from_disk()
262        wandb.finish()