"""Plot results for CIFAR-10 LR tuning.

Example
-------
python plot_cifar10_lr.py
"""
import os
import pdb
import csv
from collections import defaultdict

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
sns.set_palette('muted')


def load_log(exp_dir, log_filename='iteration_log.csv'):
  result_dict = defaultdict(list)
  with open(os.path.join(exp_dir, log_filename), newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
      for key in row:
        try:
          if key in ['global_iteration', 'iteration', 'epoch']:
            result_dict[key].append(int(row[key]))
          else:
            result_dict[key].append(float(row[key]))
        except:
          result_dict[key].append(None)
  return result_dict


plot_dirs = [
    ('SGDm Fixed LR', 'experiments/cifar10_wrn_sgdm_fixed_lr/dset:cifar10-model:wideresnet-nl:2-b:sgdmwd-m:rmsprop-bs:128-ilr:0.03-mlr:0.1-lam:0-mstp:0-mint:10-ag:1-wd-0.0-val:0-ep:200-fac:0.2-dat:60,120,160-seed:11/'),
    ('SGDm Decayed LR', 'experiments/cifar10_wrn_sgdm_decayed/dset:cifar10-model:wideresnet-nl:2-b:sgdmwd-m:rmsprop-bs:128-ilr:0.1-mlr:0.1-lam:0-mstp:0-mint:10-ag:1-wd-0.0005-val:0-ep:200-fac:0.2-dat:60,120,160-seed:11/'),
    ('SGDm-APO', 'experiments/cifar10_wrn_sgdm_apo/dset:cifar10-model:wideresnet-nl:2-b:sgdmwd-m:rmsprop-bs:128-ilr:0.1-mlr:0.1-lam:0.1-mstp:1-mint:10-ag:1-wd-0.0001-val:0-ep:200-fac:0.2-dat:60,120,160-seed:11/'),
]

# plot_dirs = [
#     ('RMSprop Fixed LR', 'experiments/cifar10_resnet34_rmsprop_fixed_lr/dset:cifar10-model:resnet34-nl:2-b:rmsprop-m:None-bs:128-ilr:0.001-mlr:0.1-lam:0-mstp:0-mint:10-ag:1-wd-0.0-val:0-ep:200-fac:0.2-dat:60,120,160-seed:11'),
#     ('RMSprop Decayed LR', 'experiments/cifar10_resnet34_rmsprop_decayed_lr/dset:cifar10-model:resnet34-nl:2-b:rmsprop-m:None-bs:128-ilr:0.001-mlr:0.1-lam:0-mstp:0-mint:10-ag:1-wd-0.0-val:0-ep:200-fac:0.2-dat:60,120,160-seed:11'),
#     ('RMSprop-APO', 'experiments/cifar10_resnet34_rmsprop_apo/dset:cifar10-model:resnet34-nl:2-b:rmsprop-m:None-bs:128-ilr:0.0001-mlr:0.1-lam:1e-05-mstp:1-mint:10-ag:1-wd-0.0-val:0-ep:200-fac:0.2-dat:60,120,160-seed:11'),
# ]

if not os.path.exists('figures/cifar10_lr'):
  os.makedirs('figures/cifar10_lr')

# Plot test acc
# -------------
fig = plt.figure()
for name, exp_path in plot_dirs:
  stats = load_log(os.path.join(exp_path), log_filename='epoch_log.csv')
  plt.plot(stats['epoch'], stats['test_acc'], label=name, linewidth=2)

plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Test Accuracy', fontsize=20)
plt.ylim(0.86, 0.965)
plt.legend(fontsize=18, fancybox=True, framealpha=0.3)

plt.savefig('figures/cifar10_lr/cifar10_test_acc.pdf',
            bbox_inches='tight', pad_inches=0)
plt.savefig('figures/cifar10_lr/cifar10_test_acc.png',
            bbox_inches='tight', pad_inches=0)
plt.close(fig)


# Plot learning rates
# -------------------
fig = plt.figure()
for name, exp_path in plot_dirs:
  stats = load_log(os.path.join(exp_path), log_filename='iteration_log.csv')
  plt.plot(stats['iteration'], stats['0 lr'], label=name, linewidth=2)

plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Iteration', fontsize=20)
plt.ylabel('Learning Rate', fontsize=20)
plt.yscale('log')
plt.legend(fontsize=18, fancybox=True, framealpha=0.3)

plt.savefig('figures/cifar10_lr/cifar10_lr.pdf',
            bbox_inches='tight', pad_inches=0)
plt.savefig('figures/cifar10_lr/cifar10_lr.png',
            bbox_inches='tight', pad_inches=0)
plt.close(fig)
