D4RL

D4RL#

  1from argparse import ArgumentParser
  2import random
  3
  4import torch
  5import gym as og_gym
  6import d4rl
  7import gymnasium as gym
  8import wandb
  9import numpy as np
 10
 11import amago
 12from amago.envs import AMAGOEnv
 13from amago import cli_utils
 14from amago.loading import RLData, RLDataset
 15from amago.nets.policy_dists import TanhGaussian, GMM, Beta
 16from amago.nets.actor_critic import ResidualActor, Actor
 17
 18
 19def add_cli(parser):
 20    parser.add_argument(
 21        "--env", type=str, required=True, help="Environment/Dataset name"
 22    )
 23    parser.add_argument(
 24        "--max_seq_len", type=int, default=32, help="Policy sequence length."
 25    )
 26    parser.add_argument(
 27        "--policy_dist",
 28        type=str,
 29        default="TanhGaussian",
 30        help="Policy distribution type",
 31        choices=["TanhGaussian", "GMM", "Beta"],
 32    )
 33    parser.add_argument(
 34        "--actor_type",
 35        type=str,
 36        default="Actor",
 37        help="Actor head type",
 38        choices=["ResidualActor", "Actor"],
 39    )
 40    parser.add_argument(
 41        "--eval_timesteps",
 42        type=int,
 43        default=1000,
 44        help="Number of timesteps to evaluate for each actor. Will be overridden if the environment has a known time limit.",
 45    )
 46    return parser
 47
 48
 49class D4RLDataset(RLDataset):
 50    def __init__(self, d4rl_dset: dict[str, np.ndarray]):
 51        super().__init__()
 52        self.d4rl_dset = d4rl_dset
 53        self.episode_ends = np.where(d4rl_dset["terminals"] | d4rl_dset["timeouts"])[0]
 54        self.ep_lens = self.episode_ends[1:] - self.episode_ends[:-1]
 55        self.max_ep_len = self.ep_lens.max()
 56
 57    def get_description(self) -> str:
 58        return "D4RL"
 59
 60    @property
 61    def save_new_trajs_to(self):
 62        # disables saving new amago trajectories to disk
 63        return None
 64
 65    def sample_random_trajectory(self):
 66        episode_idx = random.randrange(0, len(self.episode_ends) - 1)
 67        return self._sample_trajectory(episode_idx)
 68
 69    def _sample_trajectory(self, episode_idx: int):
 70        # pick a random episode
 71        s = self.episode_ends[episode_idx] + 1
 72        e = self.episode_ends[episode_idx + 1] + 1
 73        traj_len = e - s
 74        obs_np = self.d4rl_dset["observations"][s : e + 1]
 75        actions_np = self.d4rl_dset["actions"][s:e]
 76        rewards_np = self.d4rl_dset["rewards"][s:e]
 77        terminals_np = self.d4rl_dset["terminals"][s:e]
 78        timeouts_np = self.d4rl_dset["timeouts"][s:e]
 79
 80        # convert to torch, adding time_idxs
 81        obs = {"observation": torch.from_numpy(obs_np)}
 82        actions = torch.from_numpy(actions_np).float()
 83        rewards = torch.from_numpy(rewards_np).float().unsqueeze(-1)
 84        time_idxs = torch.arange(traj_len).unsqueeze(-1).long()
 85        dones = torch.from_numpy(terminals_np).bool().unsqueeze(-1)
 86
 87        return RLData(
 88            obs=obs,
 89            actions=actions,
 90            rews=rewards,
 91            dones=dones,
 92            time_idxs=time_idxs,
 93        )
 94
 95
 96from amago.envs.env_utils import space_convert
 97from amago.envs.amago_env import AMAGO_ENV_LOG_PREFIX
 98
 99
