import pandas as pd
import argparse
from sdv.metadata import MultiTableMetadata
from sdv.evaluation.multi_table import evaluate_quality
import json
import os


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Matching pipeline')
    parser.add_argument('--data_dir', type=str, default='complex_data/instacart/preprocessed')
    parser.add_argument('--NUM_MATCHING_CLUSTERS', type=int, default=1000)
    parser.add_argument('--exp_name', type=str, default='instacart_exp')
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--unique_matching', action='store_true')

    args = parser.parse_args()

    dataset_meta = json.load(open(os.path.join(args.data_dir, 'dataset_meta.json')))
    relation_order = dataset_meta['relation_order']
    save_dir = os.path.join(args.data_dir, 'save', args.exp_name)

    tables = {}

    for table, meta in dataset_meta['tables'].items():
        tables[table] = {
            'df': pd.read_csv(os.path.join(args.data_dir, f'{table}.csv')),
            'domain': json.load(open(os.path.join(args.data_dir, f'{table}_domain.json'))),
            'children': meta['children'],
            'parents': meta['parents'],
        }
        tables[table]['original_cols'] = list(tables[table]['df'].columns)
        tables[table]['original_df'] = tables[table]['df'].copy()

    cleaned_tables = {}
    for table_name, _ in tables.items():
        cleaned_tables[table_name] = pd.read_csv(os.path.join(save_dir, f'{table_name}_synthetic.csv'))

    # Eval
    metadata = MultiTableMetadata()
    for table_name, df in cleaned_tables.items():
        metadata.detect_table_from_dataframe(
            table_name,
            df
        )
        id_cols = [col for col in df.columns if '_id' in col]
        for id_col in id_cols:
            metadata.update_column(
                table_name=table_name,
                column_name=id_col,
                sdtype='id'
            )
        domain = tables[table_name]['domain']
        for col, dom in domain.items():
            if col in df.columns:
                if dom['type'] == 'categorical':
                    metadata.update_column(
                        table_name=table_name,
                        column_name=col,
                        sdtype='categorical',
                    )
                else:
                    metadata.update_column(
                        table_name=table_name,
                        column_name=col,
                        sdtype='numerical',
                    )
        metadata.set_primary_key(
            table_name=table_name,
            column_name=f'{table_name}_id'
        )

    for parent, child in relation_order:
        if parent is not None:
            metadata.add_relationship(
                parent_table_name=parent,
                child_table_name=child,
                parent_primary_key=f'{parent}_id',
                child_foreign_key=f'{parent}_id'
            )

    synthetic_tables = {}
    for table, meta in dataset_meta['tables'].items():
        synthetic_tables[table] = pd.read_csv(os.path.join(save_dir, f'{table}_synthetic.csv'))

    real_data = {}
    synthetic_data = {}
    for table_name, df in tables.items():
        real_data[table_name] = df['original_df']
        synthetic_data[table_name] = synthetic_tables[table_name]

    quality = evaluate_quality(
        real_data,
        synthetic_data,
        metadata
    )

    quality.save(os.path.join(save_dir, 'quality.pkl'))
