import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce

t1 = pd.read_csv('timm_models_raw.csv', header=None)
t2 = pd.read_csv('timm_models2_raw.csv', header=None)
t1['mean']=t1.mean(axis=1)
t2['mean']=t2.mean(axis=1)
t_dims=pd.read_csv('timm_model_dims.csv', header=None)
t1=t1.sort_values(by=[0])
t2=t2.sort_values(by=[0])
t_dims=t_dims.sort_values(by=[0])
timm_means=np.mean([t1['mean'], t2['mean']], axis=0)
timm_names=[t[5:] for t in t1[0].values]
timm_dims=t_dims[1].values

t3=pd.DataFrame([timm_names, timm_means, timm_dims]).T
t3.columns=['model', 'mean', 'dims']
t3['mean']=t3['mean'].astype('float64')*100
t3['dims']=t3['dims'].astype('float64')

t3=t3[t3['mean']>0]



#LOAD CIFAR100
import timm
from timm.data import create_dataset, create_loader

# dataset = ds_fn("~/Workspace/datasets", download=True, split='train', transform=transform)
batch_size=256
torch_dataset='FashionMNIST'
# train_data = create_dataset("torch/CIFAR100", "~/Workspace/datasets", download=True, split='train')
test_data = create_dataset(f"torch/{torch_dataset}", "/data/datasets", download=True, split='test')

# train_loader = DataLoader(train_data, batch_size, shuffle=True)
# test_loader = DataLoader(test_data, batch_size, shuffle=True)





#EXTRACT EMBEDDINGS 
from tqdm import tqdm
batch_size=256
def get_embedding(model,dataset, workers=1, filename=None):
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    device="cpu"
    if 'mnist' in torch_dataset:
        in_chans=1
    else:
        in_chans=3
    m = timm.create_model(model, pretrained=True, num_classes=0, in_chans=in_chans).to(device)
    m.eval()
    data_config=m.default_cfg
    insize=list(data_config['input_size'])
    mean=data_config['mean']
    std=data_config['std']
    if 'mnist' in torch_dataset:
        insize[0]=1
        mean=[0.5]
        std=[0.5]
    if 'interpolation' in data_config.keys() and 'crop_pct' in data_config.keys():
        loader = create_loader(
            dataset,
            input_size=insize,
            batch_size=batch_size,
            interpolation=data_config['interpolation'],
            mean=mean,
            std=std,
            num_workers=workers,
            crop_pct=data_config['crop_pct'],
            use_prefetcher=False
        )
    else:
        loader = create_loader(
            dataset,
            input_size=insize,
            batch_size=batch_size,
            mean=mean,
            std=std,
            num_workers=workers,
            use_prefetcher=False
        )
    vectors=[]
    targets=[]
    for i,(inp,target) in enumerate(tqdm(loader)):
        if i%10==0:
            print(f"{model}: {i}/{len(loader)}")
        inp=inp.to(device)
        target=target.to(device)
        vectors.append(m(inp).detach().cpu().numpy())
        targets.append(target)
    del(m)
    vectors=np.concatenate(vectors, axis=0)
    targets=np.concatenate(targets, axis=0)
    if filename:
        np.save(filename,vectors)
        np.save(filename+'_targets',targets)
    return vectors

import os
os.makedirs(f'/data/HAMFSL/embeds/{torch_dataset}', exist_ok=True)
cache=os.listdir(f'/data/HAMFSL/embeds/{torch_dataset}')

infeasible=['tf_efficientnet_l2_ns_475', 'beit_large_patch16_512',
       'beit_large_patch16_384', 'vit_large_patch16_384',
       'convnext_xlarge_384_in22ft1k', 'beit_large_patch16_224',
       'tf_efficientnetv2_xl_in21ft1k', 'convnext_xlarge_in22ft1k',
       'cait_m48_448', 'vit_large_r50_s32_384', 'ig_resnext101_32x48d',
       'resnetv2_152x4_bitm', 'dm_nfnet_f6', 'dm_nfnet_f5', 'dm_nfnet_f4',
       'ig_resnext101_32x32d', 'cait_m36_384', 'dm_nfnet_f3',
       'vit_large_patch16_224', 'resnetv2_101x3_bitm',
       'resnetv2_152x2_bitm', 'resnetv2_152x2_bit_teacher_384',
       'resnetv2_50x3_bitm', 'vit_large_r50_s32_224',
       'vit_large_patch32_384', 'resnetv2_152x2_bit_teacher',
       'mixer_l16_224']

for model in tqdm(t3.model.values):
    print(model)
    filename=f'/data/HAMFSL/embeds/{torch_dataset}/{model}'
    if model+'.npy' in cache:
        print('Already processed.')
    elif model in infeasible:
        print('Model too large.')
    else:
        try:
            get_embedding(model, test_data, filename=filename)
        except Exception as e:
            np.save(filename, [model, e])
            np.save(filename+'_error', [model, e])
