MNIST/CIFAR-10/CIFAR-100 experiments
====================================

Code for running the experiments in the paper
Logical Activation Functions: Logit-space equivalents of Probabilistic Boolean Operators
https://openreview.net/forum?id=m6HNNpQO8dc
on the MNIST, CIFAR-10, and CIFAR-100 datasets using MLP, CNN, and ResNet-50 architectures.


Installation
------------

Set up environment and install dependencies.

```sh
# Example install
ENVNAME=ail
echo "Setting up environment $ENVNAME"
conda create --name "$ENVNAME" -q python=3.6.2 pip

# Activate the environment
conda activate "$ENVNAME"

# Install dependencies
conda install pytorch=1.9.0 torchvision=0.10.0 cudatoolkit=10.2 -c pytorch
pip install -r requirements.txt
```


Experiments
-----------

### MLP and CNN on MNIST

```
DATASET="mnist"
for ARCH in mlp cnn; do
    for ACTFUN_IDX in {0..10}; do
        for SEED in {0..39}; do
            python engine.py \
              --seed "$SEED" \
              --save_path "out" \
              --check_path "ckpt/$DATASET/$ARCH/${ACTFUN_IDX}_${SEED}" \
              --model "$ARCH" \
              --batch_size 100 \
              --actfun_idx "$ACTFUN_IDX" \
              --optim onecycle \
              --num_epochs 10 \
              --dataset "$DATASET" \
              --label _${ACTFUN_IDX}
        done
    done
done
```


### ResNet50 on CIFAR-10 and CIFAR-100

```
for DATASET in cifar10 cifar100; do
    for RESNET_TYPE in 0.5 1 2 4; do
        for ACTFUN_IDX in {0..10}; do
            for SEED in {0..2}; do
                python engine.py \
                  --seed "$SEED" \
                  --save_path "out" \
                  --check_path "ckpt/$DATASET/resnet_${RESNET_TYPE}_${ACTFUN_IDX}_${SEED}" \
                  --model resnet \
                  --batch_size 128 \
                  --actfun_idx "$ACTFUN_IDX" \
                  --optim onecycle \
                  --num_epochs 100 \
                  --dataset "$DATASET" \
                  --aug \
                  --distributed \
                  --mix_pre_apex \
                  --resnet_type "$RESNET_TYPE" \
                  --label _${RESNET_TYPE}_${ACTFUN_IDX}
            done
        done
    done
done
```
