import datetime
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.sparse import load_npz
from sklearn import preprocessing, metrics
from tqdm import tqdm, trange
from scipy.stats import norm
from scipy import stats


def read_csv(filename, time_idx):
    df = pd.read_csv(filename)
    df[time_idx] = pd.to_datetime(df[time_idx])
    return df


def list2tensor(df, time_idx, aspects, target, freq):

    date_range = pd.date_range(
        df[time_idx].min(), df[time_idx].max(), freq=freq)

    label_encoders = []

    for col in aspects:
        le = preprocessing.LabelEncoder()
        df[col] = le.fit_transform(df[col])
        label_encoders.append(le)

    # e.g., (n_loc, n_key, n_duration)
    tensor = np.zeros([len(le.classes_) for le in label_encoders] + [len(date_range)])
    print("Converting to tensor...")
    for key, grouped in tqdm(df.groupby(aspects)):
        tensor[key] = pd.DataFrame(index=date_range).join(
            grouped.set_index(time_idx)[target].resample(freq).sum(),
            how="left").fillna(0).to_numpy().ravel()

    return tensor


def list2tensor_from_index(data, timestamp, n_attributes):
    n_sample = len(timestamp)
    tensor = np.zeros((*n_attributes.values, n_sample))
    print(data)

    # for i, t in enumerate(timestamp):
    for t in trange(n_sample, desc="list2tensor"):
        # print(data[timestamp[t]:timestamp[t+1]].shape)
        if t < n_sample - 1:
            tmp = data[timestamp[t]:timestamp[t+1]]

        if t == n_sample - 1:
            tmp = data[timestamp[t]:]

        for _, row in tmp.iterrows():
            idx, val = row.values[:-1], row.values[-1]
            tensor[idx[0], idx[1], t] += val

    return tensor


def load_tycho(filename, as_tensor=False):

    # Default setting for TYCHO dataset
    time_idx = "from_date"
    aspects = ["state", "disease"]
    target = "number"
    freq = "W"

    if as_tensor == True:
        if not os.path.isfile("../dat/project_tycho.npy"):
            data = read_csv(filename, time_idx)
            data = data[data[time_idx] >= "1950-01-01"]
            tensor = list2tensor(data, time_idx, aspects, target, freq)
            # tensor = tensor[..., 52 * 27:]
            np.save("../dat/project_tycho.npy", tensor)
            return tensor
        else:
            return np.load("../dat/project_tycho.npy")

    else:
        return read_csv(filename, time_idx)


def load_citibike(filepath, key="stationid", freq="H", as_tensor=True):
    """
        filepath: path of a directory including citibike event equences
        start_date: str, start date to specify dataset duration
        end_date: str, end date to specify dataset duration
        keys: "stationid" or "userage"
        as_tensor: bool
    """

    if as_tensor ==True:
        data = load_npz(f"{filepath}citibike_{key}.npz")
        if key == 'userage':
            data = data.toarray()
            data = data.reshape((409, 64, data.shape[-1]))

        if key == 'stationid':
            data = data.todense().reshape((409, 409, data.shape[-1]))

        print(data.shape)
        return data

    else:   
        start_date = "2017-03"
        end_date = "2021-04"
        date_range = pd.date_range(start=start_date, end=end_date, freq="m")
        # print(date_range)
        # station_data = pd.read_csv(filepath + "station_data.csv")
        trip_data = []

        for date in tqdm(date_range):

            df = pd.read_csv(filepath + "tripdata_" + date.strftime("%Y%m") + ".csv")
            df = df.rename(columns={"starttime": "timestamp"})
            df["timestamp"] = pd.to_datetime(df.timestamp)
            df = df.set_index("timestamp")

            if key == "stationid":
                df = df[["start_station_dim", "end_station_dim"]]
            elif key == "userage":
                df = df[["start_station_dim", "userage"]]
            else:
                raise ValueError

            df = df.astype(int)
            df = df.groupby([pd.Grouper(freq=freq), *df.keys()]).size()
            df = df.reset_index(level=[1, 2])
            df = df.rename(columns={0: "count"})
            # print(date.strftime("%Y%m"), df.memory_usage())
            trip_data.append(df)

        trip_data = pd.concat(trip_data)

        print(trip_data.head())
        print(trip_data.shape)

        # Extract data

        count_data = trip_data.groupby("start_station_dim").sum()
        count_data = count_data.sort_values("count", ascending=False)
        target = count_data.query("count>50000").reset_index().reset_index()
        target[["start_station_dim", "index"]].values
        
        if key == "stationid":
            new_data = trip_data[trip_data["start_station_dim"].isin(target.start_station_dim)]
            new_data = new_data[new_data["end_station_dim"].isin(target.start_station_dim)]
            new_data = new_data.join(
                target.set_index("start_station_dim")["index"], on="start_station_dim")
            new_data = new_data.rename(columns={"index": "index_x"})
            new_data = new_data.join(
                target.set_index("start_station_dim")["index"], on="end_station_dim")
            new_data = new_data.rename(columns={"index": "index_y"})
            new_data = new_data[["index_x", "index_y", "count"]]

        elif key == "userage":
            new_data = trip_data[trip_data["start_station_dim"].isin(target.start_station_dim)]
            new_data = new_data.join(
                target.set_index("start_station_dim")["index"], on="start_station_dim")
            new_data = new_data.rename(columns={"index": "index_x"})

            age_index = pd.DataFrame()
            age_index["userage"] = list(range(trip_data.userage.min(), trip_data.userage.max()+1))
            age_index = age_index.reset_index().set_index("userage")

            new_data = new_data.join(age_index, on="userage")
            new_data = new_data.rename(columns={"index": "index_y"})
            new_data = new_data[["index_x", "index_y", "count"]]

        print(new_data.head())
        return new_data


