
# Retrieval-Augmented Multiple Instance Learning(RAM-MIL)

## ✓ Requirements

Use the environment configuration the same as CLAM:

```setup
conda env create -n clam -f CLAM/clam.yaml

pip install pot
pip install geomloss
```
## ✓ Training
##  CLAM Pretraining.
1. Get into the clam directory.
```bash
cd CLAM/
```

2. Split the data into 10-folds, then save the splited data in the following format as in `splits/xxx/splits_x.csv`.
```bash
python create_splits_seq.py --task task_1_tumor_vs_normal --seed 1 --label_frac 0.75 --k 10
```

3. Training the clam_sb model.
```bash
python main.py --drop_out --early_stopping --lr 2e-4 --k 10 --label_frac 0.75 --exp_code task_1_tumor_vs_normal_CLAM_75 --weighted_sample --bag_loss ce --inst_loss svm --task task_1_tumor_vs_normal --model_type clam_sb --log_data --data_root_dir data_root_dir --results_dir result
```

4. Save the slide-level feature and attention scores.

```bash
python save_slides.py
```


## Start Neighbor Retrieval
1. Get into the ot_retrieval directory.. 
```bash
cd ot_retrieval
```

2. Save the top 10% or 20% patch features and attention scores.
```bash
python attention.ipynb
```
3. Modify the file list.

`emb_c16_sort.txt` and `emb_c17_sort.txt` are sorted by the number of patches.

4. Save optimal transport loss.
>📋 Due to the large size of the tensor, there is a risk of memory explosion during the computation process. 
Therefore, it is recommended to consider splitting the data and allocating GPUs for parallel computation.

```bash
# In domain
python opt_in_domain.py
# Out-of-domain
python opt_in_out_domain.py
```
5. Retrieve nearest neighbors.

```bash
python retrieve_neighbor.py
```


## Classifier Training
1. Get into the ot_retrieval directory.. 
```bash
cd Classifier
```

2. Modify the corresponding neighbor file `loss_matrix_1616_20att_xx.json`.
```bash
cd datasets1/dataset_generic.py
```
```python
in class Generic_MIL_Dataset(Generic_WSI_Classification_Dataset)

# retrieval in-domain:
indomain_nebs/*.json

# retrieval in domain and out-of-domain:
inout_nebs/*.json

# retrieval out-of-domain:
out_nebs/*.json
```

3. Merge Function.
```bash
cd models/model_clam.py
## CLAM_SB: simple addition
## CLAM_SB_ADD: convex combination

```

4. Classifier Training.
```bash
CUDA_VISIBLE_DEVICES=0 nohup python main.py --drop_out --early_stopping --lr 2e-4 --k 10 --label_frac 0.75 --exp_code task_1_tumor_vs_normal_CLAM_75 --weighted_sample --bag_loss ce --inst_loss svm --task task_1_tumor_vs_normal --model_type clam_sb_add --log_data --data_root_dir slide_feature_dir --results_dir ./result_add --reg 1e-4 >  result-add.log &
```

## ✓ Evaluation

>📋 To evaluate my model in domain , run:

```eval
python -u eval_c16.py --drop_out --k 10 --models_exp_code task_1_tumor_vs_normal_CLAM_75_s1 --save_exp_code result_add --task task_1_tumor_vs_normal --model_type clam_sb_add --results_dir ./result-add --data_root_dir c16_slide_feature_dir --splits_dir CLAM/splits/task_1_tumor_vs_normal_75
```

>📋 To evaluate my model out of domain , run:
```bash
# Modify the corresponding neighbor file loss_matrix_1616_20att_xx.json to loss_matrix_1716_20att_xx.json/loss_matrix_1717_20att_xx.json.
python -u eval_c17.py --drop_out --k 10 --models_exp_code task_1_tumor_vs_normal_CLAM_75_s1 --save_exp_code c17_result --task task_1_tumor_vs_normal --model_type clam_sb_add --results_dir ./result-add --data_root_dir c17_slide_feature_dir --splits_dir CLAM/c17_splits/task_1_tumor_vs_normal_75
```

## ✓ Results

>📋Our model achieves the following performance on CAMELYON16 and CAMELYON17 for in-domain classification and unsuper-vised domain adaptation :

|                 |     In-Domain(CAM16)   |    |  Out-of-Domain(CAM17) |     |
|-----------------|--------------|--------------|--------------|--------------|
| MODEL           | AUC          | Accuracy     | AUC          | Accuracy     |
| RAM-MIL Retr_I  | 0.9451±0.036 | 0.8925±0.050 |              |              |
| RAM-MIL Retr_IO | 0.9365±0.052 | 0.9200±0.050 | 0.7974±0.054 | 0.7433±0.073 |
| RAM-MIL Retr_O  | 0.9419±0.048 | 0.9175±0.051 | 0.7681±0.058 | 0.7795±0.021 |


