import jax
import jax.numpy as jnp
from models.relu_mlp import ReLU_MLP, VarPro_MLP
from models.cnn import CNN, VarPro_CNN

def init_model(model_params, x, key):
    if model_params['type'] == 'relu-mlp':
        model = ReLU_MLP()
        params = model.init(key, x)
    elif model_params['type'] == 'varpro-mlp':
        model = VarPro_MLP()
        params = model.init(key, x)
    elif model_params['type'] == 'cnn':
        model = CNN()
        params = model.init(key, jnp.ones((1,x.shape[0],x.shape[1],x.shape[2])))
    elif model_params['type'] == 'varpro-cnn':
        model = VarPro_CNN()
        params = model.init(key, jnp.ones((1,x.shape[0],x.shape[1],x.shape[2])))
    else:
      raise ValueError("This model is currently not implemented.")
    
    def loss(params, data_batch, data_labels):
        preds = model.apply(params, data_batch)
        return ((preds-data_labels)**2).mean()
    
    return params, model, loss

