import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from typing import Optional 

from .lfm import PETALModule, LinearizedForwardModel
from .networks import MLP

import pytorch_lightning as pl
import torch.nn.functional as F
import os

class SGDNeuralAdjointModule(pl.LightningModule):                                     

    def regularizer(self, xhat):
        bs = xhat.size(0)
        if self.regularizer_type == 'l2':
            return self.lambda_r*torch.pow(xhat.view(bs,-1), 2).sum(-1)
        elif self.regularizer_type == 'l1':
#             bs = xhat.size(0)
            return self.lambda_r*torch.abs(xhat.view(bs,-1)).sum(-1)
        else:
            return torch.zeros(bs).to(xhat.device)#0.0
    def grad_regularizer(self, xhat):
        bs = xhat.size(0)
        if self.grad_regularizer_type == 'l2':

            dx = torch.pow(xhat[:,1:]-xhat[:,:-1],2).view(bs,-1).sum(-1)
            dy = torch.pow(xhat[...,1:]-xhat[...,:-1],2).view(bs,-1).sum(-1)
            return self.lambda_grad_r*(dx+dy)
        elif self.grad_regularizer_type == 'l1':
#             bs = xhat.size(0)
            dx = torch.abs(xhat[:,1:]-xhat[:,:-1]).view(bs,-1).sum(-1)
            dy = torch.abs(xhat[...,1:]-xhat[...,:-1]).view(bs,-1).sum(-1)
            return self.lambda_grad_r*(dx+dy)
        else:
            return torch.zeros(bs).to(xhat.device)

    @torch.enable_grad()
    def forward(
        self,   
        y,
        xinit=None,
    ):
          # Prepare x
        bs = y.size(0)
        with torch.no_grad():
            if xinit is None:
                xhat = torch.zeros(bs, self.num_range, self.num_depth).to(self.device)
            else:
                xhat = xinit.clone().to(self.device)
            assert xhat.size(0) == bs

            # prepare for optimization
            xhat.requires_grad_(True)
#             print(xhat.requires_grad)

            y = y.view(bs, -1)
#         self.net = self.net.eval().to(device)
        
        self.prepare_optimizer(xhat, self.lr)
        
        self.pbar = tqdm(range(self.num_epochs), leave=self.verbose, disable=not self.verbose)

        errors = []
        for i in self.pbar:

            error = self.step(y,xhat)
        
        
            # record losses
            errors.append(error)
