import sys
import os
import copy
import json
import logging
import datetime

dataset = sys.argv[1]
triplet = sys.argv[2]
path = './record'

iterations = int(sys.argv[3])

kge_model = sys.argv[4]
kge_batch = 1024
kge_neg = 256
kge_dim = 100
kge_gamma = 24
kge_alpha = 1
kge_lr = 0.001
kge_iters = 10000
kge_tbatch = 16
kge_reg = 0.0

if kge_model == 'TransE':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 1024, 256, 1000, 24.0, 1.0, 0.0001, 150000, 16
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 1024, 256, 1000, 9.0, 1.0, 0.00005, 100000, 16
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 512, 1024, 500, 12.0, 0.5, 0.0001, 80000, 8
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 512, 1024, 500, 6.0, 0.5, 0.00005, 80000, 8

if kge_model == 'DistMult':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 1000, 500.0, 1.0, 0.001, 150000, 16, 0.000002
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 1000, 200.0, 1.0, 0.001, 100000, 16, 0.00001
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 500, 200.0, 1.0, 0.001, 80000, 8, 0.00001
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 500, 200.0, 1.0, 0.002, 80000, 8, 0.000005

if kge_model == 'ComplEx':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 1000, 500.0, 1.0, 0.001, 150000, 16, 0.000002
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 1000, 200.0, 1.0, 0.001, 100000, 16, 0.00001
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 500, 200.0, 1.0, 0.001, 80000, 8, 0.00001
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 500, 200.0, 1.0, 0.002, 80000, 8, 0.000005

rule_threshold_rule = sys.argv[5]
rule_threshold_trip = sys.argv[6]
rule_iters = 1000
rule_lr = 0.0001
rule_topk = 100
rule_update = int(sys.argv[7])

weight = 2.0

job_id = sys.argv[8]

# ------------------------------------------

def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d)

def cmd_kge(work_path, model):
    if model == 'TransE':
        return 'bash run.sh train {} {} 0 0 {} {} {} {} {} {} {} {} {} {}'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, work_path, rule_topk)
    if model == 'DistMult':
        return 'bash run.sh train {} {} 0 0 {} {} {} {} {} {} {} {} {} {} -r {}'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, work_path, rule_topk, kge_reg)
    if model == 'ComplEx':
        return 'bash run.sh train {} {} 0 0 {} {} {} {} {} {} {} {} {} {} -de -dr -r {}'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, work_path, rule_topk, kge_reg)

def cmd_rule(work_path, with_score):
    cmd_rule = './rule/rule -triplet {}/train_rule.txt -outtrip {}/pred_rule.txt -outrule {}/rule.txt -outcand {}/candidate.txt -threshold {} -iterations {} -lr {}'.format(work_path, work_path, work_path, work_path, rule_threshold_rule, rule_iters, rule_lr)
    if with_score == True:
        cmd_rule = cmd_rule + ' -score {}/score.txt'.format(work_path)
    return cmd_rule

def save_cmd(save_path):
    with open(save_path, 'w') as fo:
        fo.write('job_id: {}\n'.format(job_id))
        fo.write('dataset: {}\n'.format(dataset))
        fo.write('triplet: {}\n'.format(triplet))
        fo.write('iterations: {}\n'.format(iterations))
        fo.write('kge_model: {}\n'.format(kge_model))
        fo.write('kge_batch: {}\n'.format(kge_batch))
        fo.write('kge_neg: {}\n'.format(kge_neg))
        fo.write('kge_dim: {}\n'.format(kge_dim))
        fo.write('kge_gamma: {}\n'.format(kge_gamma))
        fo.write('kge_alpha: {}\n'.format(kge_alpha))
        fo.write('kge_lr: {}\n'.format(kge_lr))
        fo.write('kge_iters: {}\n'.format(kge_iters))
        fo.write('kge_tbatch: {}\n'.format(kge_tbatch))
        fo.write('kge_reg: {}\n'.format(kge_reg))
        fo.write('rule_threshold_rule: {}\n'.format(rule_threshold_rule))
        fo.write('rule_threshold_trip: {}\n'.format(rule_threshold_trip))
        fo.write('rule_iters: {}\n'.format(rule_iters))
        fo.write('rule_lr: {}\n'.format(rule_lr))
        fo.write('rule_update: {}\n'.format(rule_update))
        fo.write('weight: {}\n'.format(weight))

time = str(datetime.datetime.now()).replace(' ', '_')
path = path + '/' + time
ensure_dir(path)
save_cmd('{}/cmd.txt'.format(path))

# ------------------------------------------

os.system('cp {}/train.txt {}/train.txt'.format(dataset, path))
os.system('cp {}/train.txt {}/train_rule.txt'.format(dataset, path))
os.system('cp {}/train.txt {}/train_augment.txt'.format(dataset, path))

os.system(cmd_rule(path, False))
os.system('cp {}/candidate.txt {}/candidate_init.txt'.format(path, path))

for k in range(iterations):

    work_path = path + '/' + str(k)
    ensure_dir(work_path)

    os.system('cp {}/train_augment.txt {}/train_kge.txt'.format(path, work_path))
    os.system('cp {}/candidate.txt {}/candidate.txt'.format(path, work_path))

    os.system('cp {}/{} {}/train_rule.txt'.format(path, triplet, work_path))
    os.system(cmd_kge(work_path, kge_model))

    if rule_update == 1:
        os.system(cmd_rule(work_path, True))
    else:
        os.system(cmd_rule(work_path, False))
    os.system('python3 ./rule/add_triplet.py {}/pred_rule.txt {}/{} {}/train_augment.txt {}'.format(work_path, path, triplet, work_path, rule_threshold_trip))
    os.system('python3 ./rule/evaluate.py {}/pred_rule.txt {}/pred_kge.txt {}/result_plus.txt {}'.format(work_path, work_path, work_path, weight))
    os.system('cp {}/train_augment.txt {}/train_augment.txt'.format(work_path, path))
    os.system('cp {}/candidate.txt {}/candidate.txt'.format(work_path, path))

