#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 11 02:10:47 2020

@author: sdghosh
"""
import numpy as np
import os

from util_py.file_ops import ensure_dir_exists
from util_py.configure_logging import configure_logging
from util_py.time import get_currtime_stamp

from util_setup import get_std_gauss_4mix, get_theta_set, sample_rotation_mtx
from robust_profile import RobustProfile


def get_output_directory_name(n_emprset_samp, n_optims_per_empr, n_samp, m_samp, 
            grain_level, dimen, n_theta, maxstep, pname):
    outdir = '../results/E[{}]-R[{}]-N[{}]-M[{}]-lvl[{}]-dim[{}]-J[{}]-maxstep[{}]-Pstar[{}]/'.format(
#            get_currtime_stamp(), 
            n_emprset_samp, n_optims_per_empr, n_samp, m_samp, 
            grain_level, dimen, n_theta, maxstep, pname)
    ensure_dir_exists(outdir)
    print("Writing to output dir '{}'".format(outdir))
    return outdir

def write_generated_empirical_set(p_n_sampler, n_emprset_samp, outdir):
    empr_sets=[]
    for emprn in range(n_emprset_samp):    
        empr_set=p_n_sampler.next_sample(n_samp)
        np.savez_compressed(os.path.join(outdir, 'empr_set_{:03d}.npz'.format(emprn)), empr_set)
        empr_sets.append(empr_set)    

    return empr_sets
    

def create_Rn_sampleset(thetas, p_star_sampler, empr_sets,
                        grain_level, m_samp, n_samp, outdir,
                        start_at_emprset_ndx=0):
    
    n_emprset_samp = len(empr_sets)
    
    # sample all empirical sets    
    rob_profs=[]
    for emprn in range(n_emprset_samp):    
        if emprn < start_at_emprset_ndx: continue
        rob_profs.append( RobustProfile(thetas, empr_sets[-1], 
                                        grain_level, p_star_sampler, m_samp, n_proc))
    
    # use the gmix as the sampler
    for n in range(n_optims_per_empr):
        for emprn in range(start_at_emprset_ndx,n_emprset_samp):    
            pathname='sgdpath-e[{:02d}]-r[{:02d}]'.format(emprn, n)
            endx=emprn - start_at_emprset_ndx
            val=rob_profs[endx].calculate_one_sample(n_max_iters, outdir, maxstep, pathname, plot_stride)
            print("{} retval {:20.16f}".format(pathname,val))
    


if __name__ == "__main__":
    
    # control the level of info coming out of various parts of the code.
    # you will see the names used as keys here appear in the output lines.
    # Use that as the guide to turn off more stuff from that part of the code.
    configure_logging({
            'matplotlib' : {'level':'WARN'},
            'wdro.func_approx' : {'level': 'INFO'},
            'wdro.func_c_approx' : {'level': 'WARN'},
            'wdro.wavelet' : {'level': 'WARN'},
            'wdro.rob_profile' : {'level':'INFO'},
            'optim': {'level':'WARN'},
            'optim.line_search' : {'level':'ERROR'},
            'optim.iterdata':{'level':'WARN'},
            'optim.stopctrn':{'level':'WARN'},          
            'util.bisection':{'level':'ERROR'}
            })

    # parameters taht define the setup of our experiments    
    dimen, n_theta=20, 3 
    n_samp, m_samp= 25, 50
    n_emprset_samp, n_optims_per_empr = 50, 5

    # parameters that define the wavelet approximation and the algorithm
    # that searches for the R_n estimates.

    # level of granulaity in wavelet basis
    grain_level=3
    # max iterations of sgd
    n_max_iters= 400
    # how frequently should we plot the theta projections? (neg - never)
    plot_stride=  -100000
    # max steplength in the outer sgd iterations
    maxstep =.050
    # num parallel processes to spawn off
    n_proc= 3
    #num restarts in solving the inner f_c problem
    n_restarts= 180
    
    # lock down the seed
    np.random.seed(1548663359) # 1549630359)

    # sample the thetas. The seeds being locked down esnures we can find this 
    # theta again.
    thetas = get_theta_set(n_theta, dimen)
    for nt in range(n_theta):
        thetas[nt] /= np.sqrt(thetas[nt].dot(thetas[nt]))

    # this is the base distn that we are trying to distinguish
    p_star= get_std_gauss_4mix(dimen, 1.)

    pname_star='gaussmix'
    outdir_star = get_output_directory_name(n_emprset_samp, n_optims_per_empr, n_samp, m_samp, 
            grain_level, dimen, n_theta, maxstep, pname_star)
    # save this off to the output dir in case we need it later
    np.savez_compressed(os.path.join(outdir_star, 'thetas.npz'), thetas)

    empr_sets_star = write_generated_empirical_set(p_star, n_emprset_samp, outdir_star)
    
    # P_alt = pstar + rotation    
    pname_alt='gaussmix+rotd'
    outdir_alt = get_output_directory_name(n_emprset_samp, n_optims_per_empr, n_samp, m_samp, 
            grain_level, dimen, n_theta, maxstep, pname_alt)

    rotm , shift = sample_rotation_mtx(dimen), None
    np.savez_compressed(os.path.join(outdir_alt, 'rotm.npz'), rotm)
    p_alt = get_std_gauss_4mix(dimen,1.,rotm, shift)
        
    # save this off to the output dir in case we need it later
    np.savez_compressed(os.path.join(outdir_alt, 'thetas.npz'), thetas)

    empr_sets_alt = write_generated_empirical_set(p_alt, n_emprset_samp, outdir_alt)
    
    # now we do the experiments!
    # first we evaluate the disn of R_n where P_n comes from P_star
    start_at_emprset_ndx=0
    create_Rn_sampleset(thetas, p_star, empr_sets_star,
                        grain_level, m_samp, n_samp, outdir_star,
                        start_at_emprset_ndx)

    # first we evaluate the disn of R_n where P_n comes from P_star
    start_at_emprset_ndx=0
    create_Rn_sampleset(thetas, p_star, empr_sets_alt,
                        grain_level, m_samp, n_samp, outdir_alt,
                        start_at_emprset_ndx)


    # now use the two plot figs python files to plot the two figures.