import logging
from collections import OrderedDict
import torch
import torch.nn as nn
import os
import model.networks as networks
from .base_model import BaseModel
logger = logging.getLogger('base')


class DDPM(BaseModel):
    def __init__(self, opt):
        super(DDPM, self).__init__(opt)

        # define network and load pretrained models
        self.netG = self.set_device(networks.define_G(opt))

        if not opt['uncertainty_train']:
            self.netGU = self.set_device(networks.define_G(opt)) # uncertainty model

        self.schedule_phase = None

        # set loss and load resume state
        self.set_loss()

        self.set_new_noise_schedule(
            opt['model']['beta_schedule']['train'], schedule_phase='train')
        if self.opt['phase'] == 'train':
            self.netG.train()
            # find the parameters to optimize
            if opt['model']['finetune_norm']:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    v.requires_grad = False
                    if k.find('transformer') >= 0:
                        v.requires_grad = True
                        v.data.zero_()
                        optim_params.append(v)
                        logger.info(
                            'Params [{:s}] initialized to 0 and will optimize.'.format(k))
            else:
                optim_params = list(self.netG.parameters())

            self.optG = torch.optim.Adam(
                optim_params, lr=opt['train']["optimizer"]["lr"])
            self.log_dict = OrderedDict()

        if not opt['uncertainty_train'] and self.opt['phase'] == 'train':
            self.netGU.load_state_dict(torch.load(self.opt['path']['resume_state']+'_gen.pth'), strict=True)

        if self.opt['phase'] == 'test':
            self.netG = nn.DataParallel(self.netGU)
            self.netG.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
        else:
            self.load_network()
        # self.print_network()

    def feed_data(self, data):

        dic = {}
        dic['LQ'] = data['LQ']
        dic['GT'] = data['GT']

        self.data = self.set_device(dic)

    def optimize_parameters(self):

        self.optG.zero_grad()

        if not self.opt['uncertainty_train']:

            l_pix, l_svd = self.netG(self.data, self.netGU.denoise_fn)

            # need to average in multi-gpu
            b, c, h, w = self.data['LQ'].shape

            l_pix = l_pix.sum()/int(b*c*h*w)
            l_svd = l_svd.sum()/int(b*6)
            loss = l_pix + l_svd
            loss.backward()
            self.optG.step()

            # set log
            self.log_dict['loss'] = loss.item()
            self.log_dict['l_pix'] = l_pix.item()
            self.log_dict['l_svd'] = l_svd.item()
        else:
            l_pix = self.netG(self.data)

            b, c, h, w = self.data['LQ'].shape

            l_pix = l_pix.sum()/int(b*c*h*w)
            l_pix.backward()
            self.optG.step()

            # set log
            self.log_dict['l_pix'] = l_pix.item()


    def test(self, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.DataParallel):

                self.SR = self.netG.module.super_resolution(
                    self.data['LQ'], continous)
                
            else:
                self.SR = self.netG.super_resolution(
                    self.data['LQ'], continous)

        self.netG.train()

    def sample(self, batch_size=1, continous=False):
        self.netG.eval()
        with torch.no_grad():
            if isinstance(self.netG, nn.DataParallel):
                self.SR = self.netG.module.sample(batch_size, continous)
            else:
                self.SR = self.netG.sample(batch_size, continous)
        self.netG.train()

    def set_loss(self):
        if isinstance(self.netG, nn.DataParallel):
            self.netG.module.set_loss(self.device)
        else:
            self.netG.set_loss(self.device)

        if not self.opt['uncertainty_train']:
            if isinstance(self.netGU, nn.DataParallel):
                self.netGU.module.set_loss(self.device)
            else:
                self.netGU.set_loss(self.device)

    def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
        # if self.schedule_phase is None or self.schedule_phase != schedule_phase:

        self.schedule_phase = schedule_phase
        if isinstance(self.netG, nn.DataParallel):
            self.netG.module.set_new_noise_schedule(
                schedule_opt, self.device)
        else:
            self.netG.set_new_noise_schedule(schedule_opt, self.device)

        if not self.opt['uncertainty_train']:
            if isinstance(self.netGU, nn.DataParallel):
                self.netGU.module.set_new_noise_schedule(
                    schedule_opt, self.device)
            else:
                self.netGU.set_new_noise_schedule(schedule_opt, self.device)

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_LR=True, sample=False):
        out_dict = OrderedDict()
        if sample:
            out_dict['SAM'] = self.SR.detach().float().cpu()
        else:
            out_dict['HQ'] = self.SR.detach().float().cpu()
            # out_dict['UT'] = self.utmap.detach().float().cpu()
            # out_dict['Ill'] = self.data['Ill'].detach().float().cpu()
            out_dict['INF'] = self.data['LQ'].detach().float().cpu()
            out_dict['GT'] = self.data['GT'].detach()[0].float().cpu()
            if need_LR and 'LR' in self.data:
                out_dict['LQ'] = self.data['LQ'].detach().float().cpu()
            else:
                out_dict['LQ'] = out_dict['INF']
        return out_dict

    def print_network(self):
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)

        logger.info(
            'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
        logger.info(s)

    def save_network(self, epoch, iter_step):
        gen_path = os.path.join(
            self.opt['path']['checkpoint'], 'I{}_E{}_gen.pth'.format(iter_step, epoch))
        opt_path = os.path.join(
            self.opt['path']['checkpoint'], 'I{}_E{}_opt.pth'.format(iter_step, epoch))
        
        # gen
        network = self.netG
        if isinstance(self.netG, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, gen_path)
        # opt
        opt_state = {'epoch': epoch, 'iter': iter_step,
                     'scheduler': None, 'optimizer': None}
        opt_state['optimizer'] = self.optG.state_dict()
        torch.save(opt_state, opt_path)

        if self.opt['uncertainty_train']:
            ut_gen_path = os.path.join(
                './checkpoints/uncertainty/', 'latest_gen.pth'.format(iter_step, epoch))
            ut_opt_path = os.path.join(
                './checkpoints/uncertainty/', 'latest_opt.pth'.format(iter_step, epoch))
            torch.save(state_dict, ut_gen_path)
            torch.save(opt_state, ut_opt_path)

        logger.info(
            'Saved model in [{:s}] ...'.format(gen_path))

    def load_network(self):
        load_path = self.opt['path']['resume_state']
        if load_path is not None:
            logger.info(
                'Loading pretrained model for G [{:s}] ...'.format(load_path))
            gen_path = '{}_gen.pth'.format(load_path)
            opt_path = '{}_opt.pth'.format(load_path)
            # gen
            network = self.netG
            if isinstance(self.netG, nn.DataParallel):
                network = network.module

            # network = nn.DataParallel(network).cuda()

            network.load_state_dict(torch.load(
                gen_path), strict=(not self.opt['model']['finetune_norm']))
            # network.load_state_dict(torch.load(
            #     gen_path), strict=False)
            if self.opt['phase'] == 'train':
                # optimizer
                opt = torch.load(opt_path)
                self.optG.load_state_dict(opt['optimizer'])
                #self.begin_step = opt['iter']
                #self.begin_epoch = opt['epoch']
                self.begin_step = 0
                self.begin_epoch = 0
