Symbolic DM Alchemy

Symbolic DM Alchemy#

 1from argparse import ArgumentParser
 2
 3import wandb
 4
 5from amago.envs import AMAGOEnv
 6from amago.envs.builtin.alchemy import SymbolicAlchemy
 7from amago import cli_utils
 8
 9
10if __name__ == "__main__":
11    parser = ArgumentParser()
12    cli_utils.add_common_cli(parser)
13    args = parser.parse_args()
14
15    config = {}
16    traj_encoder_type = cli_utils.switch_traj_encoder(
17        config,
18        arch=args.traj_encoder,
19        memory_size=args.memory_size,
20        layers=args.memory_layers,
21    )
22    exploration_wrapper_type = cli_utils.switch_exploration(
23        config, "bilevel", rollout_horizon=200, steps_anneal=2_500_000
24    )
25    agent_type = cli_utils.switch_agent(
26        config, args.agent_type, reward_multiplier=100.0
27    )
28    tstep_encoder_type = cli_utils.switch_tstep_encoder(
29        config, arch="ff", n_layers=2, d_hidden=256, d_output=256
30    )
31
32    cli_utils.use_config(config, args.configs)
33    make_train_env = lambda: AMAGOEnv(
34        env=SymbolicAlchemy(),
35        env_name="dm_symbolic_alchemy",
36    )
37    group_name = f"{args.run_name}_symbolic_dm_alchemy"
38    for trial in range(args.trials):
39        run_name = group_name + f"_trial_{trial}"
40        experiment = cli_utils.create_experiment_from_cli(
41            args,
42            make_train_env=make_train_env,
43            make_val_env=make_train_env,
44            max_seq_len=201,
45            traj_save_len=201,
46            group_name=group_name,
47            run_name=run_name,
48            tstep_encoder_type=tstep_encoder_type,
49            traj_encoder_type=traj_encoder_type,
50            exploration_wrapper_type=exploration_wrapper_type,
51            agent_type=agent_type,
52            val_timesteps_per_epoch=2000,
53        )
54        experiment = cli_utils.switch_async_mode(experiment, args.mode)
55        experiment.start()
56        if args.ckpt is not None:
57            experiment.load_checkpoint(args.ckpt)
58        experiment.learn()
59        experiment.evaluate_test(make_train_env, timesteps=20_000, render=False)
60        experiment.delete_buffer_from_disk()
61        wandb.finish()