import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import orth
from numpy.linalg import svd

GENOMIC_PCA_FILE_PATH = "PCA.txt"
COUNTRY_COLORING_FILE_PATH = "POPRESID_Color.txt"
COLOR_TO_RGB_FILE_PATH = "colors.txt"

PCA_SCORE_FILE_PATH = 'POPRES_08_24_01.EuroThinFinal.LD_0.8.exLD.out0-PCA.eigs'
EIGEN_VALUES_FILE_PATH = 'POPRES_08_24_01.EuroThinFinal.LD_0.8.exLD.out0-PCA.eval'
D = 20
EPS = 0.1 # fraction of noise points in the final dataset (add eps/(1-eps))

def convert_color_string_to_rgb():
    color_to_rgb = pd.read_csv(COLOR_TO_RGB_FILE_PATH, sep="\t", lineterminator="\n", names=['color', 'r', 'g', 'b'])
    color_to_rgb['rgb_code'] = list(zip(color_to_rgb['r']/255.0, color_to_rgb['g']/255.0, color_to_rgb['b']/255.0))
    color_to_rgb = color_to_rgb.drop(columns=['r', 'g', 'b'])

    return color_to_rgb

def read_ID_color_matches():
    with open(COUNTRY_COLORING_FILE_PATH, 'r') as f:
        lines = list()
        for line in f:
            lines.append(line.strip().rsplit(" ", 1))
    ID_color_matching = pd.DataFrame(lines, columns = ['ID', 'color'])
    ID_color_matching['ID'] = ID_color_matching['ID'].astype('int64')
    
    return ID_color_matching

def generate_noise(data, eps):
    """
    First d/2 coordinates are each uniformly random integers between 0 and 2 (P1) and last d/2 coordinates are each uniformly randomly either 2 or 3 (P2)
    Coordinates are scaled by a factor of 1/24

    Q = 0.5*P1 + 0.5*P2
    """
    n = data.shape[0] / (1-eps)
    q = int(eps*n)
    p1 = np.random.randint(0, 3, size=(q, D//2))
    p2 = np.random.randint(2, 4, size=(q, D//2))
    noise = 1/24*np.concatenate([p1, p2], axis=1)
    
    noiseColors = np.zeros((q, 3))

    return noise, noiseColors

def project_genomic_data_2D():
    genomic_data_pca = pd.read_csv(GENOMIC_PCA_FILE_PATH, sep="\t", lineterminator="\n")

    ### Add color string information to dataset
    country_color_matching = read_ID_color_matches()
    df = pd.merge(genomic_data_pca, country_color_matching, on='ID', how='left')

    ### Convert color string to rgb format
    color_to_rgb = convert_color_string_to_rgb()
    df = pd.merge(df, color_to_rgb, on='color', how='left')

    ### Plot data projected onto top two principal components
    plt.figure(figsize = (8, 6))
    plt.scatter(df['RotNS'], df['RotEW'], c = list(df['rgb_code']), alpha=0.6)
    plt.title('Genomic Data projected onto top 2 PCs')
    plt.show()

def perform_pca_genomic_data_2D(noisy=False):
    ### Reading PCA scores of top 20 
    pca_data = pd.read_csv(PCA_SCORE_FILE_PATH, sep='\s+', header=None, skiprows=1)
    pca_scores = pca_data.drop([pca_data.columns[0], pca_data.columns[1], pca_data.columns[22]], axis=1, inplace=False)
    column_names = ["PC{}".format(xx) for xx in range(1, 21)]
    pca_scores.columns = column_names

    ### Reading eigenvalues, representing the variance explained by the corresponding principal component
    eigenvalues = pd.read_csv(EIGEN_VALUES_FILE_PATH, header=None)
    eigenvalues.columns = ["eigen_values"]
    eigenvalues = eigenvalues[:20].values

    ### Scale pca_scores by eigenvalues to obtain dataset
    data = pca_scores.to_numpy() * eigenvalues.squeeze()

    ### Include color information
    color_to_rgb = convert_color_string_to_rgb()
    country_color_matching = read_ID_color_matches()
    colors = list()
    for id in pca_data[pca_data.columns[0]]:
        color_name = country_color_matching.loc[country_color_matching['ID'] == id, 'color'].values[0]
        rgb_value = color_to_rgb.loc[color_to_rgb['color'] == color_name, 'rgb_code'].values[0]
        colors.append(rgb_value)

    ### Add noise to dataset
    if(noisy):
        noise, noiseColors = generate_noise(data, EPS)
        data = np.vstack((data, noise))
        colors = np.vstack((colors, noiseColors))

    ### Randomly rotate data to undo diagonalized structure in high-dimensional space
    """randRot = orth(np.random.randn(D, D))
    orth(np.random.rand(D, D))
    dataRot = np.dot(data, randRot)

    ### Obtain top 2 principal components by performing SVD
    U, _, _ = svd(np.cov(dataRot.T))
    projected_data = np.dot(dataRot, U[:, :2])

    ### Plot data projected onto top two principal components
    plt.figure(figsize=(8,6))
    plt.scatter(projected_data[:, 1], -projected_data[:, 0], c=colors, label='Original Data', alpha=0.6)
    plt.title('Genomic Data projected onto top 2 PCs')
    plt.show()"""

    ### perform pca using sklearn
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(data)
    projected_data = pca.transform(data)
    plt.figure(figsize=(8,6))
    plt.scatter(projected_data[:, 1], -projected_data[:, 0], c=colors, label='Original Data', alpha=0.6)
    plt.title('Genomic Data projected onto top 2 PCs')
    plt.show()