import torch.nn as nn

pm = "zeros"


def cnn_encoder(n_step, state_dim, param_dim):
    params_enc = nn.Sequential(
        nn.Conv1d(1, 64, kernel_size=5, padding=2, stride=1, padding_mode=pm),
        nn.ELU(),
        nn.Conv1d(64, 128, kernel_size=5, padding=2, stride=1, padding_mode=pm),
        nn.ELU(),
        nn.Conv1d(128, 256, kernel_size=5, padding=2, stride=1, padding_mode=pm),
        nn.ELU(),
        nn.Conv1d(256, 256, kernel_size=5, padding=2, stride=1, padding_mode=pm),
        nn.ELU(),
        nn.Conv1d(256, 128, kernel_size=5, padding=2, stride=1, padding_mode=pm),
        nn.ELU(),
        nn.Conv1d(128, 128, kernel_size=5, padding=2, stride=1, padding_mode=pm),
        nn.Flatten(),
        nn.Linear(n_step * state_dim * 128, param_dim),
    )
    return params_enc
