import time
import torch
import sys
sys.path.append("..")
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import model_down
from Bio import SeqIO
from typing import List, Tuple,Any
import string
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

import param_parser_esm1b
import esm
scaler = torch.cuda.amp.GradScaler()
def show_tnse(sen_embeddings, labels):
    x_tsne = TSNE(n_components=2).fit_transform(sen_embeddings)
    # x_min, x_max = x_tsne.min(0), x_tsne.max(0)
    # x_norm = (x_tsne - x_min)/(x_max - x_min)
    x_pca = PCA(n_components=2).fit_transform(sen_embeddings)
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    #plt.scatter(x_tsne[:,0], x_tsne[:,1], c=labels, label='t-SNE')
    plt.scatter(x_tsne[:,0], x_tsne[:,1], c=labels, label='t-SNE')
    plt.legend()
    plt.subplot(122)
    plt.scatter(x_pca[:,0], x_pca[:,1], c=labels, label='PCA')
    plt.legend()
    plt.savefig('tsne-pca_bert_notcls_seq1.png', dpi=200)
    plt.show()

if __name__ == '__main__':
    epochs = 60
    batch_size = 16
    best_acc = 0

    # Load ESM-1b model
    #model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    # batch_converter = alphabet.get_batch_converter()
    args = param_parser_esm1b.params_parser()
    # # Load ESM-1b model
    alphabet = esm.data.Alphabet.from_architecture('ESM-1b')
    model = esm.model.ProteinBertModel(args, alphabet)
    model.load_state_dict(torch.load('best_model_val_pretrain_1b.pt')['model_state_dict'],strict=False)
    batch_converter = alphabet.get_batch_converter()
    model=model.cuda()
    #model = model_down.ProteinBertForSequenceClassification().cuda()
    batch_converter = model_down.alphabet.get_batch_converter()
    class Dataset(torch.utils.data.Dataset):
        'Characterizes a dataset for PyTorch'

        def __init__(self, list):
            'Initialization'
            self.list = list

        def __len__(self):
            'Denotes the total number of samples'
            return len(self.list)

        def __getitem__(self, index):
            ID = self.list[index]
            return ID

        def collate_fn(self, batch: List[Tuple[Any, ...]]):
            #print(batch)
            msa_data = []
            label = []
            for i in range(len(batch)):
                msa_data.append(read_sequence(os.path.join(batch[i])))
            msa_batch_label, msa_batch_str, msa_batch_token = batch_converter(msa_data)
            for i in range(len(batch)):
                filedir,filename=os.path.split(batch[i])
                label.append(int(filename[0]))
            return msa_batch_token,label
        
    a3m_dir='/home/public/bigdata/my/datasets/metal/train_small'
    filenames = [
        os.path.join(a3m_dir,name) for name in os.listdir(a3m_dir)
        if os.path.splitext(name)[-1] == '.a3m'
    ]  #选择指定目录下的.png图片
    dataset = Dataset(filenames)
    train_dataloader = torch.utils.data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            drop_last=True,
                                            pin_memory=True,
                                            num_workers=8,
                                            collate_fn=dataset.collate_fn)

    test_dir='/home/public/bigdata/my/datasets/metal/test'
    testnames = [
        os.path.join(test_dir,name) for name in os.listdir(test_dir)
        if os.path.splitext(name)[-1] == '.a3m'
    ]  #选择指定目录下的.png图片
    test_dataset = Dataset(testnames)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            drop_last=True,
                                            pin_memory=True,
                                            num_workers=8,
                                            collate_fn=test_dataset.collate_fn)

    #model = model_down.ProteinBertForSequenceClassification().cuda()
    a=[]
    b=[]
    for idx, batch in enumerate(test_dataloader):
        seqs = batch[0]
        targets = batch[1]
        inputs, targets = torch.tensor(seqs).cuda(), torch.tensor(targets).cuda()
        with torch.no_grad():
            outputs = model(inputs,repr_layers=[33])
            value_prediction = outputs['representations'][33][:,0,:].squeeze().cpu().numpy()
        a.extend(value_prediction)
        b.extend(targets.cpu().numpy())
        if(idx==18):break
    show_tnse(np.array(a),np.array(b))

from anndata import AnnData
import scanpy as sc


def plot_umap(adata, namespace='flu'):
    sc.pl.umap(adata, color='label1',save='label1.png'.format(namespace))

adata = AnnData(np.array(a))
obs = {}
obs["label1"] = []
#adata.obs['label1'] = np.array(b)
obs["label1"].extend(np.array(b))
for key in obs:
        #print(key) #n_seq seq Name #Sequence Accession Complete Genome Segment Segment_Length Subtype  Collection_Date
        #Host_Species Country State/Province Flu_Season Strain_Name
    adata.obs[key] = obs[key]
sc.pp.neighbors(adata, n_neighbors=100, use_rep='X')#, n_neighbors=100,
sc.tl.louvain(adata, resolution=1.)

sc.set_figure_params(dpi_save=500)

sc.tl.umap(adata, min_dist=1.)
plot_umap(adata)




