from utils import *
import torch
import os
import numpy as np
import scipy.io as scio
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
if not torch.cuda.is_available():
    raise Exception('NO GPU!')

batch_size = 1
mask_path = "../datasets/TSA_simu_data/"
test_path = "../datasets/TSA_simu_data/Truth/"
test_data = LoadTest(test_path)
batch_size_test = len(test_data)
Phi_batch, Phi_s_batch = generate_shift_masks(mask_path, batch_size_test)

model = torch.load('./pretrained_model/DAUHST_9stg.pth')

result_path = './testing_result/'
if not os.path.exists(result_path):
    os.makedirs(result_path)

def test():
    Phi_batch_test = Phi_batch.cuda().float()
    Phi_s_batch_test = Phi_s_batch[0:1, :, :].cuda().float()
    test_gt = test_data.cuda().float()
    test_y = gen_2D_meas(test_gt, Phi_batch_test)
    model.eval()
    with torch.no_grad():
        model_out = model(test_y, Phi_batch_test, Phi_s_batch_test)
    pred = np.transpose(model_out.detach().cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(test_gt.cpu().numpy(), (0, 2, 3, 1)).astype(np.float32)
    model.train()
    return pred, truth


def main():
    pred, truth = test()
    name = result_path + 'Test_result.mat'
    print(f'Save reconstructed HSIs as {name}.')
    scio.savemat(name, {'truth': truth, 'pred': pred})

if __name__ == '__main__':
    main()