from src.augmentations.cutout import Cutout
from src.methods.base_method import TrainBaseMethod
from src.datasets import WaterbirdsDataset
from src.utils import load_checkpoint
from src.utils import AverageMeter, calculate_accuracy
from src.utils import change_column_value_of_existing_row
from tqdm import tqdm

import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np

import torch
import os
import shutil


class WaterbirdsTrain(TrainBaseMethod):
    def __init__(self, args) -> None:
        self.clean_train_data_dir = os.path.join(
            args.base_dir, "datasets", "Waterbirds", "images", "train")
        self.clean_val_data_dir = os.path.join(
            args.base_dir, "datasets", "Waterbirds", "images", "val")
        super().__init__(args)

    def prepare_data_loaders(self, train=True) -> None:
        self.std = np.reshape([0.229, 0.224, 0.225], [3, 1, 1])
        self.mean = np.reshape([0.485, 0.456, 0.406], [3, 1, 1])
        scale = 256.0/224.0
        target_resolution = (224, 224)
        self.transform_test = transforms.Compose([
                transforms.Resize(
                    (int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
                transforms.CenterCrop(target_resolution),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [
                                     0.229, 0.224, 0.225])
            ])
        if train:
            self.transform_train = transforms.Compose([
                transforms.RandomResizedCrop(
                    target_resolution,
                    scale=(0.7, 1.0),
                    ratio=(0.75, 1.3333333333333333),
                    interpolation=2),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [
                                     0.229, 0.224, 0.225])
            ])
            self.transform_data_to_mask = transforms.Compose([
                transforms.Resize(
                    (int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
                transforms.CenterCrop(target_resolution),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [
                                     0.229, 0.224, 0.225])
            ])
            if self.args.cutout:
                self.transform_train.transforms.append(
                    Cutout(n_holes=self.args.n_holes, length=self.args.length))
            self.train_dataset = WaterbirdsDataset(raw_data_path=self.args.dataset_dir, root=os.path.join(
                self.args.base_dir, 'datasets', 'Waterbirds'), split='train', transform=self.transform_train)
            self.train_loader = torch.utils.data.DataLoader(
                self.train_dataset, batch_size=self.args.train_batch, shuffle=True, num_workers=self.args.workers)
            self.val_dataset = WaterbirdsDataset(raw_data_path=self.args.dataset_dir, root=os.path.join(
                self.args.base_dir, 'datasets', 'Waterbirds'), split='val', transform=self.transform_test, return_places=self.args.use_worst_group_acc)
            self.data_to_mask_dataset = WaterbirdsDataset(raw_data_path=self.args.dataset_dir, root=os.path.join(
                self.args.base_dir, 'datasets', 'Waterbirds'), split='train', transform=self.transform_data_to_mask)
            self.val_loader = torch.utils.data.DataLoader(
                self.val_dataset, batch_size=self.args.test_batch, shuffle=False, num_workers=self.args.workers)
            self.data_to_mask_loader = torch.utils.data.DataLoader(
                self.data_to_mask_dataset, batch_size=self.args.masking_batch_size, shuffle=True, num_workers=self.args.workers)
        else:
            self.test_dataset = WaterbirdsDataset(raw_data_path=self.args.dataset_dir, root=os.path.join(
                self.args.base_dir, 'datasets', 'Waterbirds'), split='test', transform=self.transform_test, return_places=True)

            self.test_loader = torch.utils.data.DataLoader(
                self.test_dataset, batch_size=self.args.test_batch, shuffle=False, num_workers=self.args.workers)

    def run_an_epoch_with_group(self, data_loader, epoch, train=False, val_or_test="val"):
        if train:
            self.model.train()
        else:
            self.model.eval()
        losses = AverageMeter()
        accuracies = AverageMeter()
        if train:
            progress_bar_description = 'Epoch ' + str(epoch)
        else:
            progress_bar_description = val_or_test
        all_predictions = []
        all_aux_labels = []
        all_labels = []
        with torch.set_grad_enabled(train):
            progress_bar = tqdm(data_loader)
            for data in progress_bar:
                progress_bar.set_description(progress_bar_description)
                inputs, labels, aux_labels = data[0], data[2], data[-1]
                inputs, labels = inputs.to(
                    self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.loss_function(outputs, labels)
                losses.update(loss.item(), inputs.size(0))
                output_probabilities = F.softmax(outputs, dim=1)
                probabilities, predictions = output_probabilities.data.max(1)
                accuracies.update(calculate_accuracy(labels, predictions), 1)
                all_predictions.append(predictions.detach().cpu())
                all_aux_labels.append(aux_labels)
                all_labels.append(labels.detach().cpu())
                if train:
                    self.optimize(loss=loss)
                progress_bar.set_postfix(
                    {
                        "loss": losses.avg,
                        "accuracy": accuracies.avg,
                    }
                )
        all_predictions = torch.cat(all_predictions)
        all_aux_labels = torch.cat(all_aux_labels)
        all_labels = torch.cat(all_labels)
        groups = {
            0: [],
            1: [],
            2: [],
            3: [],
        }
        for aux_label, label, prediction in zip(all_aux_labels, all_labels, all_predictions):
            groups[2*aux_label.item()+label.item()].append(label.item()
                                                           == prediction.item())
        weighted_acc = 0
        accuracies = []
        for group_id, group_predictions in groups.items():
            accuracy = sum(group_predictions)/len(group_predictions)
            accuracies.append(accuracy)
            print(f"accuracy of group {group_id+1}: ", accuracy)
            weighted_acc += accuracy*len(group_predictions)
        weighted_acc /= len(all_predictions)
        print("average accuracy", weighted_acc)
        return min(accuracies)

    def test(self, checkpoint_path=None):
        self.prepare_data_loaders(train=False)
        self.prepare_model(arch=self.args.arch)
        self.model = self.model.to(self.device)
        print("-" * 10, "testing the model", "-" * 10)
        if checkpoint_path is None:
            if not os.path.isfile(os.path.join(self.model_save_dir, self.args.checkpoint_name)):
                shutil.copy(
                    os.path.join(
                        self.args.saved_checkpoint_dir, self.args.checkpoint_name
                    ),
                    os.path.join(
                        self.model_save_dir, self.args.checkpoint_name
                    )
                )
            checkpoint_path = os.path.join(
                self.model_save_dir, self.args.checkpoint_name
            )
        (
            self.model,
            _,
            _,
            _,
        ) = load_checkpoint(
            model=self.model,
            optimizer=None,
            lr_scheduler=None,
            checkpoint_path=checkpoint_path
        )
        self.model.eval()
        worst_acuracy = self.run_an_epoch_with_group(
            self.test_loader, epoch=0, train=False, val_or_test="test")
        change_column_value_of_existing_row(
            "accuracy",
            worst_acuracy,
            self.run_configs_file_path,
            self.run_id,
        )