#             pbar.set_postfix(error_dict)
        return xhat.detach(), errors
    def nandiff(self, a, b):
        with torch.no_grad():
            a[a.isnan()] = 0
            b[b.isnan()] = 0
        diff = a-b
        return diff, torch.logical_or(a.isnan(), b.isnan())

    def validation_epoch_end(self, validation_step_outputs):
        self._shared_epoch_end(validation_step_outputs, "val")

    def validation_step(self, batch, batch_idx):
        return self._shared_eval(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self._shared_eval(batch, batch_idx, "test")
    
    def test_epoch_end(self, test_step_outputs):
        self._shared_epoch_end(test_step_outputs, "test")

    def _shared_eval(self, batch, batch_idx, prefix):
        x, y, init = batch
        xhat, _ = self(y, init)
        loss = F.mse_loss(x, xhat).item()
        self.log(f"{prefix}_loss", loss, prog_bar=True)
        out = {'pred': xhat, 'mse': loss} 
        
        if self.ssp_transform is not None:
            un_x = self.ssp_transform.unnormalize(x)
            un_xhat = self.ssp_transform.unnormalize(xhat)
            un_loss = torch.sqrt(F.mse_loss(un_xhat, un_x,reduction='none').mean(dim=(-1,-2))).mean().item()

#             un_loss = F.mse_loss(un_xhat, un_x, reduce=False).item()
            self.log(f"{prefix}_unnorm_rmse", un_loss)#, prog_bar=True)
            out["unnorm_loss"]=un_loss
        return out
    
    def _shared_epoch_end(self, step_outputs, prefix):
        if self.save_name is not None:
            save_out = os.path.join(self.save_path, self.save_name)
            path, name = os.path.split(save_out)
            save_out = os.path.join(path, f'{prefix}_{name}')
            os.makedirs(path, exist_ok=True)
            
            # Gather results
            all_preds = torch.concat([output['pred'] for output in step_outputs])#.view(-1,self.num_range, self.num_depth)
            torch.save(all_preds, save_out)
            
            
            
#         x, y, init = batch
#         x_hat, _ = self(y, init)
#         loss = F.mse_loss(x, x_hat).item()
#         self.log(f"{prefix}_loss", loss, prog_bar=True)
#         return {'pred': x_hat, 'loss': loss} 

   
    def prepare_optimizer(self, xhat, lr):
        self.optim = torch.optim.SGD([xhat], lr=lr)
    
    @torch.enable_grad()
#     @torch.inference_mode(False)
    def step(self, y, xhat):
#         print(xhat.requires_grad)
        bs = xhat.size(0)
        self.optim.zero_grad()
        # forward bs x 2 x 20 x 20
        yhat = self.net(xhat)
        if self.probs:
            yhat = yhat[0]
            
        # compute target loss
#         print(yhat.shape)
        diff, nanmask = self.nandiff(yhat.reshape(bs,-1),y)
#             diff = torch.pow(diff, 2)
        diff = torch.abs(diff)
        diff[nanmask] = torch.nan
        # reduce to bs x 1
        obs_loss = torch.nanmean(diff, dim=-1)

        mask = obs_loss > self.early_cutoff if self.early_cutoff is not None else torch.ones(bs).to(device)
        loss = 0.0
        loss += obs_loss
#         print(obs_loss.shape, mask.sum())


        boundary_loss = (nn.functional.relu(torch.nan_to_num(yhat) - self.bound_loss_upper) 
                         + nn.functional.relu(self.bound_loss_lower-torch.nan_to_num(yhat))
                        ).sum(axis=(1,2,3))
        loss += self.bound_loss_lambda*boundary_loss
#                 print(loss, mask.sum())
        reg_loss = self.regularizer(xhat)
        grad_reg_loss = self.grad_regularizer(xhat)
        loss += reg_loss + grad_reg_loss
        loss = (loss*mask).sum()
#         print(loss.requires_grad)
        loss.backward()
#                 return loss
        self.optim.step()
    
        self.pbar.set_postfix(
            {f'obs_loss':f'{obs_loss.sum().item():.2e}',
             f'boundary_loss':f'{boundary_loss.sum().item():.2e}',
             f'{self.regularizer_type}':f'{reg_loss.sum().item():.2e}',
             f'grad{self.grad_regularizer_type}':f'{grad_reg_loss.sum().item():.2e}',
                'num_opt':mask.sum().item()})
        
        return loss.item()

class SGDLFMModule(SGDNeuralAdjointModule):
    def __init__(self, 
        num_epochs: int = 1000,
        early_cutoff = 1e-2,
        verbose:bool = False,
        probs: bool = False,
        bound_loss_lambda: float = 1.,
        bound_loss_upper: float = 7.,
        bound_loss_lower: float = -5.,
        regularizer: str = 'none',
        lambda_r: float = 0.0,
        grad_regularizer: str = 'none',
        lambda_grad_r: float = 0.0,
        num_range: int = 11,
        num_depth: int = 231,
        lr: float =1.0,
        save_path: str = './results/',
        save_name: Optional[str] = None,
        ssp_transform: Optional[nn.Module]=None,
        at_transform: Optional[nn.Module]=None,
        lfm_basepath: str = './data/flat_earth', 
        slice_num: int = 1, 
        time_idx: int = 1000,
    ):
#         super().__init__(
#             num_epochs=num_epochs,
#             early_cutoff =early_cutoff,
#             verbose = verbose, 
#             probs = probs,
#             bound_loss_lambda =bound_loss_lambda,
#             bound_loss_upper =bound_loss_upper,
#             bound_loss_lower =bound_loss_lower,
#             regularizer =regularizer,
#             lambda_r =lambda_r,
#             grad_regularizer =grad_regularizer,
#             lambda_grad_r=lambda_grad_r,
#             num_range =num_range,
#             num_depth =num_depth,
#             lr =lr,
#             save_path =save_path,
#             save_name =save_name,
#             ssp_transform=ssp_transform,
#             at_transform=at_transform,
#         )
        super().__init__()
        self.num_epochs = num_epochs
        self.early_cutoff = early_cutoff
        self.verbose = verbose
        self.probs = probs
        
        self.lr = lr
        self.bound_loss_lambda = bound_loss_lambda
        self.bound_loss_upper = bound_loss_upper
        self.bound_loss_lower = bound_loss_lower
        
        self.regularizer_type = regularizer
        self.lambda_r = lambda_r
        self.grad_regularizer_type= grad_regularizer
        self.lambda_grad_r = lambda_grad_r
        self.num_range = num_range
        self.num_depth = num_depth
        
        self.save_path = save_path
        self.save_name = save_name
        self.ssp_transform = ssp_transform
        self.at_transform = at_transform
        
        self.ssp_transform = ssp_transform
        self.lfm_basepath = lfm_basepath
        self.slice_num = slice_num
        self.time_idx = time_idx
        self.probs = False
        
        self.load_net()
    def load_net(self):
        self.net = LinearizedForwardModel(
            basepath=self.lfm_basepath, 
            slice_num=self.slice_num, 
            time_idx=self.time_idx,
            nan_to_zero = True,
            ssp_transform=self.ssp_transform, 
            at_transform=self.at_transform
        )

             
class TrainedSGDNeuralAdjointModule(SGDNeuralAdjointModule):
    def __init__(
        self,
#         net,
        ckpt_path: str = '',
        num_epochs: int = 1000,
        early_cutoff = 1e-2,
        verbose:bool = False,
        probs: bool = False,
        bound_loss_lambda: float = 1.,
        bound_loss_upper: float = 7.,
        bound_loss_lower: float = -5.,
        regularizer: str = 'none',
        lambda_r: float = 0.0,
        grad_regularizer: str = 'none',
        lambda_grad_r: float = 0.0,
        num_range: int = 11,
        num_depth: int = 231,
        lr: float =1.0,
        save_path: str = './results/',
        save_name: Optional[str] = None,
        ssp_transform: Optional[nn.Module]=None,
        at_transform: Optional[nn.Module]=None,

    ):
        super().__init__()
        self.load_net(ckpt_path)
        self.num_epochs = num_epochs
        self.early_cutoff = early_cutoff
        self.verbose = verbose
        self.probs = probs
        
        self.lr = lr
        self.bound_loss_lambda = bound_loss_lambda
        self.bound_loss_upper = bound_loss_upper
        self.bound_loss_lower = bound_loss_lower
        
        self.regularizer_type = regularizer
        self.lambda_r = lambda_r
        self.grad_regularizer_type= grad_regularizer
        self.lambda_grad_r = lambda_grad_r
        self.num_range = num_range
        self.num_depth = num_depth
        
        self.save_path = save_path
        self.save_name = save_name
        
        self.ssp_transform = ssp_transform    

class SGDWANModule(TrainedSGDNeuralAdjointModule):
                               
    def load_net(self, ckpt_path):
        self.net = PETALModule.load_from_checkpoint(checkpoint_path=ckpt_path)


    
class SGDMLPModule(TrainedSGDNeuralAdjointModule):
                               
    def load_net(self, ckpt_path):
        self.net = MLP.load_from_checkpoint(checkpoint_path=ckpt_path)

        
        
class SGDPETALModule(SGDWANModule):                  
    def forward(
        self,   
        y,
        xinit=None,
    ):
          # Prepare x
        bs = y.size(0)
        with torch.no_grad():
            if xinit is None:
                xhat = torch.zeros(bs, self.num_range,self.num_depth).to(self.device)
            else:
                xhat = xinit.clone()
                
            zhat = self.net.encoder(xhat)
            assert zhat.size(0) == bs, f"bs is {zhat.size()}, should be {bs}"

            # prepare for optimization
            zhat.requires_grad_(True)

            y = y.view(bs, -1)
        
        self.prepare_optimizer(zhat, self.lr)
        
        self.pbar = tqdm(range(self.num_epochs), leave=self.verbose, disable=not self.verbose)

        errors = []
        for i in self.pbar:

            error = self.step(y,zhat)
        
        
            # record losses
            errors.append(error)
#             pbar.set_postfix(error_dict)
        return self.net.decoder(zhat).detach(), errors

    @torch.enable_grad()
    def step(self, y, zhat):
        bs = zhat.size(0)
        xhat = self.net.decoder(zhat)
        return super().step(y, xhat)                 
      
class NeuralAdjoint(nn.Module):
    def __init__(
        self,
        net,
        num_epochs: int = 300,
        early_cutoff=1.5e-5,
        verbose:bool = True,
        probs: bool = False,
        bound_loss_lambda: float = 1.,
        bound_loss_upper: float = 10.,
        bound_loss_lower: float = -2.,
        regularizer: str = 'none',
        lambda_r: float = 0.0,
        grad_regularizer: str = 'none',
        lambda_grad_r: float = 0.0,
    ):
        super().__init__()
        self.net = net
        self.num_epochs = num_epochs
        self.early_cutoff = early_cutoff
        self.verbose = verbose
        self.probs = probs
        
        self.bound_loss_lambda = bound_loss_lambda
        self.bound_loss_upper = bound_loss_upper
        self.bound_loss_lower = bound_loss_lower
        
        self.regularizer_type = regularizer
        self.lambda_r = lambda_r
        self.grad_regularizer_type= grad_regularizer
        self.lambda_grad_r = lambda_grad_r
        
    def regularizer(self, xhat):
        bs = xhat.size(0)
        if self.regularizer_type == 'l2':
            return self.lambda_r*torch.pow(xhat.view(bs,-1), 2).sum(-1)
        elif self.regularizer_type == 'l1':
#             bs = xhat.size(0)
            return self.lambda_r*torch.abs(xhat.view(bs,-1)).sum(-1)
        else:
            return torch.zeros(bs).to(xhat.device)#0.0
    def grad_regularizer(self, xhat):
        bs = xhat.size(0)
        if self.grad_regularizer_type == 'l2':

            dx = torch.pow(xhat[:,1:]-xhat[:,:-1],2).view(bs,-1).sum(-1)
            dy = torch.pow(xhat[...,1:]-xhat[...,:-1],2).view(bs,-1).sum(-1)
            return self.lambda_grad_r*(dx+dy)
        elif self.grad_regularizer_type == 'l1':
#             bs = xhat.size(0)
            dx = torch.abs(xhat[:,1:]-xhat[:,:-1]).view(bs,-1).sum(-1)
            dy = torch.abs(xhat[...,1:]-xhat[...,:-1]).view(bs,-1).sum(-1)
            return self.lambda_grad_r*(dx+dy)
        else:
            return torch.zeros(bs).to(xhat.device)
        
    def prepare_optimizer(self, xhat, lr):
        self.optim = torch.optim.SGD([xhat], lr=lr)

    def forward(
        self,   
        y,
        xinit=None,
        lr=0.001,
        device=1,
    ):
          # Prepare x
        bs = y.size(0)
        with torch.no_grad():
            if xinit is None:
                xhat = torch.zeros(bs, 11, 231).to(device)
            else:
                xhat = xinit.clone().to(device)
            assert xhat.size(0) == bs

            # prepare for optimization
            xhat.requires_grad_(True)

            y = y.to(device).view(bs, -1)
        self.net = self.net.eval().to(device)
        
        self.prepare_optimizer(xhat, lr)
        
        self.pbar = tqdm(range(self.num_epochs), leave=self.verbose, disable=not self.verbose)

        errors = []
        for i in self.pbar:

            error = self.step(y,xhat)
        
        
            # record losses
            errors.append(error)
#             pbar.set_postfix(error_dict)
        return xhat.detach().cpu(), errors
    def nandiff(self, a, b):
        with torch.no_grad():
            a[a.isnan()] = 0
            b[b.isnan()] = 0
        diff = a-b
        return diff, torch.logical_or(a.isnan(), b.isnan())
class SGDNeuralAdjoint(NeuralAdjoint):

    def step(self, y, xhat):
        bs = xhat.size(0)
        print(xhat.requires_grad())
        self.optim.zero_grad()
        # forward bs x 2 x 20 x 20
        yhat = self.net(xhat)
        if self.probs:
            yhat = yhat[0]
            
        # compute target loss
#         print(yhat.shape)
        diff, nanmask = self.nandiff(yhat.reshape(bs,-1),y)
#             diff = torch.pow(diff, 2)
        diff = torch.abs(diff)
        diff[nanmask] = torch.nan
        # reduce to bs x 1
        obs_loss = torch.nanmean(diff, dim=-1)

        mask = obs_loss > self.early_cutoff if self.early_cutoff is not None else torch.ones(bs).to(device)
        loss = 0.0
        loss += obs_loss
#         print(obs_loss.shape, mask.sum())


        boundary_loss = (nn.functional.relu(torch.nan_to_num(yhat) - self.bound_loss_upper) 
                         + nn.functional.relu(self.bound_loss_lower-torch.nan_to_num(yhat))
                        ).sum(axis=(1,2,3))
        loss += self.bound_loss_lambda*boundary_loss
#                 print(loss, mask.sum())
        reg_loss = self.regularizer(xhat)
        grad_reg_loss = self.grad_regularizer(xhat)
        loss += reg_loss + grad_reg_loss
        loss = (loss*mask).sum()
        loss.backward()
#                 return loss
        self.optim.step()
    
        self.pbar.set_postfix(
            {f'obs_loss':f'{obs_loss.sum().item():.2e}',
             f'boundary_loss':f'{boundary_loss.sum().item():.2e}',
             f'{self.regularizer_type}':f'{reg_loss.sum().item():.2e}',
             f'grad{self.grad_regularizer_type}':f'{grad_reg_loss.sum().item():.2e}',
                'num_opt':mask.sum().item()})
        
        return loss.item()
class LBFGSNeuralAdjoint(NeuralAdjoint):
    def step(self,y, xhat):
        def closure():
            bs = y.size(0)
#             print(y.size(), xhat.size())
            self.optim.zero_grad()
            # forward bs x 2 x 20 x 20
            yhat = self.net(xhat)
            if self.probs:
                yhat = yhat[0]
#             diff = yhat.view(bs,-1)-y
            diff, nanmask = self.nandiff(yhat.reshape(bs,-1),y)

            diff[nanmask] = 0.0
            #                 print(torch.isnan(diff).sum())
            diff = torch.pow(diff, 2)

            #                 diff[nanmask] = torch.nan
            # reduce to bs x 1
            obs_loss = torch.mean(diff, dim=-1)

            # reduce to bs x 1
            #                 obs_loss = torch.nanmean(torch.pow(yhat.view(bs, -1) - y, 2), dim=-1)
            loss = 0.0
            mask = obs_loss > self.early_cutoff if self.early_cutoff is not None else torch.ones(bs).to(device)
            loss += obs_loss

            boundary_loss = (nn.functional.relu(torch.nan_to_num(yhat) - self.bound_loss_upper) 
                             + nn.functional.relu(self.bound_loss_lower-torch.nan_to_num(yhat))
                            ).sum(axis=(1,2,3))
            loss += self.bound_loss_lambda*boundary_loss
    #                 print(loss, mask.sum())
            reg_loss = self.regularizer(xhat)
            grad_reg_loss = self.grad_regularizer(xhat)
            loss += reg_loss + grad_reg_loss
#             loss += reg_loss
            loss = (loss*mask).sum()
            loss.backward()
            self.pbar.set_postfix(
                {f'obs_loss':f'{obs_loss.sum().item():.2e}',
                 f'boundary_loss':f'{boundary_loss.sum().item():.2e}',
                 f'{self.regularizer_type}':f'{reg_loss.sum().item():.2e}',
                 f'grad{self.grad_regularizer_type}':f'{grad_reg_loss.sum().item():.2e}',
                 'num_opt':mask.sum().item()})

            return loss
        loss = self.optim.step(closure)

        return loss.item()
        


class SGDDecoderNeuralAdjoint(NeuralAdjoint):
    def forward(
        self,   
        y,
        zinit=None,
#         embed_dim =
        lr=0.001,
        device=1,
    ):
          # Prepare x
        bs = y.size(0)
        with torch.no_grad():
            if zinit is None:
                zhat = torch.zeros(bs, self.net.embed_dim).to(device)
            else:
                zhat = zinit.clone().to(device)
            assert zhat.size(0) == bs, f"bs is {zhat.size()}, should be {bs}"

            # prepare for optimization
            zhat.requires_grad_(True)

            y = y.to(device).view(bs, -1)
        self.net = self.net.eval().to(device)
        
        self.prepare_optimizer(zhat, lr)
        
        self.pbar = tqdm(range(self.num_epochs), leave=self.verbose, disable=not self.verbose)

        errors = []
        for i in self.pbar:

            error = self.step(y,zhat)
        
        
            # record losses
            errors.append(error)
#             pbar.set_postfix(error_dict)
        return self.net.net.decode(zhat).view(bs, 11, 231).detach().cpu(), errors
    def step(self, y, zhat):
        bs = zhat.size(0)
        self.optim.zero_grad()
        # forward bs x 2 x 20 x 20
        yhat, _ , xhat = self.net.na_forward(zhat, decode=True)
#         if self.probs:
#             yhat = yhat[0]
            
        # compute target loss
#         print(yhat.shape)
        diff, nanmask = self.nandiff(yhat.reshape(bs,-1),y)
#             diff = torch.pow(diff, 2)
        diff = torch.abs(diff)
        diff[nanmask] = torch.nan
        # reduce to bs x 1
        obs_loss = torch.nanmean(diff, dim=-1)

        mask = obs_loss > self.early_cutoff if self.early_cutoff is not None else torch.ones(bs).to(device)
        loss = 0.0
        loss += obs_loss
#         print(obs_loss.shape, mask.sum())


        boundary_loss = (nn.functional.relu(torch.nan_to_num(yhat) - self.bound_loss_upper) 
                         + nn.functional.relu(self.bound_loss_lower-torch.nan_to_num(yhat))
                        ).sum(axis=(1,2,3))
        loss += self.bound_loss_lambda*boundary_loss
#                 print(loss, mask.sum())
        reg_loss = self.regularizer(xhat)
        grad_reg_loss = self.grad_regularizer(xhat)
        loss += reg_loss + grad_reg_loss
        loss = (loss*mask).sum()
        loss.backward()
#                 return loss
        self.optim.step()
    
        self.pbar.set_postfix(
            {f'obs_loss':f'{obs_loss.sum().item():.2e}',
             f'boundary_loss':f'{boundary_loss.sum().item():.2e}',
             f'{self.regularizer_type}':f'{reg_loss.sum().item():.2e}',
             f'grad{self.grad_regularizer_type}':f'{grad_reg_loss.sum().item():.2e}',
                'num_opt':mask.sum().item()})
        
        return loss.item()