''' Sample
   This script loads a pretrained net and a weightsfile and sample '''
import functools
import os
import math
import numpy as np
import pickle
from tqdm import tqdm, trange


import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter as P
from torch.utils.data import DataLoader
import torchvision

# Import my stuff
import inception_utils
import utils
import losses
import BigGAN as bg
import datasets


class MH:
  @torch.no_grad()
  def __init__(self, G, D, init_samples, init_labels, config):
    with torch.no_grad():
      self.total_accepted = 0
      self.total_rejected = 0
      self.G = G
      self.D = nn.DataParallel(D)
      self.prev_preds = self.D(init_samples.cpu(), init_labels.cpu()).cpu().flatten()
      z_, y_ = utils.prepare_z_y(init_samples.shape[0], G.dim_z, config['n_classes'], device='cuda', fp16=config['G_fp16'])
      self.sample_G = functools.partial(utils.sample, G=G, z_=z_, y_=y_, config=config)

  def dre(self, x):
    return torch.exp(x)

  @torch.no_grad()
  def sample(self):
    num_accepted_batch = 0
    while num_accepted_batch == 0:
      next_samples, next_labels = self.sample_G()
      next_samples, next_labels = next_samples.cpu(), next_labels.cpu()
      next_preds = self.D(next_samples, next_labels).cpu().flatten()
      p = self.dre(next_preds) / self.dre(self.prev_preds)
      u = torch.rand_like(p)
      num_accepted_batch = torch.sum(p > u)
      self.total_accepted += num_accepted_batch
      self.total_rejected += len(next_samples) - num_accepted_batch
      self.prev_preds[p > u] = next_preds[p > u]
    return next_samples[p > u], next_labels[p > u]


def run(config):
  # Prepare state dict, which holds things like epoch # and itr #
  state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
                'best_IS': 0, 'best_FID': 999999, 'config': config}

  # Seed RNG
  utils.seed_rng(config['seed'])
   
  # Setup cudnn.benchmark for free speed
  torch.backends.cudnn.benchmark = True

  device = torch.device('cuda:0')
  G = bg.Generator(**config).to(device)
  D = bg.Discriminator(**config).to(device)
  utils.load_weights(G if not (config['use_ema']) else None, D, state_dict,
                     config['weights_root'], config['load_weights'],
                     G if config['ema'] and config['use_ema'] else None,
                     strict=False, load_optim=False)
  G.eval()
  batch_size = 1000
  # Update batch size setting used for G
  data_root = '../../ImageNet/ILSVRC128.hdf5'
  train_set = datasets.ILSVRC_HDF5(root=data_root, transform=None, load_in_mem=False,
                                   index_filename='%s_imgs.npz' % config['dataset'])
  init_images = torch.zeros([batch_size, 3, 128, 128])
  init_labels = torch.zeros([batch_size])
  init_ids = np.random.choice(len(train_set), size=batch_size, replace=False)
  for i in range(len(init_images)):
    init_images[i] = train_set[init_ids[i]][0]
    init_labels[i] = train_set[init_ids[i]][1]

  chain_sampler = MH(G, D, init_images, init_labels.long(), config)

  # Get Inception Score and FID
  act_path = '../../ImageNet/I128'
  get_inception_metrics = inception_utils.prepare_inception_metrics(act_path, config['parallel'], device, config['no_fid'])
  IS_mean, IS_std, FID = get_inception_metrics(chain_sampler.sample, config['num_inception_images'], num_splits=1, prints=True)
  # Prepare output string
  outstring = 'Using %s weights ' % ('ema' if config['use_ema'] else 'non-ema')
  outstring += 'in %s mode, ' % ('eval' if config['G_eval_mode'] else 'training')
  outstring += 'over %d images, ' % config['num_inception_images']
  if config['accumulate_stats'] or not config['G_eval_mode']:
    outstring += 'with batch size %d, ' % batch_size
  if config['accumulate_stats']:
    outstring += 'using %d standing stat accumulations, ' % config['num_standing_accumulations']
  outstring += 'Itr %d: Inception Score is %3.3f +/- %3.3f, FID is %5.4f' % (state_dict['itr'], IS_mean, IS_std, FID)
  print(outstring)
  AR = float(chain_sampler.total_accepted) / float(chain_sampler.total_accepted + chain_sampler.total_rejected)
  return IS_mean, IS_std, FID, AR

def prepare_config(state_dir, state_dict_name, weight_suffix, seed=0):
  state_dict = torch.load(os.path.join(state_dir, state_dict_name))
  config = state_dict['config']
  config['resolution'] = utils.imsize_dict[config['dataset']]
  config['n_classes'] = utils.nclass_dict[config['dataset']]
  config['G_activation'] = utils.activation_dict[config['G_nl']]
  config['D_activation'] = utils.activation_dict[config['D_nl']]
  config = utils.update_config_roots(config)
  config['skip_init'] = True
  config['no_optim'] = True
  config['weights_root'] = state_dir
  config['parallel'] = True
  config['sample_inception_metrics'] = True
  config['use_ema'] = False
  config['load_weights'] = weight_suffix
  config['seed'] = seed
  config['num_inception_images'] = 50000
  return config

def main():
  state_dir = '../trained-models/138k/138k_reset_head_woy_lr1e-5_long'
  state_dicts = list(filter(lambda s: 'state_dict' in s, os.listdir(state_dir)))
  state_dicts = list(filter(lambda s: 'copy' not in s, state_dicts))
  state_dicts = sorted(state_dicts, key=lambda s: int(s.split('.')[0].split('#')[-1]))
  results = []
  for state_dict_name in state_dicts:
    weight_suffix = state_dict_name.split('.')[0].split('_')[-1]
    for seed in [0,1]:
      config = prepare_config(state_dir, state_dict_name, weight_suffix, seed)
      print(config)
      IS_mean, IS_std, FID, AR = run(config)
      results.append({'state_dir': state_dir,
                      'state_dict': state_dict_name,
                      'config': config,
                      'IS_mean': IS_mean,
                      'IS_std': IS_std,
                      'FID': FID,
                      'AR': AR})
      with open('results.pkl', 'wb') as thefile:
        pickle.dump(results, thefile)


def main_():
  # parse command line and run    
  parser = utils.prepare_parser()
  parser = utils.add_sample_parser(parser)
  config = vars(parser.parse_args())
  print(config)
  run(config)
  
if __name__ == '__main__':    
  main()