D4RL

D4RL#

  1from argparse import ArgumentParser
  2import random
  3
  4import torch
  5import gym as og_gym
  6import d4rl
  7import gymnasium as gym
  8import wandb
  9import numpy as np
 10
 11import amago
 12from amago.envs import AMAGOEnv
 13from amago import cli_utils
 14from amago.loading import RLData, RLDataset, DiskTrajDataset, MixtureOfDatasets
 15from amago.nets.policy_dists import TanhGaussian, GMM, Beta
 16from amago.nets.actor_critic import ResidualActor, Actor
 17from amago.agent import binary_filter, exp_filter
 18
 19
 20def add_cli(parser):
 21    parser.add_argument(
 22        "--env", type=str, required=True, help="Environment/Dataset name"
 23    )
 24    parser.add_argument(
 25        "--max_seq_len", type=int, default=32, help="Policy sequence length."
 26    )
 27    parser.add_argument(
 28        "--policy_dist",
 29        type=str,
 30        default="Beta",
 31        help="Policy distribution type",
 32        choices=["TanhGaussian", "GMM", "Beta"],
 33    )
 34    parser.add_argument(
 35        "--actor_type",
 36        type=str,
 37        default="Actor",
 38        help="Actor head type",
 39        choices=["ResidualActor", "Actor"],
 40    )
 41    parser.add_argument(
 42        "--online_after_epoch",
 43        type=int,
 44        default=float("inf"),
 45        help="Number of epochs after which to start collecting online data",
 46    )
 47    parser.add_argument(
 48        "--eval_timesteps",
 49        type=int,
 50        default=1000,
 51        help="Number of timesteps to evaluate for each actor. Will be overridden if the environment has a known time limit.",
 52    )
 53    return parser
 54
 55
 56class D4RLDataset(RLDataset):
 57    def __init__(self, d4rl_dset: dict[str, np.ndarray]):
 58        super().__init__()
 59        self.d4rl_dset = d4rl_dset
 60        self.episode_ends = np.where(d4rl_dset["terminals"] | d4rl_dset["timeouts"])[0]
 61        self.ep_lens = self.episode_ends[1:] - self.episode_ends[:-1]
 62        self.max_ep_len = self.ep_lens.max()
 63
 64    def get_description(self) -> str:
 65        return "D4RL"
 66
 67    @property
 68    def save_new_trajs_to(self):
 69        # disables saving new amago trajectories to disk
 70        return None
 71
 72    def sample_random_trajectory(self):
 73        episode_idx = random.randrange(0, len(self.episode_ends) - 1)
 74        return self._sample_trajectory(episode_idx)
 75
 76    def _sample_trajectory(self, episode_idx: int):
 77        # pick a random episode
 78        s = self.episode_ends[episode_idx] + 1
 79        e = self.episode_ends[episode_idx + 1] + 1
 80        traj_len = e - s
 81        obs_np = self.d4rl_dset["observations"][s : e + 1]
 82        actions_np = self.d4rl_dset["actions"][s:e]
 83        rewards_np = self.d4rl_dset["rewards"][s:e]
 84        terminals_np = self.d4rl_dset["terminals"][s:e]
 85        timeouts_np = self.d4rl_dset["timeouts"][s:e]
 86
 87        # convert to torch, adding time_idxs
 88        obs = {"observation": torch.from_numpy(obs_np)}
 89        actions = torch.from_numpy(actions_np).float()
 90        rewards = torch.from_numpy(rewards_np).float().unsqueeze(-1)
 91        time_idxs = torch.arange(traj_len).unsqueeze(-1).long()
 92        dones = torch.from_numpy(terminals_np).bool().unsqueeze(-1)
 93        return RLData(
 94            obs=obs,
 95            actions=actions,
 96            rews=rewards,
 97            dones=dones,
 98            time_idxs=time_idxs,
 99        )
