import numpy as np
import torch as t
import torchvision.transforms as tr
from torchvision.datasets import MNIST,FashionMNIST, KMNIST, EMNIST
from utils import *


color_set_mnist = t.tensor([    [0,             0.50196078,     0.50196078 ],   #0
                                [0.99215686,    0.91372549,     0.0627451  ],   #1
                                [0,             0.58431373,     0.71372549 ],   #2
                                [0.929411765,   0.568627451,    0.129411765],   #3
                                [0.568627451,   0.117647059,    0.737254902],   #4
                                [0.274509804,   0.941176471,    0.941176471],   #5
                                [0.980392157,   0.77254902,     0.733333333],   #6
                                [0.823529412,   0.960784314,    0.235294118],   #7
                                [0.501960784,   0,              0          ],   #8
                                [0.8627451,     0.07843137,     0.23529412 ]])  #9


color_set_fmn = t.tensor([  [0.99215686,    0.91372549,     0.0627451  ],   #0
                            [0,             0.58431373,     0.71372549 ],   #1
                            [0.929411765,   0.568627451,    0.129411765],   #2
                            [0.568627451,   0.117647059,    0.737254902],   #3
                            [0.274509804,   0.941176471,    0.941176471],   #4
                            [0.980392157,   0.77254902,     0.733333333],   #5
                            [0.823529412,   0.960784314,    0.235294118],   #6
                            [0.501960784,   0,              0          ],   #7
                            [0.8627451,     0.07843137,     0.23529412 ],   #8
                            [0,             0.50196078,     0.50196078 ]])  #9


color_set_kmn = t.tensor([  [0,             0.58431373,     0.71372549 ],   #0
                            [0.929411765,   0.568627451,    0.129411765],   #1
                            [0.568627451,   0.117647059,    0.737254902],   #2
                            [0.274509804,   0.941176471,    0.941176471],   #3
                            [0.980392157,   0.77254902,     0.733333333],   #4
                            [0.823529412,   0.960784314,    0.235294118],   #5
                            [0.501960784,   0,              0          ],   #6
                            [0.8627451,     0.07843137,     0.23529412 ],   #7
                            [0,             0.50196078,     0.50196078 ],   #8
                            [0.99215686,    0.91372549,     0.0627451  ]])  #9


color_set_emn = t.tensor([  [0.929411765,   0.568627451,    0.129411765],   #0
                            [0.568627451,   0.117647059,    0.737254902],   #1
                            [0.274509804,   0.941176471,    0.941176471],   #2
                            [0.980392157,   0.77254902,     0.733333333],   #3
                            [0.823529412,   0.960784314,    0.235294118],   #4
                            [0.501960784,   0,              0          ],   #5
                            [0.8627451,     0.07843137,     0.23529412 ],   #6
                            [0,             0.50196078,     0.50196078 ],   #7
                            [0.99215686,    0.91372549,     0.0627451  ],   #8
                            [0,             0.58431373,     0.71372549 ]])  #9





        

def mark_sampling(b_data,b_label,label):
    pos = t.where(b_label == label)[0]
    idx = pos[t.randint(0,len(pos),(1,))]
    return b_data[idx]
    


def biasing(dataset, b_ratio):
    
    _data =   dataset['mn'].data
    _bdata0 = dataset['fmn'].data
    _bdata1 = dataset['kmn'].data
    _bdata2 = dataset['emn'].data
    
    _label =   dataset['mn'].targets
    _blabel0 = dataset['fmn'].targets
    _blabel1 = dataset['kmn'].targets
    _blabel2 = dataset['emn'].targets-1
    
    
    midx = t.rand((len(_label),3)) < b_ratio
    cidx = t.rand((len(_label),4)) < b_ratio
    
    b_label0 = t.zeros_like(_label)
    b_label1 = t.zeros_like(_label)
    b_label2 = t.zeros_like(_label)
    
    c_label = t.zeros_like(_label)
    c_label0 = t.zeros_like(_label)
    c_label1 = t.zeros_like(_label)
    c_label2 = t.zeros_like(_label)

    data = t.zeros((len(_label), 3, 56, 56))
    for idx in range(len(_label)):
        if (idx+1) % int(0.1*len(_label)) == 0:
            print("%3d / %3d Done..."%(idx+1, len(_label)))


        if not cidx[idx,0]:
            c_label[idx] = _label[idx]
        else:
            while(True):
                rand_c_label = t.randint(0,10,(1,))
                if rand_c_label != _label[idx]:
                    break
            c_label[idx] = rand_c_label
        color = color_set_mnist[c_label[idx]]
        data[idx,:, :28,:28] = t.clamp((_data[idx].unsqueeze(2).repeat(1,1,3).float()*color)/255.,0.,1.).permute((2,0,1))

        # Mark 1 (Fashion MNIST)
        if not midx[idx,0]:
            mark = mark_sampling(_bdata0, _blabel0, _label[idx])
            b_label0[idx] = _label[idx]
        else:
            while(True):
                rand_label = t.randint(0,10,(1,))
                if rand_label != _label[idx]:
                    break
            mark = mark_sampling(_bdata0,_blabel0,rand_label)
            b_label0[idx] = rand_label
        
        if not cidx[idx,1]:
            c_label0[idx] = _label[idx]
        else:
            while(True):
                rand_c_label = t.randint(0,10,(1,))
                if rand_c_label != _label[idx]:
                    break
            c_label0[idx] = rand_c_label
        
        color = color_set_fmn[c_label0[idx]].repeat(28,28,1).permute((2,0,1))
        data[idx,:, :28,28:] = t.clamp((mark.repeat(3,1,1).float()*color)/255.,0.,1.)

        # Mark 2 (KMNIST)
        if not midx[idx,1]:
            mark = mark_sampling(_bdata1, _blabel1, _label[idx])
            b_label1[idx] = _label[idx]
        else:
            while(True):
                rand_label = t.randint(0,10,(1,))
                if rand_label != _label[idx]:
                    break
            mark = mark_sampling(_bdata1,_blabel1,rand_label)
            b_label1[idx] = rand_label
        
        if not cidx[idx,2]:
            c_label1[idx] = _label[idx]
        else:
            while(True):
                rand_c_label = t.randint(0,10,(1,))
                if rand_c_label != _label[idx]:
                    break
            c_label1[idx] = rand_c_label
        
        color = color_set_kmn[c_label1[idx]].repeat(28,28,1).permute((2,0,1))
        data[idx,:, 28:,:28] = t.clamp((mark.repeat(3,1,1).float()*color)/255.,0.,1.)

        # Mark 2 (EMNIST)
        if not midx[idx,2]:
            mark = mark_sampling(_bdata2, _blabel2, _label[idx])
            b_label2[idx] = _label[idx]
        else:
            while(True):
                rand_label = t.randint(0,10,(1,))
                if rand_label != _label[idx]:
                    break
            mark = mark_sampling(_bdata0,_blabel0,rand_label)
            b_label2[idx] = rand_label
        
        if not cidx[idx,3]:
            c_label2[idx] = _label[idx]
        else:
            while(True):
                rand_c_label = t.randint(0,10,(1,))
                if rand_c_label != _label[idx]:
                    break
            c_label2[idx] = rand_c_label
        
        color = color_set_emn[c_label2[idx]].repeat(28,28,1).permute((2,0,1))
        data[idx,:, 28:,28:] = t.clamp((mark.repeat(3,1,1).float()*color)/255.,0.,1.)


    b_label = {}
    b_label['obj_bias0'] = b_label0
    b_label['obj_bias1'] = b_label1
    b_label['obj_bias2'] = b_label2
    b_label['col_bias']  = c_label
    b_label['col_bias0'] = c_label0
    b_label['col_bias1'] = c_label1
    b_label['col_bias2'] = c_label2

    print(_label)
    print(b_label)
    return data, _label, b_label