100class D4RLGymEnv(gym.Env):
101    """
102    Light wrapper that logs the D4RL normalized return and handles
103    the gym/gymnasium conversion while we're at it.
104    """
105
106    def __init__(self, env_name: str):
107        # hack fix seeding for parallel envs
108        np.random.seed(random.randrange(1e6))
109        self.env_name = env_name
110        self.env = og_gym.make(env_name)
111        self.action_space = space_convert(self.env.action_space)
112        self.observation_space = space_convert(self.env.observation_space)
113        if isinstance(self.env, og_gym.wrappers.TimeLimit):
114            # this time limit is apparently not consistent with the datasets
115            self.time_limit = self.env._max_episode_steps
116        else:
117            self.time_limit = None
118        self.max_return = d4rl.infos.REF_MAX_SCORE[self.env_name]
119        self.min_return = d4rl.infos.REF_MIN_SCORE[self.env_name]
120
121    @property
122    def dset(self):
123        return self.env.get_dataset()
124
125    def reset(self, *args, **kwargs):
126        self.episode_return = 0
127        return self.env.reset(), {}
128
129    def step(self, action):
130        s, r, d, i = self.env.step(action)
131        self.episode_return += r
132        if d:
133            i[f"{AMAGO_ENV_LOG_PREFIX} D4RL Normalized Return"] = (
134                d4rl.get_normalized_score(self.env_name, self.episode_return)
135            )
136        return s, r, d, d, i
137
138
139if __name__ == "__main__":
140    parser = ArgumentParser()
141    cli_utils.add_common_cli(parser)
142    add_cli(parser)
143    args = parser.parse_args()
144
145    # ues env to set some args
146    env_name = args.env
147    example_env = D4RLGymEnv(args.env)
148    assert isinstance(
149        example_env.action_space, gym.spaces.Box
150    ), "Only supports continuous action spaces"
151    if args.timesteps_per_epoch > 0:
152        print("WARNING: timesteps_per_epoch is not supported for D4RL, setting to 0")
153        args.timesteps_per_epoch = 0
154
155    # create dataset
156    dataset = D4RLDataset(d4rl_dset=example_env.dset)
157    args.eval_timesteps = example_env.time_limit + 1
158
159    # setup environment
160    make_train_env = lambda: AMAGOEnv(
161        D4RLGymEnv(args.env),
162        env_name=env_name,
163        batched_envs=1,
164    )
165
166    # agent architecture: drop everything down to standard small sizes
167    config = {
168        "amago.nets.actor_critic.NCritics.d_hidden": 128,
169        "amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 256,
170        "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 128,
171        "amago.nets.actor_critic.Actor.d_hidden": 128,
172        "amago.nets.actor_critic.Actor.continuous_dist_type": eval(args.policy_dist),
173        "amago.nets.actor_critic.ResidualActor.feature_dim": 128,
174        "amago.nets.actor_critic.ResidualActor.residual_ff_dim": 256,
175        "amago.nets.actor_critic.ResidualActor.residual_blocks": 2,
176        "amago.nets.actor_critic.ResidualActor.continuous_dist_type": eval(
177            args.policy_dist
178        ),
179    }
180    tstep_encoder_type = cli_utils.switch_tstep_encoder(
181        config,
182        arch="ff",
183        d_hidden=128,
184        d_output=128,
185        n_layers=1,
186    )
187    traj_encoder_type = cli_utils.switch_traj_encoder(
188        config,
189        arch=args.traj_encoder,
190        memory_size=args.memory_size,
191        layers=args.memory_layers,
192    )
193    agent_type = cli_utils.switch_agent(
194        config,
195        args.agent_type,
196        online_coeff=0.0,
197        offline_coeff=1.0,
198        gamma=0.995,
199        reward_multiplier=100.0 if example_env.max_return <= 10.0 else 1,
200        num_actions_for_value_in_critic_loss=2,
201        num_actions_for_value_in_actor_loss=4,
202        num_critics=4,
203        actor_type=eval(args.actor_type),
204    )
205    cli_utils.use_config(config, args.configs)
206
207    group_name = f"{args.run_name}_{env_name}"
208    for trial in range(args.trials):
209        run_name = group_name + f"_trial_{trial}"
210        experiment = cli_utils.create_experiment_from_cli(
211            args,
212            make_train_env=make_train_env,
213            make_val_env=make_train_env,
214            max_seq_len=args.max_seq_len,
215            run_name=run_name,
216            tstep_encoder_type=tstep_encoder_type,
217            traj_encoder_type=traj_encoder_type,
218            agent_type=agent_type,
219            group_name=group_name,
220            val_timesteps_per_epoch=args.eval_timesteps,
221            learning_rate=1e-4,
222            dataset=dataset,
223            padded_sampling="right",
224            sample_actions=False,
225        )
226        experiment = cli_utils.switch_async_mode(experiment, args.mode)
227        experiment.start()
228        if args.ckpt is not None:
229            experiment.load_checkpoint(args.ckpt)
230        experiment.learn()
231        experiment.evaluate_test(make_train_env, timesteps=10_000, render=False)
232        wandb.finish()