100
101
102from amago.envs.env_utils import space_convert
103from amago.envs.amago_env import AMAGO_ENV_LOG_PREFIX
104
105
106class D4RLGymEnv(gym.Env):
107    """
108    Light wrapper that logs the D4RL normalized return and handles
109    the gym/gymnasium conversion while we're at it.
110    """
111
112    def __init__(self, env_name: str):
113        # hack fix seeding for parallel envs
114        np.random.seed(random.randrange(1e6))
115        self.env_name = env_name
116        self.env = og_gym.make(env_name)
117        self.action_space = space_convert(self.env.action_space)
118        self.observation_space = space_convert(self.env.observation_space)
119        if isinstance(self.env, og_gym.wrappers.TimeLimit):
120            # this time limit is apparently not consistent with the datasets
121            self.time_limit = self.env._max_episode_steps
122        else:
123            self.time_limit = None
124        self.max_return = d4rl.infos.REF_MAX_SCORE[self.env_name]
125        self.min_return = d4rl.infos.REF_MIN_SCORE[self.env_name]
126
127    @property
128    def dset(self):
129        return self.env.get_dataset()
130
131    def reset(self, *args, **kwargs):
132        self.episode_return = 0
133        return self.env.reset(), {}
134
135    def step(self, action):
136        s, r, d, i = self.env.step(action)
137        truncated = i.get("TimeLimit.truncated", False)
138        terminated = d and not truncated
139        self.episode_return += r
140        if terminated or truncated:
141            i[f"{AMAGO_ENV_LOG_PREFIX} D4RL Normalized Return"] = (
142                d4rl.get_normalized_score(self.env_name, self.episode_return)
143            )
144        return s, r, terminated, truncated, i
145
146
147if __name__ == "__main__":
148    parser = ArgumentParser()
149    cli_utils.add_common_cli(parser)
150    add_cli(parser)
151    args = parser.parse_args()
152
153    # ues env to set some args
154    env_name = args.env
155    example_env = D4RLGymEnv(args.env)
156    assert isinstance(
157        example_env.action_space, gym.spaces.Box
158    ), "Only supports continuous action spaces"
159
160    args.eval_timesteps = example_env.time_limit + 1
161    args.timesteps_per_epoch = example_env.time_limit
162
163    # setup environment
164    make_train_env = lambda: AMAGOEnv(
165        D4RLGymEnv(args.env),
166        env_name=env_name,
167        batched_envs=1,
168    )
169
170    # agent architecture: drop everything down to standard small sizes
171    config = {
172        "amago.nets.actor_critic.NCritics.d_hidden": 128,
173        "amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 128,
174        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 64,
175        "amago.nets.actor_critic.Actor.d_hidden": 128,
176        "amago.nets.actor_critic.Actor.continuous_dist_type": eval(args.policy_dist),
177        "amago.nets.actor_critic.ResidualActor.feature_dim": 128,
178        "amago.nets.actor_critic.ResidualActor.residual_ff_dim": 256,
179        "amago.nets.actor_critic.ResidualActor.residual_blocks": 2,
180        "amago.nets.actor_critic.ResidualActor.continuous_dist_type": eval(
181            args.policy_dist
182        ),
183    }
184    tstep_encoder_type = cli_utils.switch_tstep_encoder(
185        config,
186        arch="ff",
187        d_hidden=128,
188        d_output=128,
189        n_layers=1,
190    )
191    exploration_wrapper_type = cli_utils.switch_exploration(
192        config,
193        strategy="egreedy",
194        eps_start=0.05,
195        eps_end=0.01,
196        steps_anneal=15_000,
197    )
198    traj_encoder_type = cli_utils.switch_traj_encoder(
199        config,
200        arch=args.traj_encoder,
201        memory_size=args.memory_size,
202        layers=args.memory_layers,
203    )
204    agent_type = cli_utils.switch_agent(
205        config,
206        args.agent_type,
207        online_coeff=0.0,
208        offline_coeff=1.0,
209        gamma=0.997,
210        reward_multiplier=100.0 if example_env.max_return <= 10.0 else 1,
211        num_actions_for_value_in_critic_loss=3,
212        num_actions_for_value_in_actor_loss=5,
213        num_critics=5,
214        actor_type=eval(args.actor_type),
215        fbc_filter_func=exp_filter,
216    )
217    cli_utils.use_config(config, args.configs)
218
219    group_name = f"{args.run_name}_{env_name}"
220    for trial in range(args.trials):
221        run_name = group_name + f"_trial_{trial}"
222
223        # create dataset
224        d4rl_dataset = D4RLDataset(d4rl_dset=example_env.dset)
225        online_dset = DiskTrajDataset(
226            dset_root=args.buffer_dir,
227            dset_name=run_name,
228            dset_min_size=250,
229            dset_max_size=args.dset_max_size,
230        )
231        combined_dset = MixtureOfDatasets(
232            datasets=[d4rl_dataset, online_dset],
233            # skew sampling towards the demos 80/20
234            sampling_weights=[0.8, 0.2],
235            # gradually increase the weight of the online dset
236            # over the first 100 epochs *after online collection starts*
237            smooth_sudden_starts=50,
238        )
239
240        experiment = cli_utils.create_experiment_from_cli(
241            args,
242            make_train_env=make_train_env,
243            make_val_env=make_train_env,
244            max_seq_len=args.max_seq_len,
245            run_name=run_name,
246            tstep_encoder_type=tstep_encoder_type,
247            traj_encoder_type=traj_encoder_type,
248            agent_type=agent_type,
249            group_name=group_name,
250            val_timesteps_per_epoch=args.eval_timesteps,
251            learning_rate=1e-4,
252            dataset=combined_dset,
253            padded_sampling="right",
254            start_collecting_at_epoch=args.online_after_epoch,
255            stagger_traj_file_lengths=False,
256            traj_save_len=args.eval_timesteps + 1,
257            sample_actions=False,
258            exploration_wrapper_type=exploration_wrapper_type,
259        )
260        # save a copy of this script at the time of the run
261        experiment = cli_utils.switch_async_mode(experiment, args.mode)
262        experiment.start()
263        if args.ckpt is not None:
264            experiment.load_checkpoint(args.ckpt)
265        experiment.learn()
266        experiment.evaluate_test(make_train_env, timesteps=10_000, render=False)
267        wandb.finish()