import numpy as np
import os

import matplotlib.pyplot as plt


class Evaluator:
    def __init__(self, cfg, seq_name):
        self.mse = []
        self.psnr = []
        self.ssim = []
        self.cfg = cfg
        self.seq_name = seq_name
        self.t60_error = []
        self.clarity_error = []
        self.edt_error = []
        self.spec_mse = []
        self.invalid = 0
        self.fig = plt.figure()
        self.figplot = self.fig.add_subplot(2, 1, 1)
        self.figplot1 = self.fig.add_subplot(2, 1, 2)

    def psnr_metric(self, img_pred, img_gt):
        mse = np.mean((img_pred - img_gt)**2)
        psnr = -10 * np.log(mse) / np.log(10)
        return psnr

    def measure_edt(self, h, fs=22050, decay_db=10):
        h = np.array(h)
        fs = float(fs)

        # The power of the impulse response in dB
        power = h ** 2
        energy = np.cumsum(power[::-1])[::-1]  # Integration according to Schroeder

        # remove the possibly all zero tail
        i_nz = np.max(np.where(energy > 0)[0])
        energy = energy[:i_nz]
        energy_db = 10 * np.log10(energy)
        energy_db -= energy_db[0]

        i_decay = np.min(np.where(- decay_db - energy_db > 0)[0])
        t_decay = i_decay / fs
        # compute the decay time
        decay_time = t_decay
        est_edt = (60 / decay_db) * decay_time
        return est_edt

    def measure_rt60(self, h, fs=22050, decay_db=30, plot=False, rt60_tgt=None):
        """
        Analyze the RT60 of an impulse response. Optionaly plots some useful information.
        Parameters
        ----------
        h: array_like
            The impulse response.
        fs: float or int, optional
            The sampling frequency of h (default to 1, i.e., samples).
        decay_db: float or int, optional
            The decay in decibels for which we actually estimate the time. Although
            we would like to estimate the RT60, it might not be practical. Instead,
            we measure the RT20 or RT30 and extrapolate to RT60.
        plot: bool, optional
            If set to ``True``, the power decay and different estimated values will
            be plotted (default False).
        rt60_tgt: float
            This parameter can be used to indicate a target RT60 to which we want
            to compare the estimated value.
        """

        h = np.array(h)
        fs = float(fs)

        # The power of the impulse response in dB
        power = h ** 2
        energy = np.cumsum(power[::-1])[::-1]  # Integration according to Schroeder

        # remove the possibly all zero tail
        i_nz = np.max(np.where(energy > 0)[0])
        energy = energy[:i_nz]
        energy_db = 10 * np.log10(energy)
        energy_db -= energy_db[0]
        # -5 dB headroom
        i_5db = np.min(np.where(-5 - energy_db > 0)[0])
        e_5db = energy_db[i_5db]
        t_5db = i_5db / fs

        # after decay
        # if len(np.where(-5-decay_db - energy_db >0)[0]) == 0:
        #     return 100
        i_decay = np.min(np.where(-5 - decay_db - energy_db > 0)[0])
        t_decay = i_decay / fs

        # compute the decay time
        decay_time = t_decay - t_5db
        est_rt60 = (60 / decay_db) * decay_time
        # c50 = 10.0 * np.log10((np.sum(pow_energy[:t]) / np.sum(pow_energy[t:])))
        if plot:
            import matplotlib.pyplot as plt

            # Remove clip power below to minimum energy (for plotting purpose mostly)
            energy_min = energy[-1]
            energy_db_min = energy_db[-1]
            power[power < energy[-1]] = energy_min
            power_db = 10 * np.log10(power)
            power_db -= np.max(power_db)

            # time vector
            def get_time(x, fs):
                return np.arange(x.shape[0]) / fs - i_5db / fs

            T = get_time(power_db, fs)

            # plot power and energy
            plt.plot(get_time(energy_db, fs), energy_db, label="Energy")

            # now the linear fit
            plt.plot([0, est_rt60], [e_5db, -65], "--", label="Linear Fit")
            plt.plot(T, np.ones_like(T) * -60, "--", label="-60 dB")
            plt.vlines(
                est_rt60, energy_db_min, 0, linestyles="dashed", label="Estimated RT60"
            )

            if rt60_tgt is not None:
                plt.vlines(rt60_tgt, energy_db_min, 0, label="Target RT60")

            plt.legend()

        return est_rt60
    def measure_clarity(self, signal, time=50, fs=22050):
        h2 = signal**2
        t = int((time/1000)*fs + 1)
        return 10*np.log10(np.sum(h2[:t])/np.sum(h2[t:]))
    def compute_energy_db(self, h):
        h = np.array(h)
        # The power of the impulse response in dB
        power = h ** 2
        energy = np.cumsum(power[::-1])[::-1]  # Integration according to Schroeder

        # remove the possibly all zero tail
        i_nz = np.max(np.where(energy > 0)[0])
        energy = energy[:i_nz]
        energy_db = 10 * np.log10(energy)
        energy_db -= energy_db[0]
        return  energy_db
    def evaluate_edt(self, pred_ir, gt_ir):
        np_pred_ir = pred_ir
        np_gt_ir = gt_ir
        pred_edt = self.measure_edt(np_pred_ir)
        gt_edt = self.measure_edt(np_gt_ir)
        edt_error = abs(pred_edt - gt_edt)
        self.edt_error.append(edt_error)

    def evaluate_clarity(self, pred_ir, gt_ir):
        np_pred_ir = pred_ir
        np_gt_ir = gt_ir
        pred_clarity = self.measure_clarity(np_pred_ir)
        gt_clarity = self.measure_clarity(np_gt_ir)
        clarity_error = abs(pred_clarity - gt_clarity)
        self.clarity_error.append(clarity_error)

    def evaluate_t60(self, pred_ir, gt_ir):
        np_pred_ir = pred_ir#.data.cpu().numpy()
        np_gt_ir = gt_ir#.data.cpu().numpy()
        mse = np.mean((np_pred_ir - np_gt_ir) ** 2)
        self.mse.append(mse)
        psnr = self.psnr_metric(np_pred_ir, np_gt_ir)
        self.psnr.append(psnr)
        try:
            pred_t60 = self.measure_rt60(np_pred_ir)
            gt_t60 = self.measure_rt60(np_gt_ir)
            t60_error = abs(pred_t60 - gt_t60) / gt_t60
            self.t60_error.append(t60_error)
        except:
            self.invalid += 1

    def evaluate_energy_db(self, pred_ir, gt_ir):
        pred_db = self.compute_energy_db(pred_ir)
        gt_db = self.compute_energy_db(gt_ir)
        return pred_db, gt_db
    def evaluate_spec_mse(self, pred_ir_spec, gt_ir_spec):
        self.spec_mse.append(np.mean(pred_ir_spec - gt_ir_spec)**2)
    
    def simple_summarize(self):
        result_path = os.path.join(self.cfg.result_dir,
                                   self.seq_name, 'metrics.npy')
        os.system('mkdir -p {}'.format(os.path.dirname(result_path)))
        mse = np.mean(self.mse)
        psnr = np.mean(self.psnr)
        t60_error = np.mean(self.t60_error)
        clarity_error = np.mean(self.clarity_error)
        edt_error = np.mean(self.edt_error)
        spec_mse = np.mean(self.spec_mse)
        metrics = {'mse': mse, 'psnr': psnr, 'spec_mse': spec_mse, 't60_error': t60_error,
                   'clarity_error': clarity_error, 'edt_error': edt_error}
        return metrics
    def summarize(self):
        result_path = os.path.join(self.cfg.result_dir,
            self.seq_name, 'metrics.npy')
        os.system('mkdir -p {}'.format(os.path.dirname(result_path)))
        mse = np.mean(self.mse)
        psnr = np.mean(self.psnr)
        ssim = np.mean(self.ssim)
        t60_error = np.mean(self.t60_error)
        sample_t60_error = np.mean(self.sample_t60_error)
        edt_error = np.mean(self.edt_error)
        sample_edt_error = np.mean(self.sample_edt_error)
        c50_error = np.mean(self.c50_error)
        sample_c50_error = np.mean(self.sample_c50_error)

        metrics = {'mse': mse, 'psnr': psnr, 't60_error': t60_error, 'sample_t60_error': sample_t60_error,
            'edt_error': edt_error, 'sample_edt_error': sample_edt_error, 'c50_error': c50_error, 'sample_c50_error': sample_c50_error}

        self.mse = []
        self.psnr = []
        self.ssim = []
        return metrics