def biased_mnist_gen(args):
    train_valid_split = 0.9
    
    dset = {}
    dset['mn'] = MNIST(args.data_storage,train = True, download = True)
    dset['fmn'] = FashionMNIST(args.data_storage,train=True, download=True)
    dset['kmn'] = KMNIST(args.data_storage,train=True,download=True)
    dset['emn'] = EMNIST(args.data_storage,split='letters',train=True,download=True)
    

    for b_ratio in args.bias_ratio:
        ret,train,valid = {},{},{}

        data, label, b_label = biasing(dset,b_ratio)
        train['data'] = data[:int(len(label)*train_valid_split)]
        train['label'] = label[:int(len(label)*train_valid_split)]

        train['b_label'] = {}
        train['b_label']['obj_bias0'] = b_label['obj_bias0'][:int(len(label)*train_valid_split)]
        train['b_label']['obj_bias1'] = b_label['obj_bias0'][:int(len(label)*train_valid_split)]
        train['b_label']['obj_bias2'] = b_label['obj_bias0'][:int(len(label)*train_valid_split)]
        train['b_label']['col_bias']  = b_label['col_bias'][:int(len(label)*train_valid_split)]
        train['b_label']['col_bias0'] = b_label['col_bias0'][:int(len(label)*train_valid_split)]
        train['b_label']['col_bias1'] = b_label['col_bias1'][:int(len(label)*train_valid_split)]
        train['b_label']['col_bias2'] = b_label['col_bias2'][:int(len(label)*train_valid_split)]
        
        valid['data'] = data[int(len(label)*train_valid_split):]
        valid['label'] = label[int(len(label)*train_valid_split):]
        
        valid['b_label'] = {}
        valid['b_label']['obj_bias0'] = b_label['obj_bias0'][int(len(label)*train_valid_split):]
        valid['b_label']['obj_bias1'] = b_label['obj_bias0'][int(len(label)*train_valid_split):]
        valid['b_label']['obj_bias2'] = b_label['obj_bias0'][int(len(label)*train_valid_split):]
        valid['b_label']['col_bias']  = b_label['col_bias'][int(len(label)*train_valid_split):]
        valid['b_label']['col_bias0'] = b_label['col_bias0'][int(len(label)*train_valid_split):]
        valid['b_label']['col_bias1'] = b_label['col_bias1'][int(len(label)*train_valid_split):]
        valid['b_label']['col_bias2'] = b_label['col_bias2'][int(len(label)*train_valid_split):]
        

        ret['train'] = train
        ret['valid'] = valid
        
        data_name = args.data+'_bias_'+str(b_ratio)
        save_data(ret, args.save_dir+data_name)


    dset={}
    ret = {}
    dset['mn'] = MNIST(args.data_storage,train = False, download = True)
    dset['fmn'] = FashionMNIST(args.data_storage,train=False, download=True)
    dset['kmn'] = KMNIST(args.data_storage,train=False,download=True)
    dset['emn'] = EMNIST(args.data_storage,split='letters',train=False,download=True)
    
    data, label, b_label = biasing(dset,0.9)
    label = label.clone()
    ret['data'] = data
    ret['label'] = label
    ret['label'] = label
    ret['b_label'] = b_label

    data_name = args.data + '_test'
    save_data(ret, args.save_dir+data_name)

    

