1from argparse import ArgumentParser
2import random
3
4import torch
5import gym as og_gym
6import d4rl
7import gymnasium as gym
8import wandb
9import numpy as np
10
11import amago
12from amago.envs import AMAGOEnv
13from amago import cli_utils
14from amago.loading import RLData, RLDataset
15from amago.nets.policy_dists import TanhGaussian, GMM, Beta
16from amago.nets.actor_critic import ResidualActor, Actor
17
18
19def add_cli(parser):
20 parser.add_argument(
21 "--env", type=str, required=True, help="Environment/Dataset name"
22 )
23 parser.add_argument(
24 "--max_seq_len", type=int, default=32, help="Policy sequence length."
25 )
26 parser.add_argument(
27 "--policy_dist",
28 type=str,
29 default="TanhGaussian",
30 help="Policy distribution type",
31 choices=["TanhGaussian", "GMM", "Beta"],
32 )
33 parser.add_argument(
34 "--actor_type",
35 type=str,
36 default="Actor",
37 help="Actor head type",
38 choices=["ResidualActor", "Actor"],
39 )
40 parser.add_argument(
41 "--eval_timesteps",
42 type=int,
43 default=1000,
44 help="Number of timesteps to evaluate for each actor. Will be overridden if the environment has a known time limit.",
45 )
46 return parser
47
48
49class D4RLDataset(RLDataset):
50 def __init__(self, d4rl_dset: dict[str, np.ndarray]):
51 super().__init__()
52 self.d4rl_dset = d4rl_dset
53 self.episode_ends = np.where(d4rl_dset["terminals"] | d4rl_dset["timeouts"])[0]
54 self.ep_lens = self.episode_ends[1:] - self.episode_ends[:-1]
55 self.max_ep_len = self.ep_lens.max()
56
57 def get_description(self) -> str:
58 return "D4RL"
59
60 @property
61 def save_new_trajs_to(self):
62 # disables saving new amago trajectories to disk
63 return None
64
65 def sample_random_trajectory(self):
66 episode_idx = random.randrange(0, len(self.episode_ends) - 1)
67 return self._sample_trajectory(episode_idx)
68
69 def _sample_trajectory(self, episode_idx: int):
70 # pick a random episode
71 s = self.episode_ends[episode_idx] + 1
72 e = self.episode_ends[episode_idx + 1] + 1
73 traj_len = e - s
74 obs_np = self.d4rl_dset["observations"][s : e + 1]
75 actions_np = self.d4rl_dset["actions"][s:e]
76 rewards_np = self.d4rl_dset["rewards"][s:e]
77 terminals_np = self.d4rl_dset["terminals"][s:e]
78 timeouts_np = self.d4rl_dset["timeouts"][s:e]
79
80 # convert to torch, adding time_idxs
81 obs = {"observation": torch.from_numpy(obs_np)}
82 actions = torch.from_numpy(actions_np).float()
83 rewards = torch.from_numpy(rewards_np).float().unsqueeze(-1)
84 time_idxs = torch.arange(traj_len).unsqueeze(-1).long()
85 dones = torch.from_numpy(terminals_np).bool().unsqueeze(-1)
86
87 return RLData(
88 obs=obs,
89 actions=actions,
90 rews=rewards,
91 dones=dones,
92 time_idxs=time_idxs,
93 )
94
95
96from amago.envs.env_utils import space_convert
97from amago.envs.amago_env import AMAGO_ENV_LOG_PREFIX
98
99
100class D4RLGymEnv(gym.Env):
101 """
102 Light wrapper that logs the D4RL normalized return and handles
103 the gym/gymnasium conversion while we're at it.
104 """
105
106 def __init__(self, env_name: str):
107 # hack fix seeding for parallel envs
108 np.random.seed(random.randrange(1e6))
109 self.env_name = env_name
110 self.env = og_gym.make(env_name)
111 self.action_space = space_convert(self.env.action_space)
112 self.observation_space = space_convert(self.env.observation_space)
113 if isinstance(self.env, og_gym.wrappers.TimeLimit):
114 # this time limit is apparently not consistent with the datasets
115 self.time_limit = self.env._max_episode_steps
116 else:
117 self.time_limit = None
118 self.max_return = d4rl.infos.REF_MAX_SCORE[self.env_name]
119 self.min_return = d4rl.infos.REF_MIN_SCORE[self.env_name]
120
121 @property
122 def dset(self):
123 return self.env.get_dataset()
124
125 def reset(self, *args, **kwargs):
126 self.episode_return = 0
127 return self.env.reset(), {}
128
129 def step(self, action):
130 s, r, d, i = self.env.step(action)
131 self.episode_return += r
132 if d:
133 i[f"{AMAGO_ENV_LOG_PREFIX} D4RL Normalized Return"] = (
134 d4rl.get_normalized_score(self.env_name, self.episode_return)
135 )
136 return s, r, d, d, i
137
138
139if __name__ == "__main__":
140 parser = ArgumentParser()
141 cli_utils.add_common_cli(parser)
142 add_cli(parser)
143 args = parser.parse_args()
144
145 # ues env to set some args
146 env_name = args.env
147 example_env = D4RLGymEnv(args.env)
148 assert isinstance(
149 example_env.action_space, gym.spaces.Box
150 ), "Only supports continuous action spaces"
151 if args.timesteps_per_epoch > 0:
152 print("WARNING: timesteps_per_epoch is not supported for D4RL, setting to 0")
153 args.timesteps_per_epoch = 0
154
155 # create dataset
156 dataset = D4RLDataset(d4rl_dset=example_env.dset)
157 args.eval_timesteps = example_env.time_limit + 1
158
159 # setup environment
160 make_train_env = lambda: AMAGOEnv(
161 D4RLGymEnv(args.env),
162 env_name=env_name,
163 batched_envs=1,
164 )
165
166 # agent architecture: drop everything down to standard small sizes
167 config = {
168 "amago.nets.actor_critic.NCritics.d_hidden": 128,
169 "amago.nets.actor_critic.NCriticsTwoHot.d_hidden": 256,
170 "amago.nets.actor_critic.NCriticsTwoHot.output_bins": 128,
171 "amago.nets.actor_critic.Actor.d_hidden": 128,
172 "amago.nets.actor_critic.Actor.continuous_dist_type": eval(args.policy_dist),
173 "amago.nets.actor_critic.ResidualActor.feature_dim": 128,
174 "amago.nets.actor_critic.ResidualActor.residual_ff_dim": 256,
175 "amago.nets.actor_critic.ResidualActor.residual_blocks": 2,
176 "amago.nets.actor_critic.ResidualActor.continuous_dist_type": eval(
177 args.policy_dist
178 ),
179 }
180 tstep_encoder_type = cli_utils.switch_tstep_encoder(
181 config,
182 arch="ff",
183 d_hidden=128,
184 d_output=128,
185 n_layers=1,
186 )
187 traj_encoder_type = cli_utils.switch_traj_encoder(
188 config,
189 arch=args.traj_encoder,
190 memory_size=args.memory_size,
191 layers=args.memory_layers,
192 )
193 agent_type = cli_utils.switch_agent(
194 config,
195 args.agent_type,
196 online_coeff=0.0,
197 offline_coeff=1.0,
198 gamma=0.995,
199 reward_multiplier=100.0 if example_env.max_return <= 10.0 else 1,
200 num_actions_for_value_in_critic_loss=2,
201 num_actions_for_value_in_actor_loss=4,
202 num_critics=4,
203 actor_type=eval(args.actor_type),
204 )
205 cli_utils.use_config(config, args.configs)
206
207 group_name = f"{args.run_name}_{env_name}"
208 for trial in range(args.trials):
209 run_name = group_name + f"_trial_{trial}"
210 experiment = cli_utils.create_experiment_from_cli(
211 args,
212 make_train_env=make_train_env,
213 make_val_env=make_train_env,
214 max_seq_len=args.max_seq_len,
215 run_name=run_name,
216 tstep_encoder_type=tstep_encoder_type,
217 traj_encoder_type=traj_encoder_type,
218 agent_type=agent_type,
219 group_name=group_name,
220 val_timesteps_per_epoch=args.eval_timesteps,
221 learning_rate=1e-4,
222 dataset=dataset,
223 padded_sampling="right",
224 sample_actions=False,
225 )
226 experiment = cli_utils.switch_async_mode(experiment, args.mode)
227 experiment.start()
228 if args.ckpt is not None:
229 experiment.load_checkpoint(args.ckpt)
230 experiment.learn()
231 experiment.evaluate_test(make_train_env, timesteps=10_000, render=False)
232 wandb.finish()