import torch
import torch.nn as nn
import torch.nn.functional as F
## -----------------------------------------------------------------------------
## Network layers
## -----------------------------------------------------------------------------


# 3x3 convolution module
def Conv(in_channels, out_channels, padding=1):
    return nn.Conv2d(in_channels, out_channels, 3, padding=padding)


# 1x1 convolution module
def SimpleConv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, 1, padding=0)


# ReLU function
def relu(x):
    return F.relu(x, inplace=True)

# Tanh activation function
def tanh(x):
    return torch.tanh(x)

# 2x2 max pool function
def pool(x):
    return F.max_pool2d(x, 2, 2)


# 2x2 nearest-neighbor upsample function
def upsample(x):
    return F.interpolate(x, scale_factor=2, mode="nearest")


# Channel concatenation function
def concat(a, b):
    return torch.cat((a, b), 1)
