import os
import os.path as osp
import numpy as np
import argparse
from tqdm import tqdm as tqdm
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import mean_squared_error, r2_score
import csv

def brightness(img):
    if len(img.shape) == 3:
        img = np.mean(img, axis=2)
    brightness = np.mean(img)
    return brightness

def rms_contrast(img):
    if len(img.shape) == 3:
        img = np.mean(img, axis=2)
    mean_intensity = np.mean(img)
    squared_diff = (img - mean_intensity) ** 2
    mean_squared_diff = np.mean(squared_diff)
    rms_contrast = np.sqrt(mean_squared_diff)
    return rms_contrast

def gradient_magnitude(img):
    if len(img.shape) == 3:
        img = np.mean(img, axis=2)
    grad_x = np.gradient(img, axis=1)
    grad_y = np.gradient(img, axis=0)
    magnitude = np.sqrt(grad_x**2 + grad_y**2)
    average_gradient_magnitude = np.mean(magnitude)
    return average_gradient_magnitude

def save_low_level_feature():
    print('Saving features...')
    if opt.method == 'brain_diffuser':
        sub = opt.subject[-1]
        imgs_tr_fp = f'/storage/user/brain-diffuser/data/processed_data/subj0{sub}/nsd_train_stim_sub{sub}.npy'
        imgs_te_fp = f'/storage/user/brain-diffuser/data/processed_data/subj0{sub}/nsd_test_stim_sub{sub}.npy'
        assert osp.exists(imgs_tr_fp) and osp.exists(imgs_te_fp)
        if opt.pixel_range_255:
            print('Skipping pixel normalization')
            imgs_tr = np.load(imgs_tr_fp).astype(np.uint8)
            imgs_te = np.load(imgs_te_fp).astype(np.uint8)
        else:
            imgs_tr = np.load(imgs_tr_fp).astype(np.uint8) / 255.0
            imgs_te = np.load(imgs_te_fp).astype(np.uint8) / 255.0
    elif opt.method == 'takagi':
        raise Exception('TODO implement Takagi')

    # NOTE images must be HxWxC numpy arrays
    if opt.feature == 'brightness':
        calculate_feature = brightness
    elif opt.feature == 'rms_contrast':
        calculate_feature = rms_contrast
    elif opt.feature == 'gradient_magnitude':
        calculate_feature = gradient_magnitude
    else:
        raise Exception('Define valid feature')

    feats_tr = []
    for idx in tqdm(range(len(imgs_tr))):
        feats_tr.append(
            calculate_feature(imgs_tr[idx])
        )

    feats_te = []
    for idx in tqdm(range(len(imgs_te))):
        feats_te.append(
            calculate_feature(imgs_te[idx])
        )

    pixel_norm_str = '' if not opt.pixel_range_255 else '_pixel_range_255'
    np.save(f'low_level_features/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}/{opt.feature}_tr.npy', np.array(feats_tr))
    np.save(f'low_level_features/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}/{opt.feature}_te.npy', np.array(feats_te))

    print('Done!')

def main():

    pixel_norm_str = '' if not opt.pixel_range_255 else '_pixel_range_255'
    if not osp.exists(f'low_level_features/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}'):
        os.makedirs(f'low_level_features/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}')
        save_low_level_feature()

    feats_tr = np.load(f'low_level_features/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}/{opt.feature}_tr.npy')
    feats_te = np.load(f'low_level_features/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}/{opt.feature}_te.npy')

    if opt.method == 'brain_diffuser':
        if opt.fmri_data:
            subj_idx = opt.subject[-1]
            intermediates_tr = np.load(f'/storage/user/BrainBitsWIP/data/processed_data/{opt.subject}/nsd_train_fmriavg_nsdgeneral_sub{subj_idx}.npy')
            intermediates_te = np.load(f'/storage/user/BrainBitsWIP/data/processed_data/{opt.subject}/nsd_test_fmriavg_nsdgeneral_sub{subj_idx}.npy')
        else:
            intermediates_tr = np.load(f'brain_diffuser_feats/{opt.subject}/train_single_{opt.bottleneck_dim}/train_intermediates.npy')
            intermediates_te = np.load(f'brain_diffuser_feats/{opt.subject}/train_single_{opt.bottleneck_dim}/test_intermediates.npy')
    elif opt.method == 'takagi_text':
        raise Exception('TODO implement Takagi')

    if opt.regression_model == 'LinearRegression':
        model = LinearRegression()
    elif opt.regression_model == 'Ridge':
        model = Ridge(solver=opt.solver)

    model.fit(intermediates_tr, feats_tr)
    y_pred = model.predict(intermediates_te)

    squared_error = (feats_te - y_pred) ** 2
    mse = squared_error.mean()
    ser = (squared_error.sum() / (len(intermediates_tr) - 2)) ** 0.5
    var_se = squared_error.var()

    model_str = '_ridge' if opt.regression_model == 'Ridge' else ''
    solver_str = '_saga' if opt.solver == 'saga' else ''
    save_fp = f'low_level_features/results{model_str}{solver_str}{opt.fp_details}/{opt.feature+pixel_norm_str}/{opt.method}/{opt.subject}'
    print(save_fp)
    if not osp.exists(save_fp):
        os.makedirs(save_fp, exist_ok=True)

    csv_fp = osp.join(save_fp, 'results.csv')
    if not osp.exists(csv_fp):
        with open(csv_fp, mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['bottleneck_dim', 'mse', 'ser', 'var_se'])
    with open(csv_fp, mode='a', newline='') as f:
        writer = csv.writer(f)
        row = [opt.bottleneck_dim, mse, ser, var_se]
        writer.writerow(row)
        print(row)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--subject", type=str, default=None,
                        help="subj01 or subj02  or subj05  or subj07")
    parser.add_argument("--method", type=str, default=None,
                        help="takagi or brain_diffuser")
    parser.add_argument("--feature", type=str, default=None,
                        help="rms_contrast, brightness, gradient_magnitude")
    parser.add_argument("--bottleneck_dim", type=int, default=None,
                        help="1, 5, 10, etc.")
    parser.add_argument("--regression_model", type=str, default='LinearRegression',
                        help="LinearRegression, Ridge")
    parser.add_argument("--pixel_range_255", type=bool, default=False,
                        help="set true to prevent normalizing pixels to range [0, 1]")
    parser.add_argument("--fmri_data", type=bool, default=False,
                        help="set true to predict image statistics directly from fmri")
    parser.add_argument("--solver", type=str, default='auto',
                        help="auto, saga")
    parser.add_argument("--fp_details", type=str, default='')
    opt = parser.parse_args()
    main()
