import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
from utils import PUDataGenerator
from losses import HingeLoss
import os
import pickle
import argparse

parser = argparse.ArgumentParser(description='pu_learning')

parser.add_argument('--SEED', default=123, type=int)
parser.add_argument('--total_epochs', default=40, type=int)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--lr_0', default=1e-3, type=float, help='learning rate for w_hat in smag')
parser.add_argument('--lr_1', default=1e-3, type=float, help='learning rate for w in smag')
parser.add_argument('--decay_epochs', default=[12, 24])
parser.add_argument('--decay_factor', default=10)
parser.add_argument('--optimizer', default='smag', type=str, help='smag, sgd, sbcd, sdca, ssdc_spg, ssdc_adagrad')
parser.add_argument('--gamma', default=1, type=float)
parser.add_argument('--dataset', default='mnist', type=str, help='mnist, fashion_mnist, cifar10, fer2013')
parser.add_argument('--pi_p', default=0.5, type=float)
parser.add_argument('--inner_iter_num', default=10, type=int)
parser.add_argument('--double_decay', default=0, type=int)


def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def zero_grad(model):
    for name, p in model.named_parameters():
        if p.grad is not None:
            p.grad.data.zero_()

class LinearModel(nn.Module):
    def __init__(self, feature_num):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(feature_num, 1)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input tensor
        return self.fc(x)

class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
        self.images = images.astype(np.uint8)
        self.targets = targets
        self.mode = mode
        self.transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomCrop((crop_size, crop_size), padding=None),
            transforms.RandomHorizontalFlip(),
            transforms.Resize((image_size, image_size)),
        ])
        self.transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((image_size, image_size)),
        ])

        # for loss function
        self.pos_indices = np.flatnonzero(targets == 1)
        self.pos_index_map = {}
        for i, idx in enumerate(self.pos_indices):
            self.pos_index_map[idx] = i

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            idx = self.pos_index_map[idx] if idx in self.pos_indices else -1
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target


def obj_evaluation(model_1, hinge_loss, train_pos_loader, train_unl_loader, pi_p):
    # Eveluation of training objective
    train_loss_1 = 0
    total = 0
    with torch.no_grad():
        for images, labels in train_pos_loader:
            outputs = model_1(images)
            labels_pos_pos = torch.ones(images.size(0))
            labels_pos_neg = -1 * torch.ones(images.size(0))
            train_loss_1 += hinge_loss(outputs, labels_pos_pos) - hinge_loss(outputs, labels_pos_neg)
            total += labels.size(0)
    train_loss_1 = pi_p * train_loss_1 / total

    train_loss_2 = 0
    total = 0
    with torch.no_grad():
        for images, labels in train_unl_loader:
            outputs = model_1(images)
            labels_unl_neg = -1 * torch.ones(images.size(0))
            train_loss_2 += hinge_loss(outputs, labels_unl_neg)
            total += labels.size(0)
    train_loss_2 = train_loss_2 / total

    train_loss = 0
    if train_loss_1 != 0:
        train_loss += train_loss_1.item()
    if train_loss_2 != 0:
        train_loss += train_loss_2.item()
    return train_loss

