# Disentangling the Predictive Variance of Deep Ensembles through the Neural Tangent Kernel

This repository contains code to reproduce some of the experimental results presented in the paper "Disentangling the Predictive Variance of Deep Ensembles through the Neural Tangent Kernel".

## Prerquisite
Plase follow the instruction on https://github.com/google/neural-tangents to install the appropriate packages.

## Commands for reproducing Table 1

To reproduce the results on N=1000, MNIST MLP, run
```
python3 jax_run_sgd.py --batch=100 --dataset=mnist --hidden_depths=2 --hidden_widths=1024 --net=mlp --num_train_data=1000 --num_ensemble=10 --binary=False --activation=relu
```

Reduce "batch" if you get our of memory error.

To reproduce the results on N=1000, CIFAR10 CNN, run
```
python3 jax_run_sgd.py --batch=100 --dataset=cifar10 --hidden_depths=2 --hidden_widths=256 --net=conv --num_train_data=1000 --num_ensemble=10 --binary=False --activation=relu
```
For N=50000, replace "--num_train_data=1000" by "--num_train_data=50000" and "--num_ensemble=10 " by "--num_ensemble=3" in the above commands.

## Commands for reproducing Figure 2

To reproduce the results on N=10, binarized MNIST MLP, run

```
python3 jax_run_ntk.py  --batch=50 --dataset=mnist --hidden_depths=1,2,4 --hidden_widths=64,128,256,512,1024 --net=mlp --num_train_data=100 --num_ensemble=30 --binary=True --activation=softplus
```

Likewise, for binarized CIFAR10 CNN, run

```
python3 jax_run_ntk.py  --batch=50 --dataset=cifar10 --hidden_depths=2,4 --hidden_widths=32,64,128,256 --net=conv --num_train_data=100 --num_ensemble=30 --binary=True --activation=softplus
```
