import torch
from torch import nn
import torch.nn.functional as F
from base_classes import BaseGNN
from model_configurations import set_block, set_function

from torch_geometric.nn import MLP, GINConv, global_add_pool
# Define the GNN model.
class GNN_LARGE(nn.Module):
  def __init__(self,  in_channels,  out_channels, opt, device=torch.device('cpu')):
    super(GNN_LARGE, self).__init__()
    self.f = set_function(opt)
    block = set_block(opt)
    time_tensor = torch.tensor([0, opt['time']]).to(device)
    self.regularization_fns = []

    self.odeblock = block(self.f, self.regularization_fns, opt, data=None, device=device, t=time_tensor).to(device)
    # self.alpha_ode = nn.Parameter(torch.tensor(torch.tensor(0.1), requires_grad=True))
    # print("self.alpha_ode: ",self.alpha_ode )
    self.decoder = nn.Linear(opt['hidden_dim'], out_channels)
    self.encoder = nn.Linear(in_channels, opt['hidden_dim'])
    self.opt = opt

    if self.opt['use_mlp']:
      self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
      self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
    if opt['fc_out']:
      self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
    if self.opt['batch_norm']:
      self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim'])
      self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim'])




  def forward(self, data):
    # Encode each node based on its feature.
    x = data.graph['node_feat']
    edge_index = data.graph['edge_index']


    x = F.dropout(x, self.opt['input_dropout'], training=self.training)
    x = self.encoder(x)

    if self.opt['use_mlp']:
      x = F.dropout(x, self.opt['dropout'], training=self.training)
      x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training)
      x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training)
    # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper



    if self.opt['batch_norm']:
      x = self.bn_in(x)

    # Solve the initial value problem of the ODE.


    if 'graphcon' in self.opt['function']:
      x = torch.cat([x, x], dim=-1)
      self.odeblock.set_x0(x)

      if self.training and self.odeblock.nreg > 0:
        z, self.reg_states = self.odeblock(x)
      else:
          # alpha_ode = torch.sigmoid(self.alpha_ode)
          # alpha_ode = self.alpha_ode
        z = self.odeblock(x,edge_index,edge_weight=None)
      z = z[:,self.opt['hidden_dim']:]
    elif 'term' in self.opt['function']:
      #x2 = torch.zeros_like(x, device=self.device)
      x2 = x.clone()
      # x2 = torch.ones_like(x, device=self.device)
      # x2 = torch.ones_like(x,device=self.device)
      for _ in range(self.opt['num_terms'] - 1):
        x = torch.cat((x, x2), dim=1)
      self.odeblock.set_x0(x)

      if self.training and self.odeblock.nreg > 0:
        z, self.reg_states = self.odeblock(x)
      else:
        # alpha_ode = torch.sigmoid(self.alpha_ode)
        # alpha_ode = self.alpha_ode
        z = self.odeblock(x,edge_index,edge_weight=None)
      z = z[:,0:self.opt['hidden_dim']]

    else:

      self.odeblock.set_x0(x)

      if self.training and self.odeblock.nreg > 0:
        z, self.reg_states = self.odeblock(x)
      else:
        # alpha_ode = torch.sigmoid(self.alpha_ode)
        # alpha_ode = self.alpha_ode
        z = self.odeblock(x,edge_index,edge_weight=None)

    # self.odeblock.set_x0(x)

    # if self.training and self.odeblock.nreg > 0:
    #   z, self.reg_states = self.odeblock(x)
    # else:
    #   # alpha_ode = torch.sigmoid(self.alpha_ode)
    #   # alpha_ode = self.alpha_ode
    #   z = self.odeblock(x,edge_index,edge_weight=None)
    #



    # Activation.
    z = F.relu(z)

    if self.opt['fc_out']:
      z = self.fc(z)
      z = F.relu(z)

    # Dropout.
    z = F.dropout(z, self.opt['dropout'], training=self.training)

    # Decode each node embedding to get node label.

    # z = self.m2(z)
    z = self.decoder(z)
    # z = global_add_pool(z, batch).squeeze(-1)
    return z
