import PrivMRF
import PrivMRF.utils.tools as tools
from PrivMRF.domain import Domain
import numpy as np
import pandas as pd
import json

def get_domain_by_attrs(dom_dict, columns):
    dom_dict = {attr: dom_dict[attr] for attr in dom_dict}
    dom_dict = {i: dom_dict[columns[i]] for i in range(len(columns))}
    domain = Domain(dom_dict, list(range(len(dom_dict))))
    return domain

if __name__ == '__main__':
    # should provide int data
    # data, _ = tools.read_csv('./preprocess/adult.csv')
    # data = np.array(data, dtype=int)

    data = pd.read_csv('/home/w3pang/guided_tab_ddpm/PrivMRF_old/data/fb-comments-discrete.csv')
    data = np.array(data, dtype=int)
    str_domain = json.load(open('/home/w3pang/guided_tab_ddpm/PrivMRF_old/data/fb-comments-domain.json'))
    domain = {}
    for key, val in str_domain.items():
        domain[int(key)] = val.copy()
    domain = get_domain_by_attrs(domain, list(range(data.shape[1])))
    attr_hierarchy = PrivMRF.read_hierarchy('/home/w3pang/guided_tab_ddpm/PrivMRF_old/data/fb-comments-hierarchy.json')

    # domain of each attribute should be [0, 1, ..., max_value-1]
    # attribute name should be 0, ..., column_num-1.
    # json_domain = tools.read_json_domain('./preprocess/adult.json')
    # domain = Domain(json_domain, list(range(data.shape[1])))

    # you may set hyperparameters or specify other settings here
    config = {
        'data_name':        'adult',
        'epsilon':          999,
        'exp_name':         'adult',

        'budget':           999,

        'print_interval':   200,
        'max_measure_attr_num': 5,
        'sensitivity':      6,
        'enable_attribute_hierarchy':   True,

        'save_model':       False,
    }

    # train a PrivMRF, delta=1e-5
    # for other dp parameter delta, calculate the privacy budget 
    # with cal_privacy_budget() of ./PrivMRF/utils/tools.py 
    # and hard code the budget in privacy_budget() of ./PrivMRF/utils/tools.py 
    model = PrivMRF.run(data, domain, attr_hierarchy=attr_hierarchy, \
        exp_name='exp', epsilon=999, p_config=config)

    # generate synthetic data
    syn_data = model.synthetic_data('./fb-comments-out.csv')