import torch
import torch.nn as nn
import random
import numpy as np
import torch.optim as optim
import sys
import os



device = "cuda" if torch.cuda.is_available() else "cpu"


def get_model(model_type, input_dim, n_class):
    if model_type == 'linear':
        return nn.Linear(input_dim, n_class)
    elif model_type.split('+')[0] == 'MLP':
        inner_dim = int(model_type.split('+')[1])
        return nn.Sequential(
                #nn.Dropout(p=0.1),
                nn.Linear(input_dim, inner_dim),
                nn.ReLU(),
                #nn.Dropout(p=0.1),
                nn.Linear(inner_dim, n_class)
            )
    elif model_type.split('+')[0] == 'MLPdropout':
        inner_dim = int(model_type.split('+')[1])
        return nn.Sequential(
                nn.Dropout(p=0.1),
                nn.Linear(input_dim, inner_dim),
                nn.ReLU(),
                nn.Dropout(p=0.1),
                nn.Linear(inner_dim, n_class)
            )


    else:
        raise Exception('check model_type')
    


class Learner():
    def __init__(self):

        self.model_state_dict = {}

        self.feat_mean = {}
        self.feat_std = {}
        self.predictions = {}


    def train(self, dataset, key = "", init_key = "", seed = 0, input_dim = 768, n_class = 41, inputs = 'headline+short_description',
        batch_size = 256, num_workers = 4, epochs = 50, lr = 0.01, wd = 5e-4, optimizer_name = 'SGD', debug = False, normalize = True, model_type = 'linear'):
        
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
        os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
        
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

        inputs = inputs.split('+')


        model = get_model(model_type, input_dim = input_dim * len(inputs), n_class = n_class)
        
        model = model.to(device)

        if len(init_key) > 0:
            model.load_state_dict(self.model_state_dict[init_key])

        model.train()

        trainloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = False, num_workers = num_workers)

        feat_mean = torch.zeros(input_dim * len(inputs)).to(device)
        feat_std = torch.zeros(input_dim * len(inputs)).to(device)
        class_cnt = torch.zeros(n_class)
        tot_cnt = 0

        for batch_idx, (headline, short_description, targets) in enumerate(trainloader):
            _bs = targets.size(0)

            for input_idx, input_name in enumerate(inputs):
                input_feature = headline.view(_bs, -1) if input_name == 'headline' else short_description.view(_bs, -1)

                features = input_feature if input_idx == 0 else torch.cat([features, input_feature], dim = 1)

            for c in targets:
                class_cnt[c] += 1

            feat_mean = feat_mean + features.to(device).sum(dim = 0)
            tot_cnt += _bs

        class_weight = class_cnt.mean() / class_cnt.clamp(min = 1.)

        feat_mean = feat_mean / tot_cnt

        for batch_idx, (headline, short_description, targets) in enumerate(trainloader):
            _bs = targets.size(0)

            for input_idx, input_name in enumerate(inputs):
                input_feature = headline.view(_bs, -1) if input_name == 'headline' else short_description.view(_bs, -1)

                features = input_feature if input_idx == 0 else torch.cat([features, input_feature], dim = 1)

            feat_std = feat_std + ((features.to(device) - feat_mean.view(1, -1)) ** 2).sum(dim = 0)

        feat_std = torch.sqrt(feat_std / (tot_cnt - 1))
        
        
        criterion = nn.CrossEntropyLoss(weight = class_weight.to(device))

        print (class_cnt)
        print (class_weight)

        if optimizer_name == 'SGD':
            optimizer = optim.SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay = wd)
        elif optimizer_name == 'AdamW':
            optimizer = optim.AdamW(model.parameters(), lr = lr, weight_decay = wd)
        elif optimizer_name == 'Adam':
            optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay = wd)
        else:
            raise Exception("check optimizer_name")

        for epoch in range(epochs):
            tot_loss = 0.
            correct = 0.
            total = 0
            total_batch = 0

            for batch_idx, (headline, short_description, targets) in enumerate(trainloader):
                _bs = targets.size(0)

                for input_idx, input_name in enumerate(inputs):
                    input_feature = headline.view(_bs, -1) if input_name == 'headline' else short_description.view(_bs, -1)

                    features = input_feature if input_idx == 0 else torch.cat([features, input_feature], dim = 1)

                features, targets = features.to(device), targets.to(device)

                if normalize:
                    features = (features - feat_mean) / feat_std

                outputs = model(features)
                loss = criterion(outputs, targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                _, predicted = torch.max(outputs, 1)
                correct += predicted.eq(targets).sum().cpu()
                tot_loss += loss.item()
                total += targets.size(0)
                total_batch += 1

                if debug:
                    print (batch_idx, loss.item(), flush=True)


                
            print ('[train] epoch: %d | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (epoch + 1, tot_loss / total_batch, 100. * float(correct) / total, correct, total), flush = True)
        
        self.model_state_dict[key] = model.state_dict()
        self.feat_mean[key] = feat_mean
        self.feat_std[key] = feat_std

    def eval(self, dataset, key = "", input_dim=768, n_class = 41, inputs = 'headline+short_description', batch_size = 256, num_workers = 4, normalize = True, model_type = 'linear', save = False):
        inputs = inputs.split('+')

        model = get_model(model_type, input_dim = input_dim * len(inputs), n_class = n_class)

        model = model.to(device)
        model.load_state_dict(self.model_state_dict[key])
        model.eval()

        feat_mean = self.feat_mean[key]
        feat_std = self.feat_std[key]

        tot_loss = 0.
        correct = 0.
        total = 0
        total_batch = 0

        criterion = nn.CrossEntropyLoss()

        testloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = False, drop_last = False, num_workers = num_workers)


        if save:
            predictions = torch.zeros(len(dataset)).type(torch.int)
            start_idx = 0


        for batch_idx, (headline, short_description, targets) in enumerate(testloader):

            _bs = targets.size(0)

            for input_idx, input_name in enumerate(inputs):
                input_feature = headline.view(_bs, -1) if input_name == 'headline' else short_description.view(_bs, -1)

                features = input_feature if input_idx == 0 else torch.cat([features, input_feature], dim = 1)


            features, targets = features.to(device), targets.to(device)

            if normalize:
                features = (features - feat_mean) / feat_std


            with torch.no_grad():
                outputs = model(features)
                _, predicted = torch.max(outputs, 1)

                loss = criterion(outputs, targets)

                if save:
                    predictions[start_idx : start_idx + _bs] = predicted.cpu()
                    start_idx += _bs
                
                correct += predicted.eq(targets).sum().cpu()
                tot_loss += loss.item()
                total += targets.size(0)
                total_batch += 1

        if save:
            self.predictions[key] = predictions

        return 'Loss: %.3f | Acc: %.3f%% (%d/%d)\n' % (tot_loss / total_batch, 100. * float(correct) / total, correct, total)

        


                
        

        


    
    

        