def _acc_evaluation(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            outputs = model(images)
            predicted = torch.where(outputs < 0, -1, 1)
            labels = labels.view(-1, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def acc_evaluation(model, valid_loader, test_loader):
    accuracy_val = _acc_evaluation(model, valid_loader)
    accuracy_test = _acc_evaluation(model, test_loader)
    return accuracy_val, accuracy_test


def main():
    # Parameters
    args = parser.parse_args()
    print(args)

    dataset = args.dataset
    num_epochs = args.total_epochs
    lr_0 = args.lr_0
    lr_1 = args.lr_1
    batch_size = args.batch_size
    SEED = args.SEED
    decay_epochs = args.decay_epochs
    decay_factor = args.decay_factor
    pi_p = args.pi_p
    gamma = args.gamma
    optimizer_option = args.optimizer
    inner_iter_num = args.inner_iter_num
    double_decay = args.double_decay

    set_all_seeds(SEED)

    if dataset in ['mnist', 'fashion_mnist']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize pixel values to [-1, 1]
        ])

        if dataset == 'mnist':
            train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
            test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
        else:
            train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
            test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

        generator = PUDataGenerator(shuffle=True, random_seed=SEED)
        (train_pos_images, train_pos_labels, train_unl_images, train_unl_labels) = generator.transform(train_dataset.data,
                                                                                                       train_dataset.targets,
                                                                                                       pu_ratio=0.5,
                                                                                                       train=True)
        (valid_images, valid_labels, _, _) = generator.transform(test_dataset.data[:5000], test_dataset.targets[:5000], train=False)
        (test_images, test_labels, _, _) = generator.transform(test_dataset.data[5000:], test_dataset.targets[5000:], train=False)

        trainPosDataset = ImageDataset(train_pos_images, train_pos_labels, image_size=28, crop_size=28)
        trainUnlDataset = ImageDataset(train_unl_images, train_unl_labels, image_size=28, crop_size=28)
        validDataset = ImageDataset(valid_images, valid_labels, mode='test', image_size=28, crop_size=28)
        testDataset = ImageDataset(test_images, test_labels, mode='test', image_size=28, crop_size=28)

        sampler = None
        train_pos_loader = torch.utils.data.DataLoader(trainPosDataset, batch_size, sampler=sampler, shuffle=False,
                                                       num_workers=0)
        train_unl_loader = torch.utils.data.DataLoader(trainUnlDataset, batch_size, sampler=sampler, shuffle=False,
                                                       num_workers=0)
        valid_loader = torch.utils.data.DataLoader(validDataset, batch_size=batch_size, shuffle=False, num_workers=0)
        test_loader = torch.utils.data.DataLoader(testDataset, batch_size=batch_size, shuffle=False, num_workers=0)

        feature_num = 28 * 28
    elif dataset == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize pixel values to [-1, 1]
        ])

        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

        generator = PUDataGenerator(shuffle=True, random_seed=SEED)
        (train_pos_images, train_pos_labels, train_unl_images, train_unl_labels) = generator.transform(
            train_dataset.data,
            train_dataset.targets,
            pu_ratio=0.5,
            train=True)
        (valid_images, valid_labels, _, _) = generator.transform(test_dataset.data[:5000], test_dataset.targets[:5000],
                                                                 train=False)
        (test_images, test_labels, _, _) = generator.transform(test_dataset.data[5000:], test_dataset.targets[5000:],
                                                               train=False)

        trainPosDataset = ImageDataset(train_pos_images, train_pos_labels, image_size=32, crop_size=32)
        trainUnlDataset = ImageDataset(train_unl_images, train_unl_labels, image_size=32, crop_size=32)
        validDataset = ImageDataset(valid_images, valid_labels, mode='test', image_size=32, crop_size=32)
        testDataset = ImageDataset(test_images, test_labels, mode='test', image_size=32, crop_size=32)

        sampler = None
        train_pos_loader = torch.utils.data.DataLoader(trainPosDataset, batch_size, sampler=sampler, shuffle=False,
                                                       num_workers=0)
        train_unl_loader = torch.utils.data.DataLoader(trainUnlDataset, batch_size, sampler=sampler, shuffle=False,
                                                       num_workers=0)
        valid_loader = torch.utils.data.DataLoader(validDataset, batch_size=batch_size, shuffle=False, num_workers=0)
        test_loader = torch.utils.data.DataLoader(testDataset, batch_size=batch_size, shuffle=False, num_workers=0)

        feature_num = 3 * 32 * 32
    elif dataset == 'fer2013':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize pixel values to [-1, 1]
        ])

        train_dataset = datasets.FER2013(root='./data', split='train', transform=transform)
        train_image, train_target = [], []
        test_image, test_target = [], []
        for ind, data in enumerate(train_dataset._samples):
            image, target = data
            if ind < 25709:
                train_image.append(image)
                train_target.append(target)
            else:
                test_image.append(image)
                test_target.append(target)

        generator = PUDataGenerator(shuffle=True, random_seed=SEED, dataset='fer2013')
        (train_pos_images, train_pos_labels, train_unl_images, train_unl_labels) = generator.transform(
            train_image,
            train_target,
            pu_ratio=0.5,
            train=True)
        (valid_images, valid_labels, _, _) = generator.transform(test_image[:1500], test_target[:1500],
                                                                 train=False)
        (test_images, test_labels, _, _) = generator.transform(test_image[1500:], test_target[1500:],
                                                               train=False)

        trainPosDataset = ImageDataset(train_pos_images, train_pos_labels, image_size=48, crop_size=48)
        trainUnlDataset = ImageDataset(train_unl_images, train_unl_labels, image_size=48, crop_size=48)
        validDataset = ImageDataset(valid_images, valid_labels, mode='test', image_size=48, crop_size=48)
        testDataset = ImageDataset(test_images, test_labels, mode='test', image_size=48, crop_size=48)

        sampler = None
        train_pos_loader = torch.utils.data.DataLoader(trainPosDataset, batch_size, sampler=sampler, shuffle=False,
                                                       num_workers=0)
        train_unl_loader = torch.utils.data.DataLoader(trainUnlDataset, batch_size, sampler=sampler, shuffle=False,
                                                       num_workers=0)
        valid_loader = torch.utils.data.DataLoader(validDataset, batch_size=batch_size, shuffle=False, num_workers=0)
        test_loader = torch.utils.data.DataLoader(testDataset, batch_size=batch_size, shuffle=False, num_workers=0)

        feature_num = 48 * 48
    else:
        raise KeyError('Unknown dataset.')


    model_1 = LinearModel(feature_num=feature_num)
    model_2 = LinearModel(feature_num=feature_num)
    model_buf = LinearModel(feature_num=feature_num)
    hinge_loss = HingeLoss()

    train_loss_list = []
    valid_acc_list = []
    test_acc_list = []
    best_test_acc_list = []
    best_vali = float('-inf')
    best_test = float('-inf')

    # Training
    if optimizer_option in ['sgd', 'smag']:
        for epoch in range(num_epochs):
            model_1.train()
            model_2.train()
            if epoch in decay_epochs:
                if double_decay:
                    lr_0 = lr_0 / decay_factor
                    print('Reducing lr_0 to %.5f @ T=%s!' % (lr_0, epoch))
                lr_1 = lr_1 / decay_factor
                print('Reducing lr_1 to %.5f @ T=%s!' % (lr_1, epoch))

            for train_pos_data, _ in train_pos_loader:
                try:
                    train_unl_data, _ = next(train_unl_iter)
                except:
                    train_unl_iter = iter(train_unl_loader)
                    train_unl_data, _ = next(train_unl_iter)

                labels_pos_pos = torch.ones(train_pos_data.size(0))
                labels_pos_neg = -1 * torch.ones(train_pos_data.size(0))
                labels_unl_neg = -1 * torch.ones(train_unl_data.size(0))

                if optimizer_option == 'sgd':
                    outputs_1_pos = model_1(train_pos_data)
                    outputs_1_unl = model_1(train_unl_data)
                    loss = pi_p * hinge_loss(outputs_1_pos, labels_pos_pos) + hinge_loss(outputs_1_unl,
                                                                                         labels_unl_neg) - pi_p * hinge_loss(
                        outputs_1_pos, labels_pos_neg)
                    zero_grad(model_1)
                    grad = torch.autograd.grad(loss, model_1.parameters(), retain_graph=False)
                    for g, w in zip(grad, model_1.parameters()):
                        w.data = w.data - lr_1 * g.data

                elif optimizer_option == 'smag':
                    outputs_1_pos = model_1(train_pos_data)
                    outputs_1_unl = model_1(train_unl_data)
                    loss_1 = pi_p * hinge_loss(outputs_1_pos, labels_pos_pos) + hinge_loss(outputs_1_unl,
                                                                                           labels_unl_neg)
                    zero_grad(model_1)
                    grad_1 = torch.autograd.grad(loss_1, model_1.parameters(), retain_graph=False)

                    for g1, w, w_buf in zip(grad_1, model_1.parameters(), model_buf.parameters()):
                        w.data = w.data - lr_1 * (g1.data + 1 / gamma * (w.data - w_buf.data))

                    outputs_2_pos = model_2(train_pos_data)
                    loss_2 = pi_p * hinge_loss(outputs_2_pos, labels_pos_neg)
                    zero_grad(model_2)
                    if loss_2 != 0:
                        grad_2 = torch.autograd.grad(loss_2, model_2.parameters(), retain_graph=False)
                        for g2, w, w_buf in zip(grad_2, model_2.parameters(), model_buf.parameters()):
                            w.data = w.data - lr_1 * (g2.data + 1 / gamma * (w.data - w_buf.data))
                    else:
                        for w, w_buf in zip(model_2.parameters(), model_buf.parameters()):
                            w.data = w.data - lr_1 * (1 / gamma * (w.data - w_buf.data))

                    for w1, w2, w_buf in zip(model_1.parameters(), model_2.parameters(), model_buf.parameters()):
                        w_buf.data = w_buf.data - lr_0 / gamma * (w2.data - w1.data)

            # Evaluation
            train_loss = obj_evaluation(model_1, hinge_loss, train_pos_loader, train_unl_loader, pi_p)
            accuracy_val, accuracy_test = acc_evaluation(model_1, valid_loader, test_loader)

            if accuracy_val > best_vali:
                best_vali = accuracy_val
                best_test = accuracy_test

            train_loss_list.append(train_loss)
            valid_acc_list.append(accuracy_val)
            test_acc_list.append(accuracy_test)
            best_test_acc_list.append(best_test)

            print(
                f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {accuracy_val:.4f}, Test Accuracy: {accuracy_test:.4f}, Best Test Accuracy: {best_test:.4f}')

    elif optimizer_option == 'sbcd':

        epoch_count = 0
        train_pos_iter = iter(train_pos_loader)
        train_unl_iter = iter(train_unl_loader)

        inner_iter_count = 0
        train_loss = 0
        total_ite = 0

        while epoch_count < num_epochs:
            model_1.train()
            model_2.train()
            try:
                train_pos_data, _ = next(train_pos_iter)
            except:
                train_pos_iter = iter(train_pos_loader)

                # Evaluation
                train_loss = obj_evaluation(model_1, hinge_loss, train_pos_loader, train_unl_loader, pi_p)
                accuracy_val, accuracy_test = acc_evaluation(model_1, valid_loader, test_loader)

                if accuracy_val > best_vali:
                    best_vali = accuracy_val
                    best_test = accuracy_test

                train_loss_list.append(train_loss)
                valid_acc_list.append(accuracy_val)
                test_acc_list.append(accuracy_test)
                best_test_acc_list.append(best_test)

                print(
                    f'Epoch [{epoch_count + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {accuracy_val:.4f}, Test Accuracy: {accuracy_test:.4f}, Best Test Accuracy: {best_test:.4f}')
                epoch_count += 1
                if epoch_count in decay_epochs:
                    if double_decay:
                        lr_0 = lr_0 / decay_factor
                        print('Reducing lr_0 to %.5f @ T=%s!' % (lr_0, epoch_count))
                    lr_1 = lr_1 / decay_factor
                    print('Reducing lr_1 to %.5f @ T=%s!' % (lr_1, epoch_count))
                continue

            try:
                train_unl_data, _ = next(train_unl_iter)
            except:
                train_unl_iter = iter(train_unl_loader)
                train_unl_data, _ = next(train_unl_iter)

            if inner_iter_count < inner_iter_num:

                labels_pos_pos = torch.ones(train_pos_data.size(0))
                labels_pos_neg = -1 * torch.ones(train_pos_data.size(0))
                labels_unl_neg = -1 * torch.ones(train_unl_data.size(0))

                outputs_1_pos = model_1(train_pos_data)
                outputs_1_unl = model_1(train_unl_data)
                loss_1 = pi_p * hinge_loss(outputs_1_pos, labels_pos_pos) + hinge_loss(outputs_1_unl, labels_unl_neg)
                zero_grad(model_1)
                grad_1 = torch.autograd.grad(loss_1, model_1.parameters(), retain_graph=False)

                for g1, w, w_buf in zip(grad_1, model_1.parameters(), model_buf.parameters()):
                    w.data = (gamma * w.data + lr_1 * w_buf.data - lr_1 * gamma * g1.data) / (lr_1 + gamma)

                outputs_2_pos = model_2(train_pos_data)
                loss_2 = pi_p * hinge_loss(outputs_2_pos, labels_pos_neg)
                zero_grad(model_2)
                if loss_2 != 0:
                    grad_2 = torch.autograd.grad(loss_2, model_2.parameters(), retain_graph=False)
                    for g2, w, w_buf in zip(grad_2, model_2.parameters(), model_buf.parameters()):
                        w.data = (gamma * w.data + lr_1 * w_buf.data - lr_1 * gamma * g2.data) / (lr_1 + gamma)
                else:
                    for w, w_buf in zip(model_2.parameters(), model_buf.parameters()):
                        w.data = (gamma * w.data + lr_1 * w_buf.data) / (lr_1 + gamma)

                inner_iter_count += 1
            else:
                for w1, w2, w_buf in zip(model_1.parameters(), model_2.parameters(), model_buf.parameters()):
                    w_buf.data = w_buf.data - lr_0 / gamma * (w2.data - w1.data)
                inner_iter_count = 0
    elif optimizer_option == 'sdca':
        epoch_count = 0
        train_pos_iter = iter(train_pos_loader)
        train_unl_iter = iter(train_unl_loader)

        inner_iter_count = 0

        while epoch_count < num_epochs:
            model_1.train()
            model_2.train()
            try:
                train_pos_data, _ = next(train_pos_iter)
            except:
                train_pos_iter = iter(train_pos_loader)

                # Evaluation
                train_loss = obj_evaluation(model_1, hinge_loss, train_pos_loader, train_unl_loader, pi_p)
                accuracy_val, accuracy_test = acc_evaluation(model_1, valid_loader, test_loader)

                if accuracy_val > best_vali:
                    best_vali = accuracy_val
                    best_test = accuracy_test

                train_loss_list.append(train_loss)
                valid_acc_list.append(accuracy_val)
                test_acc_list.append(accuracy_test)
                best_test_acc_list.append(best_test)

                print(
                    f'Epoch [{epoch_count + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {accuracy_val:.4f}, Test Accuracy: {accuracy_test:.4f}, Best Test Accuracy: {best_test:.4f}')
                if epoch_count in decay_epochs:
                    lr_1 = lr_1 / decay_factor
                    print('Reducing lr_1 to %.5f @ T=%s!' % (lr_1, epoch_count))
                epoch_count += 1
                continue

            try:
                train_unl_data, _ = next(train_unl_iter)
            except:
                train_unl_iter = iter(train_unl_loader)
                train_unl_data, _ = next(train_unl_iter)

            if inner_iter_count < inner_iter_num:
                labels_pos_pos = torch.ones(train_pos_data.size(0))
                labels_pos_neg = -1 * torch.ones(train_pos_data.size(0))
                labels_unl_neg = -1 * torch.ones(train_unl_data.size(0))

                if inner_iter_count == 0:
                    outputs_2_pos = model_1(train_pos_data)
                    loss_2 = pi_p * hinge_loss(outputs_2_pos, labels_pos_neg)
                    zero_grad(model_1)
                    if loss_2 != 0:
                        grad_2 = torch.autograd.grad(loss_2, model_1.parameters(), retain_graph=False)

                outputs_1_pos = model_1(train_pos_data)
                outputs_1_unl = model_1(train_unl_data)
                loss_1 = pi_p * hinge_loss(outputs_1_pos, labels_pos_pos) + hinge_loss(outputs_1_unl, labels_unl_neg)
                zero_grad(model_1)
                grad_1 = torch.autograd.grad(loss_1, model_1.parameters(), retain_graph=False)
                if loss_2 != 0:
                    for g1, g2, w, w_buf in zip(grad_1, grad_2, model_1.parameters(), model_buf.parameters()):
                        w.data = w.data - lr_1 * (g1.data - g2.data)
                else:
                    for g1, w, w_buf in zip(grad_1, model_1.parameters(), model_buf.parameters()):
                        w.data = w.data - lr_1 * g1.data

                inner_iter_count += 1
            else:
                inner_iter_count = 0
    elif optimizer_option == 'ssdc_spg':
        epoch_count = 0
        train_pos_iter = iter(train_pos_loader)
        train_unl_iter = iter(train_unl_loader)

        inner_iter_count = 0
        iter_sum = inner_iter_num * (1 + inner_iter_num) / 2

        model_buf.train()
        for w_buf in model_buf.parameters():
            w_buf.data = torch.zeros_like(w_buf.data)
        model_2.train()

        while epoch_count < num_epochs:
            model_1.train()
            try:
                train_pos_data, _ = next(train_pos_iter)
            except:
                train_pos_iter = iter(train_pos_loader)

                # Evaluation
                train_loss = obj_evaluation(model_1, hinge_loss, train_pos_loader, train_unl_loader, pi_p)
                accuracy_val, accuracy_test = acc_evaluation(model_1, valid_loader, test_loader)

                if accuracy_val > best_vali:
                    best_vali = accuracy_val
                    best_test = accuracy_test

                train_loss_list.append(train_loss)
                valid_acc_list.append(accuracy_val)
                test_acc_list.append(accuracy_test)
                best_test_acc_list.append(best_test)

                print(
                    f'Epoch [{epoch_count + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {accuracy_val:.4f}, Test Accuracy: {accuracy_test:.4f}, Best Test Accuracy: {best_test:.4f}')
                if epoch_count in decay_epochs:
                    lr_1 = lr_1 / decay_factor
                    print('Reducing lr_1 to %.5f @ T=%s!' % (lr_1, epoch_count))
                epoch_count += 1
                continue

            try:
                train_unl_data, _ = next(train_unl_iter)
            except:
                train_unl_iter = iter(train_unl_loader)
                train_unl_data, _ = next(train_unl_iter)

            if inner_iter_count < inner_iter_num:
                labels_pos_pos = torch.ones(train_pos_data.size(0))
                labels_pos_neg = -1 * torch.ones(train_pos_data.size(0))
                labels_unl_neg = -1 * torch.ones(train_unl_data.size(0))

                if inner_iter_count == 0:
                    outputs_2_pos = model_2(train_pos_data)
                    loss_2 = pi_p * hinge_loss(outputs_2_pos, labels_pos_neg)
                    zero_grad(model_2)
                    if loss_2 == 0:
                        # grad_2 = torch.zeros_like(model_2.parameters())
                        grad_2 = []
                        for para in model_2.parameters():
                            grad_2.append(torch.zeros_like(para))
                    else:
                        grad_2 = torch.autograd.grad(loss_2, model_2.parameters(), retain_graph=False)

                outputs_1_pos = model_1(train_pos_data)
                outputs_1_unl = model_1(train_unl_data)
                loss_1 = pi_p * hinge_loss(outputs_1_pos, labels_pos_pos) + hinge_loss(outputs_1_unl, labels_unl_neg)
                zero_grad(model_1)
                grad_1 = torch.autograd.grad(loss_1, model_1.parameters(), retain_graph=False)

                for g1, g2, w, w_2, w_buf in zip(grad_1, grad_2, model_1.parameters(), model_2.parameters(),
                                                 model_buf.parameters()):
                    w.data = w.data - lr_1 * (g1.data - g2.data) + lr_1 * gamma * (w_2.data - w.data)
                    w_buf.data += (inner_iter_count + 1) * w.data

                inner_iter_count += 1
            else:
                for w_2, w_buf in zip(model_2.parameters(), model_buf.parameters()):
                    w_2.data = w_buf.data / iter_sum
                    w_buf.data = torch.zeros_like(w_buf.data)
                inner_iter_count = 0

    elif optimizer_option == 'ssdc_adagrad':
        epoch_count = 0
        train_pos_iter = iter(train_pos_loader)
        train_unl_iter = iter(train_unl_loader)
        inner_iter_count = 0
        G = 10

        model_buf.train()
        for w_buf in model_buf.parameters():
            w_buf.data = torch.zeros_like(w_buf.data)
        model_2.train()

        while epoch_count < num_epochs:
            model_1.train()
            try:
                train_pos_data, _ = next(train_pos_iter)
            except:
                train_pos_iter = iter(train_pos_loader)

                # Evaluation
                train_loss = obj_evaluation(model_1, hinge_loss, train_pos_loader, train_unl_loader, pi_p)
                accuracy_val, accuracy_test = acc_evaluation(model_1, valid_loader, test_loader)

                if accuracy_val > best_vali:
                    best_vali = accuracy_val
                    best_test = accuracy_test

                train_loss_list.append(train_loss)
                valid_acc_list.append(accuracy_val)
                test_acc_list.append(accuracy_test)
                best_test_acc_list.append(best_test)

                print(
                    f'Epoch [{epoch_count + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Validation Accuracy: {accuracy_val:.4f}, Test Accuracy: {accuracy_test:.4f}, Best Test Accuracy: {best_test:.4f}')
                if epoch_count in decay_epochs:
                    lr_1 = lr_1 / decay_factor
                    print('Reducing lr_1 to %.5f @ T=%s!' % (lr_1, epoch_count))
                epoch_count += 1
                continue

            try:
                train_unl_data, _ = next(train_unl_iter)
            except:
                train_unl_iter = iter(train_unl_loader)
                train_unl_data, _ = next(train_unl_iter)

            if inner_iter_count < inner_iter_num:
                labels_pos_pos = torch.ones(train_pos_data.size(0))
                labels_pos_neg = -1 * torch.ones(train_pos_data.size(0))
                labels_unl_neg = -1 * torch.ones(train_unl_data.size(0))

                if inner_iter_count == 0:
                    outputs_2_pos = model_2(train_pos_data)
                    loss_2 = pi_p * hinge_loss(outputs_2_pos, labels_pos_neg)
                    zero_grad(model_2)
                    if loss_2 == 0:
                        grad_2 = []
                        for para in model_2.parameters():
                            grad_2.append(torch.zeros_like(para))
                    else:
                        grad_2 = torch.autograd.grad(loss_2, model_2.parameters(), retain_graph=False)
                    accum_grad = []
                    s_t = []
                    for param in model_2.parameters():
                        accum_grad.append(torch.zeros_like(param))
                        s_t.append(torch.zeros_like(param))
                outputs_1_pos = model_1(train_pos_data)
                outputs_1_unl = model_1(train_unl_data)
                loss_1 = pi_p * hinge_loss(outputs_1_pos, labels_pos_pos) + hinge_loss(outputs_1_unl, labels_unl_neg)
                zero_grad(model_1)
                grad_1 = torch.autograd.grad(loss_1, model_1.parameters(), retain_graph=False)

                for g1, g2, accum_g, w, w_2, w_buf, s in zip(grad_1,
                                                             grad_2,
                                                             accum_grad,
                                                             model_1.parameters(),
                                                             model_2.parameters(),
                                                             model_buf.parameters(),
                                                             s_t):
                    accum_g.data += g1.data - g2.data
                    s.data = torch.sqrt(g1.data * g1.data)
                    H_t = 2 * G + s.data

                    w.data = w_2.data - lr_1 / (inner_iter_count * lr_1 * gamma + H_t) * accum_g.data
                    w_buf.data += w.data

                inner_iter_count += 1
            else:
                for w_2, w_buf in zip(model_2.parameters(), model_buf.parameters()):
                    w_2.data = w_buf.data / inner_iter_num
                    w_buf.data = torch.zeros_like(w_buf.data)
                inner_iter_count = 0
    else:
        raise KeyError('Unknown optimizer.')

    output_dict = {'train_loss_list': train_loss_list, 'valid_acc_list': valid_acc_list, 'test_acc_list': test_acc_list, 'best_test_acc_list': best_test_acc_list}
    file_name = 'new_file.p'
    print('creating new file: ', file_name)
    results_path = '/path/to/folder'
    new_file_path = os.path.join(results_path, file_name)
    with open(new_file_path, 'wb') as handle:
        pickle.dump(output_dict, handle)
        handle.close()


if __name__ == "__main__":
    main()
