import torch
from torch import optim
import torch.nn as nn


def train_bce(D, G, dataloader, lr, epoch_n, device, verbose, ml_ls, path_f='.'):
    read_names = []
    loss_history = []
    ar_history = []
    dim_z = 100
    real_label = 1
    fake_label = 0
    D.to(device)
    G.to(device).eval()

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=ml_ls, gamma=0.1)
    criterion = nn.BCELoss()
    count_object = 0
    print('Start')
    for epoch in range(epoch_n):
        for i, data in enumerate(dataloader, 0):
            x_s = data[0].to(device)
            b_size = x_s.size(0)
            count_object += b_size

            with torch.no_grad():
                noise = torch.randn(b_size, dim_z, 1, 1, device=device)
                x_p = G(noise)
                batch = torch.cat([x_s, x_p])
                l_s = torch.full((b_size, ), real_label, device=device)
                l_p = torch.full((b_size, ), fake_label, device=device)
                label = torch.cat([l_s, l_p])

                idx = torch.randperm(2 * b_size)
                label = label[idx]
                batch = batch[idx]


            output = D(batch).view(-1)
            loss = criterion(output, label)

            loss_history.append([count_object, loss.item()])

            D.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if i % verbose == 0:
                with torch.no_grad():
                    fake_score = torch.mean(D(x_p)).item()
                    true_score = torch.mean(D(x_s)).item()
                    ar = torch.mean(D.ar(x_s, x_p)).item()
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tFake score: %4f\tReal score: %4f, AR: %4f'
                      % (epoch, epoch_n-1, i, len(dataloader), loss.item(), fake_score, true_score, ar))
                save_name = '%s/net_D_co_%d.pth' % (path_f, count_object)
                read_names.append(save_name)
                torch.save(D.to('cpu').state_dict(), save_name)
                ar_history.append(ar)
                D.to(device)

    ar_D = dict(zip(read_names, ar_history))
    torch.save(ar_D, '%s/_ar_history' % path_f)
    save_name = '%s/net_D_co_%d.pth' % (path_f, count_object)
    read_names.append(save_name)
    torch.save(D.to('cpu').state_dict(), save_name)
    torch.save(read_names, '%s/_D_order' % path_f)
    torch.save(loss_history, '%s/_loss_history' % path_f)
    torch.cuda.empty_cache()
    print(path_f)
    return 0


