"""
Helper functions for simple datasets
"""
from abc import ABC

import itertools
from typing import Optional

from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import os
from torch.utils.data import DataLoader, random_split, TensorDataset

"""
Datasets
"""


class ToyDataModule(LightningDataModule, ABC):
    def __init__(self, X, test_size, batch_size: int = 100, num_workers: int = int(os.cpu_count() / 2)):
        super().__init__()
        self.X = X
        self.test_size = test_size
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if self.test_size == 0:
            self.Xtrain, self.Xtest = self.X, self.X
        else:
            self.Xtrain, self.Xtest = train_test_split(self.X, test_size=self.test_size)

    def train_dataloader(self):
        return DataLoader(TensorDataset(torch.Tensor(self.Xtrain)), batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(TensorDataset(torch.Tensor(self.Xtest)), batch_size=self.batch_size,
                          num_workers=self.num_workers)


def binary_frontdoor(n_samples, batch_size: Optional[int], num_workers: Optional[int], validation_size=0.):
    u1 = np.random.uniform(0, 1, size=(n_samples, 1))
    u2 = np.random.normal(0, 1, size=(n_samples, 1))
    x = np.random.binomial(1, p=u1, size=(n_samples, 1))
    w = np.random.binomial(1, p=1 / (1 + np.exp(-x - u2)), size=(n_samples, 1))
    y = np.random.binomial(1, p=1 / (1 + np.exp(w - u1)), size=(n_samples, 1))

    X = np.concatenate([x, w, y], axis=1).astype(float)

    tmp = torch.tensor(X).float()
    t = torch.nn.functional.one_hot(tmp[:, 0].long(), num_classes=2).reshape(-1, 2)
    data = torch.cat((t, tmp[:, 1].reshape(-1, 1), tmp[:, 2].reshape(-1, 1)), dim=1).float()
    var_dims = np.array([2, 1, 1])
    dm =  ToyDataModule(data.float(), test_size=validation_size, batch_size=batch_size,
                               num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims}


def backdoor_data(n_samples, linear, batch_size, num_workers, validation_size=0.):
    if linear:
        coeffs1 = np.array([1, -1, 2]).reshape(-1, 1)
        coeffs2 = np.array([3, 1, -0.5]).reshape(-1, 1)
        x = np.random.multivariate_normal(mean=np.array([1 for i in range(3)]), cov=np.eye(3), size=n_samples)
        t = x @ coeffs1 + 2 + np.random.normal(0, 3, size=(n_samples, 1))
        y = x @ coeffs2 + 3 * t + np.random.normal(0, 2, size=(n_samples, 1))
        var_dims = np.array([3, 1, 1])
        true_atd = 3.
    else:
        x = np.random.normal(2, 1, size=(n_samples, 1))
        t = 0.1 * x ** 2 - x + np.random.normal(1, 2, size=(n_samples, 1))
        y = 0.5 * t ** 2 - t * x + np.random.normal(0, 2, size=(n_samples, 1))
        var_dims = np.array([1, 1, 1])
        true_atd = -2.5
    data = np.concatenate([x, t, y], axis=-1)
    data = torch.tensor(data).float()
    dm = ToyDataModule(data, test_size=validation_size, batch_size=batch_size, num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims, 'true_atd': true_atd}


def binary_iv(n_samples, batch_size, num_workers, validation_size=0.):
    strata_xy = np.asarray([
        # y00         y10       y01       y11
        [0.000757, 0.013034, 0.006125, 0.002606],  # x00
        [0.004541, 0.074105, 0.034526, 0.014387],  # x10
        [0.026040, 0.418847, 0.195419, 0.082264],  # x01
        [0.004534, 0.073950, 0.034123, 0.014742]  # x11
    ])

    strata_z = [0.649335, 0.350665]

    index_dict = {
        (0, 0): [0, 2],
        (0, 1): [1, 3],
        (1, 0): [0, 1],
        (1, 1): [2, 3]
    }

    def pr(z, x, y):
        x_indices = index_dict[(z, x)]
        y_indices = index_dict[(x, y)]
        prob = 0
        for i in x_indices:
            for j in y_indices:
                prob += strata_xy[i][j]
        return strata_z[z] * (prob)

    A = np.array(list(itertools.product([0, 1], [0, 1], [0, 1])))
    X = A[np.random.choice(A.shape[0], size=n_samples,
                           p=[pr(z, x, y) for z, x, y in itertools.product([0, 1], [0, 1], [0, 1])])]

    tmp = torch.tensor(X).float()
    t = torch.nn.functional.one_hot(tmp[:, 1].long(), num_classes=2).reshape(-1, 2)
    data = torch.cat((tmp[:, 0].reshape(-1, 1), t, tmp[:, 2].reshape(-1, 1)), dim=1).float()
    var_dims = np.array([1, 2, 1])
    dm =  ToyDataModule(data.float(), test_size=validation_size, batch_size=batch_size,
                               num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims}


def bow_data(linear, n_samples, batch_size, num_workers, validation_size=0., extra=False):
    e_t = np.random.normal(0, 1, size=(n_samples, 1))
    e_y = np.random.normal(0, 1, size=(n_samples, 1))
    c = np.random.normal(0, 1, size=(n_samples, 1))
    t = c + e_t
    y = t + c + e_y
    var_dims = np.array([1, 1])
    data = np.concatenate([t, y], axis=-1)
    if extra:
        z = np.random.normal(0, 1, size=(n_samples, 1))
        data = np.concatenate([z, t, y], axis=-1)
        var_dims = np.array([1, 1, 1])

    data = torch.tensor(data).float()
    dm = ToyDataModule(data, test_size=validation_size, batch_size=batch_size, num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims, 'true_atd': 1.}


def iv_data(linear, n_samples, batch_size, num_workers, validation_size=0., weak=False):
    e_t = np.random.normal(0, 1, size=(n_samples, 1))
    e_y = np.random.normal(0, 1, size=(n_samples, 1))
    c = np.random.normal(0, 1, size=(n_samples, 1))
    z = np.random.multivariate_normal(np.array([-1, 0]), np.eye(2), size=n_samples)
    if linear:
        coeff = np.array([1, -1]).reshape(-1, 1)
        t = z @ coeff + 0.5 * c + e_t
        y = 0.5 * t - 3 * c + e_y
        true_atd = 0.5
        var_dims = np.array([2, 1, 1])
    else:
        if weak:
            coeff = np.array([3, 1.5]).reshape(-1, 1)
            t = z @ coeff + 0.5 * c + e_t
            y = 0.3 * t ** 2 - 1.5 * t + c + e_y
            true_atd = -3.3
            var_dims = np.array([2, 1, 1])
        else:
            z = np.random.normal(0, 1, size=(n_samples, 1))
            t = 0.05 * z + 3 * c + e_t
            y = t ** 2 - 3 * t * c + e_y
            true_atd = 0.
            var_dims = np.array([1, 1, 1])

    data = np.concatenate([z, t, y], axis=-1)
    data = torch.tensor(data).float()
    dm = ToyDataModule(data, test_size=validation_size, batch_size=batch_size, num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims, 'true_atd': true_atd}


def frontdoor_data(n_samples, linear, batch_size, num_workers, validation_size=0.):
    if linear:
        coeffs1 = np.array([2, -1]).reshape(-1, 1)
        coeffs2 = np.array([2, 1]).reshape(-1, 1)
        u = np.random.normal(-1, 1, size=(n_samples, 1))
        t = u + np.random.normal(0, 3, size=(n_samples, 1))
        e_x = np.random.multivariate_normal(mean=np.array([1 for i in range(2)]), cov=2 * np.eye(2), size=n_samples)
        x = t @ coeffs1.T + e_x
        y = x @ coeffs2 + u + np.random.normal(0, 2, size=(n_samples, 1))
        var_dims = np.array([1, 2, 1])
        true_atd = 3
    else:
        u = np.random.normal(-1, 1, size=(n_samples, 1))
        t = np.random.normal(2, 2, size=(n_samples, 1)) + u
        x = 2 * t + np.random.normal(1, 2, size=(n_samples, 1))
        y = 0.25 * x ** 2 - x + u + np.random.normal(0, 2, size=(n_samples, 1))
        var_dims = np.array([1, 1, 1])
        true_atd = 0.

    data = np.concatenate([t, x, y], axis=-1)
    data = torch.tensor(data).float()
    dm = ToyDataModule(data, test_size=validation_size, batch_size=batch_size, num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims, 'true_atd': true_atd}


def leaky_data(n_samples, linear, batch_size, num_workers, validation_size=0.):
    if linear:
        u = np.random.multivariate_normal(mean=np.array([1, -1]), cov=np.eye(2), size=n_samples)
        c = np.random.normal(0, 1, size=(n_samples, 1))
        e_t = np.random.normal(0, 1, size=(n_samples, 1))
        e_y = np.random.normal(0, 1, size=(n_samples, 1))
        e_x = np.random.multivariate_normal(mean=np.array([0, 0]), cov=np.eye(2), size=n_samples)
        coeffs1 = np.array([1, 2]).reshape(-1, 1)
        coeffs2 = np.array([-1.5, 2]).reshape(-1, 1)

        t = c + e_t
        x = t @ coeffs1.T + e_x + u
        y = x @ coeffs2 + u.sum(axis=1).reshape(-1, 1) + c + e_y
        var_dims = np.array([1, 2, 1])
        true_atd = 2.5
    else:
        raise NotImplementedError

    data = np.concatenate([t, x, y], axis=-1)
    data = torch.tensor(data).float()
    dm = ToyDataModule(data, test_size=validation_size, batch_size=batch_size, num_workers=num_workers)
    return {'dm': dm, 'data': data, 'var_dims': var_dims, 'true_atd': true_atd}


def gen_data(key, args):
    return {
        'backdoor': backdoor_data,
        'frontdoor': frontdoor_data,
        'iv': iv_data,
        'leaky': leaky_data,
        'bow': bow_data
    }[key](**args)
