from pipeline_utils import *

import numpy as np
import os
import pandas as pd
import json
import argparse
import pickle
import time
from tab_ddpm.utils import *
from pipeline_modules import *
from sdv.metadata import MultiTableMetadata
from tabsyn_utils import learn_to_cluster, clustering_naive
from gen_multi_report import gen_multi_report

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='berka/preprocessed')
    parser.add_argument('--exp_name', type=str, default='berka')
    parser.add_argument('--KEY_SCALE', type=int, default=1)
    parser.add_argument('--PARENT_SCALE', type=float, default=1)
    parser.add_argument('--NUM_CLUSTERS', type=int, default=20)
    parser.add_argument('--NUM_MATCHING_CLUSTERS', type=int, default=5)
    parser.add_argument('--CLASSIFIER_SCALE', type=float, default=1)
    parser.add_argument('--SAMPLE_SCALE', type=int, default=1)
    parser.add_argument('--classifier_steps', type=int, default=10000)
    parser.add_argument('--diffusion_steps', type=int, default=100000)
    parser.add_argument('--model_type', type=str, default='mlp')
    parser.add_argument('--gaussian_loss_type', type=str, default='mse')
    parser.add_argument('--num_timesteps', type=int, default=2000)
    parser.add_argument('--batch_size', type=int, default=4096)
    parser.add_argument('--lr', type=float, default=6e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--test_num_samples', type=int, default=50000)
    parser.add_argument('--scheduler', type=str, default='cosine')
    parser.add_argument('--learn_to_cluster_epochs', type=int, default=500)
    parser.add_argument('--working_dir', type=str, default='lavaDDPM_workspace/berka')
    parser.add_argument('--sample_prefix', type=str, default='')
    parser.add_argument('--matching_batch_size', type=int, default=1000)
    parser.add_argument('--non_unique_matching', action='store_true')
    parser.add_argument('--clustering_naive', action='store_true')
    parser.add_argument('--clustering_method', type=str, default='kmeans', choices=['kmeans', 'gmm', 'both', 'variational', 'y'])
    parser.add_argument('--no_matching', action='store_true')

    args = parser.parse_args()
    clustering_start_time = time.time()

    dataset_meta = json.load(open(os.path.join(args.data_dir, 'dataset_meta.json')))
    save_dir = os.path.join(args.working_dir, args.exp_name)
    os.makedirs(save_dir, exist_ok=True)

    with open(os.path.join(save_dir, 'args'), 'w') as file:
        json.dump(vars(args), file, indent=4)

    relation_order = dataset_meta['relation_order']
    relation_order_reversed = relation_order[::-1]

    tables = {}

    for table, meta in dataset_meta['tables'].items():
        tables[table] = {
            'df': pd.read_csv(os.path.join(args.data_dir, f'{table}.csv')),
            'domain': json.load(open(os.path.join(args.data_dir, f'{table}_domain.json'))),
            'children': meta['children'],
            'parents': meta['parents'],
        }
        tables[table]['original_cols'] = list(tables[table]['df'].columns)
        tables[table]['original_df'] = tables[table]['df'].copy()

    all_group_lengths_prob_dicts = {}

    # Clustering
    for parent, child in relation_order_reversed:
        if parent is not None:
            print(f'Clustering {parent} -> {child}')
            if args.clustering_naive:
                parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts = clustering_naive(
                    tables[child]['df'], 
                    tables[child]['domain'], 
                    tables[parent]['df'],
                    tables[parent]['domain'],
                    f'{child}_id',
                    f'{parent}_id',
                    args.NUM_CLUSTERS,
                    args.PARENT_SCALE,
                    args.KEY_SCALE,
                    parent,
                    child,
                    args,
                )
            else:
                if args.learn_to_cluster_epochs > 0:
                    args.max_beta = 1e-2
                    args.min_beta = 1e-5
                    args.lambd = 0.7
                    args.device = 'cuda'
                    args.vae_batch_size = 4096
                    args.has_y = False
                    args.has_test = False
                    args.working_dir = 'lavaDDPM_workspace/california'
                    args.read_ckpt = False
                    parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts = learn_to_cluster(
                        tables[child]['df'], 
                        tables[child]['domain'], 
                        tables[parent]['df'],
                        tables[parent]['domain'],
                        f'{child}_id',
                        f'{parent}_id',
                        args.NUM_CLUSTERS,
                        args.PARENT_SCALE,
                        args.KEY_SCALE,
                        parent,
                        child,
                        args,
                    )
                else:
                    parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts = pair_clustering_keep_id(
                        tables[child]['df'], 
                        tables[child]['domain'], 
                        tables[parent]['df'],
                        tables[parent]['domain'],
                        f'{child}_id',
                        f'{parent}_id',
                        args.NUM_CLUSTERS,
                        args.PARENT_SCALE,
                        args.KEY_SCALE,
                        parent,
                        child,
                        handle_size1=False,
                        clustering_method=args.clustering_method,
                    )
            tables[parent]['df'] = parent_df_with_cluster
            tables[child]['df'] = child_df_with_cluster
            all_group_lengths_prob_dicts[(parent, child)] = group_lengths_prob_dicts

    clustering_end_time = time.time()

    clustering_time_spent = clustering_end_time - clustering_start_time

    training_start_time = time.time()

    # Training
    models = {}
    for parent, child in relation_order:
        print(f'Training {parent} -> {child}')
        df_with_cluster = tables[child]['df']
        id_cols = [col for col in df_with_cluster.columns if '_id' in col]
        df_without_id = df_with_cluster.drop(columns=id_cols)
        result = child_training(
            df_without_id,
            tables[child]['domain'],
            parent,
            child,
            args.diffusion_steps,
            args.classifier_steps,
            args.batch_size,
            args.model_type,
            args.gaussian_loss_type,
            args.num_timesteps,
            args.scheduler,
            args.lr,
            args.weight_decay,
            args.test_num_samples
        )
        models[(parent, child)] = result

    # a hardcode fix
    for parent, child in relation_order:
        if parent is None:
            tables[child]['df']['placeholder'] = list(range(len(tables[child]['df'])))

    training_end_time = time.time()
    training_time_spent = training_end_time - training_start_time

    synthesizing_start_time = time.time()

    # Synthesize
    synthetic_tables = {}
    for parent, child in relation_order:
        print(f'Generating {parent} -> {child}')
        result = models[(parent, child)]
        df_with_cluster = tables[child]['df']
        df_without_id = get_df_without_id(df_with_cluster)
        if parent is None:
            _, child_generated = sample_from_diffusion(
                df=df_without_id, 
                df_info=result['df_info'], 
                diffusion=result['diffusion'],
                dataset=result['dataset'],
                label_encoders=result['label_encoders'],
                sample_size=int(1 * (len(df_without_id))),
                model_params=result['model_params'],
                T_dict=result['T_dict'],
            )
            child_keys = list(range(len(child_generated)))
            generated_final_arr = np.concatenate(
                [
                    np.array(child_keys).reshape(-1, 1),
                    child_generated.to_numpy()
                ],
                axis=1
            )
            generated_final_df = pd.DataFrame(
                generated_final_arr,
                columns=[f'{child}_id'] + result['df_info']['num_cols'] + result['df_info']['cat_cols'] + [result['df_info']['y_col']]
            )
            generated_final_df = generated_final_df[tables[child]['df'].columns]
            synthetic_tables[(parent, child)] = {
                'df': generated_final_df,
                'keys': child_keys
            }
        else:
            for key, val in synthetic_tables.items():
                if key[1] == parent:
                    parent_synthetic_df = val['df']
                    parent_keys = val['keys']
                    parent_result = models[key]
                    break

            child_result = models[(parent, child)]
            parent_label_index = parent_result['column_orders'].index(
                child_result['df_info']['y_col']
            )

            parent_synthetic_df_without_id = get_df_without_id(parent_synthetic_df)

            _, child_generated, child_sampled_group_sizes = conditional_sampling_by_group_size(
                df=df_without_id, 
                df_info=child_result['df_info'],
                dataset=child_result['dataset'],
                label_encoders=child_result['label_encoders'],
                classifier=child_result['classifier'],
                diffusion=child_result['diffusion'],
                group_labels=parent_synthetic_df_without_id.values[:, parent_label_index].astype(float).astype(int).tolist(),
                group_lengths_prob_dicts=all_group_lengths_prob_dicts[(parent, child)],
                sample_batch_size=40000,
                is_y_cond='none',
                classifier_scale=args.CLASSIFIER_SCALE,
            )

            child_foreign_keys = np.repeat(parent_keys, child_sampled_group_sizes, axis=0).reshape((-1, 1))
            child_foreign_keys_arr = np.array(child_foreign_keys).reshape(-1, 1)
            child_primary_keys_arr = np.arange(
                len(child_generated)
            ).reshape(-1, 1)

            child_generated_final_arr = np.concatenate(
                [
                    child_primary_keys_arr,
                    child_generated.to_numpy(),
                    child_foreign_keys_arr
                ],
                axis=1
            )

            child_final_columns = [f'{child}_id'] + result['df_info']['num_cols'] + \
                result['df_info']['cat_cols'] + [result['df_info']['y_col']] + [f'{parent}_id']

            child_final_df = pd.DataFrame(
                child_generated_final_arr,
                columns=child_final_columns
            )
            original_columns = []
            for col in tables[child]['df'].columns:
                if col in child_final_df.columns:
                    original_columns.append(col)
            child_final_df = child_final_df[original_columns]
            synthetic_tables[(parent, child)] = {
                'df': child_final_df,
                'keys': child_primary_keys_arr.flatten().tolist()
            }
    
    synthesizing_end_time = time.time()
    synthesizing_time_spent = synthesizing_end_time - synthesizing_start_time

    os.makedirs(os.path.join(save_dir, 'before_matching'), exist_ok=True)
    for key, val in synthetic_tables.items():
        val['df'].to_csv(os.path.join(save_dir, 'before_matching', f'{key}_synthetic.csv'), index=False)

    matching_start_time = time.time()

    # Matching
    final_tables = {}
    for parent, child in relation_order:
        if child not in final_tables:
            if len(tables[child]['parents']) > 1:
                final_tables[child] = handle_multi_parent(
                    child, 
                    tables[child]['parents'], 
                    synthetic_tables, 
                    args.NUM_MATCHING_CLUSTERS,
                    unique_matching=not args.non_unique_matching,
                    batch_size=args.matching_batch_size,
                    no_matching=args.no_matching
                )
            else:
                final_tables[child] = synthetic_tables[(parent, child)]['df']

    matching_end_time = time.time()
    matching_time_spent = matching_end_time - matching_start_time

    cleaned_tables = {}
    for key, val in final_tables.items():
        cleaned_tables[key] = val[tables[key]['original_cols']]

    for key, val in cleaned_tables.items():
        save_dir = os.path.join(
            args.working_dir, 
            key, 
            f'{args.sample_prefix}_final'
        )
        os.makedirs(save_dir, exist_ok=True)
        val.to_csv(os.path.join(save_dir, f'{key}_synthetic.csv'), index=False)

    # Eval
    metadata = MultiTableMetadata()
    for table_name, val in tables.items():
        df = val['original_df']
        metadata.detect_table_from_dataframe(
            table_name,
            df
        )
        id_cols = [col for col in df.columns if '_id' in col]
        for id_col in id_cols:
            metadata.update_column(
                table_name=table_name,
                column_name=id_col,
                sdtype='id'
            )
        domain = tables[table_name]['domain']
        for col, dom in domain.items():
            if col in df.columns:
                if dom['type'] == 'discrete':
                    metadata.update_column(
                        table_name=table_name,
                        column_name=col,
                        sdtype='categorical',
                    )
                elif dom['type'] == 'continuous':
                    metadata.update_column(
                        table_name=table_name,
                        column_name=col,
                        sdtype='numerical',
                    )
                else:
                    raise ValueError(f'Unknown domain type: {dom["type"]}')
        metadata.set_primary_key(
            table_name=table_name,
            column_name=f'{table_name}_id'
        )

    for parent, child in relation_order:
        if parent is not None:
            metadata.add_relationship(
                parent_table_name=parent,
                child_table_name=child,
                parent_primary_key=f'{parent}_id',
                child_foreign_key=f'{parent}_id'
            )

    synthetic_tables = {}
    for table, meta in dataset_meta['tables'].items():
        save_dir = os.path.join(
            args.working_dir, 
            table, 
            f'{args.sample_prefix}_final'
        )
        synthetic_tables[table] = pd.read_csv(os.path.join(save_dir, f'{table}_synthetic.csv'))
    
    gen_multi_report(
        args.data_dir,
        args.working_dir,
        'lava'
    )
    
    pickle.dump(metadata, open(os.path.join(save_dir, 'metadata.pkl'), 'wb'))
    
    print('Time spent: ')
    print('Clustering: ', clustering_time_spent)
    print('Training: ', training_time_spent)
    print('Synthesizing: ', synthesizing_time_spent)
    print('Matching: ', matching_time_spent)
    print('Total: ', clustering_time_spent + training_time_spent + synthesizing_time_spent + matching_time_spent)
