# Supplementary code for the paper Compositional Policy Learning in Stochastic Control Systems with Formal Guarantees

## Requirements

Python 3.8 or newer. 
For the installation of JAX with GPU support see [here](https://github.com/google/jax)

```bash
pip3 install flax optax gym numpy tqdm tensorflow seaborn
```

Note that Tensorflow is only used for the ```tf.data``` API.

## Pre-training policies with PPO

To train a policy network for 100 PPO iterations run:

```bash
python3 rsm_loop.py --env rooms0_1_l1 --p_lip 8.0 --ppo_iters 100 --eps 0.01 --norm l1 --only_ppo
```

The policy is then saved in ```checkpoints/rooms0_1_l1.jax```

## Available pre-trained policies

The ```checkpoints``` directory contains the pre-trained policies used in the experiments

## Sub environments

The sub-environments are named as ```rooms{start_room}_{target_room}_{norm}```, where ```start_room``` and ```target_room``` are integers between 0 and 8 and ```norm``` is either ```l1``` or ```linf```.
For a map of the rooms see the paper or the ```plots``` directory.

Total goal is to reach from room 0 (bottom left) to room 8 (top right).