import torch
from .BiSRNet_3L import BiSRNet_3L

def model_generator(method, pretrained_model_path=None):
    if method == 'bisrnet':
        model = BiSRNet_3L(in_channels=28, out_channels=28, n_feat=28, stage=1, num_blocks=[1,1,1]).cuda()
    else:
        print(f'Method {method} is not defined !!!!')
    if pretrained_model_path is not None:
        print(f'load model from {pretrained_model_path}')
        checkpoint = torch.load(pretrained_model_path)
        model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()},
                              strict=True)
    return model