import torch
import logging
import models.modules.SRResNet_arch as SRResNet_arch
import models.modules.discriminator_vgg_arch as SRGAN_arch
import models.modules.RRDBNet_arch as RRDBNet_arch
import models.modules.dan_arch as DAN_arch
logger = logging.getLogger('base')


####################
# define network
####################
#### Generator
def define_G(opt):
    opt_net = opt['network_G']
    which_model = opt_net['which_model_G']

    if which_model == 'MSRResNet':
        netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                       nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
    elif which_model == 'RRDBNet':
        netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                    nf=opt_net['nf'], nb=opt_net['nb'])
    elif which_model == 'Predictor':
        netG = sftmd_arch.Predictor(
            in_nc=opt_net['in_nc'], nf=opt_net['nf'], 
            scale=opt_net['upscale'], code_len=opt_net['code_length'],
            num_blocks=opt_net['num_blocks']
            )
    elif which_model == 'Corrector':
        netG = sftmd_arch.Corrector(in_nc=opt_net['in_nc'], nf=opt_net['nf'], code_len=opt_net['code_length'])
    elif which_model == 'SFTMD':
        netG = sftmd_arch.SFTMD(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['upscale'], input_para=opt_net['code_length'])
    elif which_model == 'DAN':
        
        netG = DAN_arch.DAN(
            nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['upscale'], 
            input_para=opt_net['code_length'],loop=opt_net['loop'],kernel_size=opt['kernel_size'])

    elif which_model == 'SRResNet':
        netG = sftmd_arch.SRResNet()
    elif which_model == 'SFTMD_DEMO':
        netG = sftmd_arch.SFTMD_DEMO(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
                                nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['upscale'], input_para=opt_net['code_length'])
    # elif which_model == 'sft_arch':  # SFT-GAN
    #     netG = sft_arch.SFT_Net()
    else:
        raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
    return netG


#### Discriminator
def define_D(opt):
    opt_net = opt['network_D']
    which_model = opt_net['which_model_D']

    if which_model == 'discriminator_vgg_128':
        netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
    else:
        raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
    return netD


#### Define Network used for Perceptual Loss
def define_F(opt, use_bn=False):
    gpu_ids = opt['gpu_ids']
    device = torch.device('cuda' if gpu_ids else 'cpu')
    # PyTorch pretrained VGG19-54, before ReLU.
    if use_bn:
        feature_layer = 49
    else:
        feature_layer = 34
    netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
                                          use_input_norm=True, device=device)
    netF.eval()  # No need to train
    return netF