def load_nytaxi(filepath, as_tensor=False):

    tensor = []

    for i in trange(265, desc="NYC Yellow Taxi"):
        try:
            M = load_npz(filepath + f'mat_{i+1}.npz').toarray().T
        except:
            M = np.zeros(tensor[-1].shape)
        tensor.append(M)

    return np.array(tensor)


def load_olist(filepath):
    data = np.load(filepath + "olist.npy")
    print(data.shape)
    return data

def load_gtrends(filename, as_tensor=False):
    data = np.load(f"../dat/{filename}/tensor.npy")
    data = np.moveaxis(data, 1, 2)
    print(data.shape)
    return data

def compute_model_cost(X, n_bits=32, zero=1e-10):
    # Typical MDL
    k, l = X.shape
    X_nonzero = np.count_nonzero(X > zero)
    return X_nonzero * (np.log(k) + np.log(l) + n_bits) + np.log(X_nonzero)
    # plt.hist(X.flatten(), bins=30)
    # plt.savefig("Whist.png")
    # plt.close()

    # https://eprints.soton.ac.uk/413094/1/NeuralComputationMDLPostReviews.pdf

    X_size = X.size
    X_nonzero = np.count_nonzero(X > zero)
    X_zero = X_size - X_nonzero
    cost_zero = - X_zero * np.log(X_zero / X_size)
    cost_zero -= (X_size - X_zero) * np.log((X_nonzero / X_size))

    nonzero_elements = X[X > zero].flatten()
    a_hat, loc_hat, scale_hat = stats.gamma.fit(nonzero_elements)

    prob_gamma = stats.gamma.pdf(nonzero_elements, a_hat,
                                 loc=loc_hat,
                                 scale=scale_hat)

    cost_gamma = -1 * np.log(prob_gamma[prob_gamma > zero])

    # print("cost zero =", cost_zero)
    # print("cost gamma=", cost_gamma.sum())
    return cost_zero + cost_gamma.sum()
    # return np.log(X.size)
    # return np.log(X_nonzero)
    # return X.size * (np.log(k) + np.log(l))


def compute_coding_cost(X, Y):

    mask = X > 0
    if mask.sum() == 0:
        return 0
    diff = (X[mask] - Y[mask]).flatten().astype("float32")
    # print("X", X)
    # print("Y", Y)
    # print(diff)
    # diff = (X - Y).flatten().astype("float32")
    # print("mean", diff.mean(), "std", diff.std())

    # diff[diff == 0] = 1e-10
    # diff = np.log1p(np.abs(diff))

    # plt.hist(np.log1p(np.abs(diff)), bins=100)
    # plt.hist(diff, bins=10)
    # plt.savefig("diff.png")
    # plt.close()
    
    logprob = norm.logpdf(diff, loc=diff.mean(), scale=diff.std())
    # prob = norm.pdf(diff, loc=diff.mean(), scale=diff.std())
    # prob = norm.pdf(diff, loc=diff.mean(), scale=max(diff.std(), 1e-10))
    # prob = stats.expon.pdf(diff, loc=diff.mean(), scale=diff.std())

    # print(diff.mean(), diff.std())
    # print(logprob.sum())
    cost = -1 * logprob.sum() / np.log(2.)
    # print("coding cost [utils]")
    # print(-1 * logprob.sum())
    # print((-1 * logprob).sum() / np.log(2))
    # print(-1 * sum(np.log2(prob)))
    # print(-1 * sum(np.log(prob)))
    # cost = -1 * sum(np.log(prob))
    # cost = -1 * sum(np.log2(prob))
    # cost = -1 * sum(np.log1p(prob))

    return cost


def compute_error(X, Y):
    mask = X > 0
    if mask.sum() == 0:
        return 0
    else:
        return np.sqrt(metrics.mean_squared_error(
            X[mask].flatten(),
            Y[mask].flatten()
        ))