# -*- coding: utf-8 -*-
"""
Created on Thu May  9 16:19:53 2024

@author: shara
"""

import numpy as np
from trellis import Trellis
from scipy import stats
import statsmodels.api as sm
from scipy.stats import special_ortho_group
from scipy.stats import bootstrap
import matplotlib.pyplot as plt
import pickle

def get_codebook_filename( rate, num_states ):
    return str(rate) + '_' + str(num_states) + '_' + 'codebook.npy'

def data_store_filename( source_power, rate ):
    return 'AWGN_Sim_' + str(source_power) + '_' + str(rate) + '.pickle'
def snr_to_mse(snr, source_power=1):
    return source_power/( 10**( snr/10 ) )

def MI( source_power, noise_power ):
    return 0.5 * np.log2( 1 + source_power/noise_power )

operating_rates = [1,2,3]

blocklength = 1000
source_power = 1

num_sims = 100
trellis_num_states = 256

source_sequences = []
reconstruction_sequences = []

target_mses = []
achieved_mses_before_rotation = []
achieved_mses_after_rotation = []
scaling_factors = []

trellis_encoder = Trellis( trellis_num_states, m=1000 )
for op_rate in operating_rates:
    trellis_encoder.get_codebook_from_file( op_rate, get_codebook_filename( op_rate, trellis_num_states ) )

rate = 1
trellis_encoder.rate_pmf = np.array( [1.0,0,0] ) * 1.0
for _ in range( num_sims ):
    source_sequence = np.random.normal( scale = np.sqrt(source_power), size = blocklength )
    random_rotation_matrix = special_ortho_group.rvs( blocklength )
    scrambled_sequence = np.dot( random_rotation_matrix, source_sequence )
    _, achieved_se, scrambled_reconstruction = trellis_encoder.evaluate( scrambled_sequence )


    reconstruction = np.dot( random_rotation_matrix.T, scrambled_reconstruction )
    #Scale up reconstruction
    scaling_factor = np.dot( source_sequence, source_sequence )/np.dot( reconstruction, source_sequence )
    scaling_factors.append(scaling_factor)
    scaling_factor = 1.3900101748665457
    reconstruction = reconstruction * scaling_factor
    source_sequences.append( source_sequence )
    reconstruction_sequences.append( reconstruction )
    achieved_mses_before_rotation.append( achieved_se/blocklength )

for i in range( len( source_sequences ) ):
    achieved_mses_after_rotation.append( np.sum( ( source_sequences[i] - reconstruction_sequences[i] )**2 ) )

output_power = np.mean( [ np.cov(x) for x in reconstruction_sequences ] )
noise_power = np.mean( achieved_mses_after_rotation )/blocklength
save = False
if save == True:
    data_store = {'Noise Power': noise_power,
                  'Source Sequences':source_sequences,
                  'Trellis Num States': trellis_num_states,
                  'Number of Simulations': num_sims,
                  'Blocklength': blocklength,
                  'Reconstruction Sequences': reconstruction_sequences,
                  'Pre Rotation MSE Levels': achieved_mses_before_rotation,
                  'Post Rotation MSE Levels': achieved_mses_after_rotation }

    data_store_filename = data_store_filename( source_power, rate )
    with open(data_store_filename, 'wb') as f:
        # Pickle the 'data' dictionary using the highest protocol available.
        pickle.dump(data_store, f, pickle.HIGHEST_PROTOCOL)
fig = sm.qqplot( np.array( achieved_mses_after_rotation ), dist=stats.gamma, distargs=(blocklength/2, ), loc = 0, scale = 2*noise_power, line="45")
plt.title("Observed vs Theoretical Distances --- Rate=" + str(rate) )
plt.xlabel('Theoretical')
plt.ylabel( 'Observed' )
plt.show()