import jax.numpy as jnp
from jax.lax import stop_gradient
import flax.linen as nn
import functools
from typing import Iterable, Union



LAYERS = {'resnet18': [2, 2, 2, 2],
          'resnet34': [3, 4, 6, 3],
          'resnet50': [3, 4, 6, 3],
          'resnet101': [3, 4, 23, 3],
          'resnet152': [3, 8, 36, 3]}


class BasicBlock(nn.Module):
    features: int
    kernel_size: Union[int, Iterable[int]]=(3, 3)
    downsample: bool=False
    stride: bool=True  # only here for compatibility with bottle neck
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    block_name: str=None
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, act, train=True):

        residual = x 
        
        x = nn.Conv(features=self.features, 
                    kernel_size=self.kernel_size, 
                    strides=(2, 2) if self.downsample else (1, 1),
                    padding=((1, 1), (1, 1)),
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(x)

        x = batch_norm(x,
                       train=train,
                       epsilon=1e-05,
                       momentum=0.1,
                       params=None, 
                       dtype=self.dtype) 
        x = nn.relu(x)

        x = nn.Conv(features=self.features, 
                    kernel_size=self.kernel_size, 
                    strides=(1, 1), 
                    padding=((1, 1), (1, 1)),
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(x)

        x = batch_norm(x,
                       train=train,
                       epsilon=1e-05,
                       momentum=0.1,
                       params=None, 
                       dtype=self.dtype) 

        if self.downsample:
            residual = nn.Conv(features=self.features, 
                               kernel_size=(1, 1), 
                               strides=(2, 2), 
                               kernel_init=self.kernel_init,
                               use_bias=False,
                               dtype=self.dtype)(residual)

            residual = batch_norm(residual,
                                  train=train,
                                  epsilon=1e-05,
                                  momentum=0.1,
                                  params=None, 
                                  dtype=self.dtype) 
        
        x += residual
        x = nn.relu(x)
        # act[self.block_name] = x
        return x


class Bottleneck(nn.Module):
    features: int
    kernel_size: Union[int, Iterable[int]]=(3, 3)
    downsample: bool=False
    stride: bool=True
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    block_name: str=None
    expansion: int=4
    dtype: str='float32'

    @nn.compact
    def __call__(self, x, act, train=True):

        residual = x 
        
        x = nn.Conv(features=self.features, 
                    kernel_size=(1, 1), 
                    strides=(1, 1),
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(x)

        x = batch_norm(x,
                       train=train,
                       epsilon=1e-05,
                       momentum=0.1,
                       params=None,
                       dtype=self.dtype) 
        x = nn.relu(x)

        x = nn.Conv(features=self.features, 
                    kernel_size=(3, 3), 
                    strides=(2, 2) if self.downsample and self.stride else (1, 1), 
                    padding=((1, 1), (1, 1)),
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(x)
        
        x = batch_norm(x,
                       train=train,
                       epsilon=1e-05,
                       momentum=0.1,
                       params=None,
                       dtype=self.dtype) 
        x = nn.relu(x)

        x = nn.Conv(features=self.features * self.expansion, 
                    kernel_size=(1, 1), 
                    strides=(1, 1), 
                    kernel_init=self.kernel_init,
                    use_bias=False,
                    dtype=self.dtype)(x)

        x = batch_norm(x,
                       train=train,
                       epsilon=1e-05,
                       momentum=0.1,
                       params=None,
                       dtype=self.dtype) 

        if self.downsample:
            residual = nn.Conv(features=self.features * self.expansion, 
                               kernel_size=(1, 1), 
                               strides=(2, 2) if self.stride else (1, 1), 
                               kernel_init=self.kernel_init, 
                               use_bias=False,
                               dtype=self.dtype)(residual)

            residual = batch_norm(residual,
                                  train=train,
                                  epsilon=1e-05,
                                  momentum=0.1,
                                  params=None,
                                  dtype=self.dtype) 
        
        x += residual
        x = nn.relu(x)
        # act[self.block_name] = x
        return x

    
    
    

class ResNet(nn.Module):
    architecture: str='resnet50'
    block: nn.Module=BasicBlock
    kernel_init: functools.partial=nn.initializers.lecun_normal()
    bias_init: functools.partial=nn.initializers.zeros
    dtype: str='float32'
    greedy: bool=False
    dataset: str='cifar10'

    def setup(self):
        self.param_dict = None

    @nn.compact
    def __call__(self, x, train=True):

        act = {}

        if self.dataset in ['cifar10', 'imagenet32', 'cifar100', 'tinyimagenet']:
            x = nn.Conv(features=64, 
                        kernel_size=(3, 3),
                        kernel_init=self.kernel_init,
                        strides=(1, 1), 
                        padding=((1, 1), (1, 1)),
                        use_bias=False,
                        dtype=self.dtype)(x)
        else:
            x = nn.Conv(features=64, 
                        kernel_size=(7, 7),
                        kernel_init=self.kernel_init,
                        strides=(2, 2), 
                        padding=((3, 3), (3, 3)),
                        use_bias=False,
                        dtype=self.dtype)(x)

        # act['conv1'] = x

        x = batch_norm(x,
                       train=train,
                       epsilon=1e-05,
                       momentum=0.1,
                       params=None,
                       dtype=self.dtype)
        x = nn.relu(x)

        if self.dataset not in ['cifar10', 'imagenet32', 'cifar100']:
            x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))

        # Layer 1
        down = self.block.__name__ == 'Bottleneck'
        for i in range(LAYERS[self.architecture][0]):

            if self.greedy:
                act[f'block1_{i}'] = x
                x = stop_gradient(x)

            x = self.block(features=64,
                           kernel_size=(3, 3),
                           downsample=i == 0 and down,
                           stride=i != 0,
                           block_name=f'block1_{i}',
                           dtype=self.dtype)(x, act, train)

            
        # Layer 2
        for i in range(LAYERS[self.architecture][1]):

            if self.greedy:
                act[f'block2_{i}'] = x
                x = stop_gradient(x)

            x = self.block(features=128,
                           kernel_size=(3, 3),
                           downsample=i == 0,
                           block_name=f'block2_{i}',
                           dtype=self.dtype)(x, act, train)

        
        # Layer 3
        for i in range(LAYERS[self.architecture][2]):

            if self.greedy:
                act[f'block3_{i}'] = x
                x = stop_gradient(x)

            x = self.block(features=256,
                           kernel_size=(3, 3),
                           downsample=i == 0,
                           block_name=f'block3_{i}',
                           dtype=self.dtype)(x, act, train)


        # Layer 4
        for i in range(LAYERS[self.architecture][3]):

            if self.greedy:
                act[f'block4_{i}'] = x
                x = stop_gradient(x)

            x = self.block(features=512,
                           kernel_size=(3, 3),
                           downsample=i == 0,
                           block_name=f'block4_{i}',
                           dtype=self.dtype)(x, act, train)


        act['out'] = x
        return act
    

    
    
    
    
    
    

def ResNet18(kernel_init=nn.initializers.lecun_normal(),
             bias_init=nn.initializers.zeros,
             dtype='float32',
             greedy=False,
             dataset='cifar10'):

    return ResNet(architecture='resnet18',
                  block=BasicBlock,
                  kernel_init=kernel_init,
                  bias_init=bias_init,
                  dtype=dtype,
                  greedy=greedy,
                  dataset=dataset)




def ResNet34(kernel_init=nn.initializers.lecun_normal(),
             bias_init=nn.initializers.zeros,
             dtype='float32',
             greedy=False,
             dataset='cifar10'):

    return ResNet(architecture='resnet34',
                  block=BasicBlock,
                  kernel_init=kernel_init,
                  bias_init=bias_init,
                  dtype=dtype,
                  greedy=greedy,
                  dataset=dataset)




def ResNet50(kernel_init=nn.initializers.lecun_normal(),
             bias_init=nn.initializers.zeros,
             dtype='float32',
             greedy=False,
             dataset='cifar10'):

    return ResNet(architecture='resnet50',
                  block=Bottleneck,
                  kernel_init=kernel_init,
                  bias_init=bias_init,
                  dtype=dtype,
                  greedy=greedy,
                  dataset=dataset)




def ResNet101(kernel_init=nn.initializers.lecun_normal(),
              bias_init=nn.initializers.zeros,
              dtype='float32',
              greedy=False,
              dataset='cifar10'):
    
    return ResNet(architecture='resnet101',
                  block=Bottleneck,
                  kernel_init=kernel_init,
                  bias_init=bias_init,
                  dtype=dtype,
                  greedy=greedy,
                  dataset=dataset)




def ResNet152(kernel_init=nn.initializers.lecun_normal(),
              bias_init=nn.initializers.zeros,
              dtype='float32',
              greedy=False,
              dataset='cifar10'):
    
    return ResNet(architecture='resnet152',
                  block=Bottleneck,
                  kernel_init=kernel_init,
                  bias_init=bias_init,
                  dtype=dtype,
                  greedy=greedy,
                  dataset=dataset)




def batch_norm(x, train, epsilon=1e-05, momentum=0.99, params=None, dtype='float32'):
    if params is None:
        x = nn.BatchNorm(epsilon=epsilon,
                         momentum=momentum,
                         use_running_average=not train,
                         dtype=dtype)(x)
    else:
        x = nn.BatchNorm(epsilon=epsilon,
                         momentum=momentum,
                         bias_init=lambda *_ : jnp.array(params['bias']),
                         scale_init=lambda *_ : jnp.array(params['scale']),
                         mean_init=lambda *_ : jnp.array(params['mean']),
                         var_init=lambda *_ : jnp.array(params['var']),
                         use_running_average=not train,
                         dtype=dtype)(x)
    return x