#!/usr/bin/env python

import argparse
import glob
import pandas as pd
import numpy as np
import os


def reindex_dict(id_list):
    # construct a dictionary for reassigning id
    id2newid = {}
    for i, org_id in enumerate(id_list):
        id2newid[org_id] = i
    return id2newid


def preprocess_movie_lens(path_to_data):
    # exclude column 'timestamp' and smaller rates than 4
    df = pd.read_csv(path_to_data+'ratings.csv').drop('timestamp', axis=1)
    df['rating'] = df['rating'].map(lambda x: 1 if x >= 4 else 0)
    df = df[df.rating > 0]

    movie_dic = reindex_dict(np.sort(df['movieId'].unique()))
    user_dic = reindex_dict(np.sort(df['userId'].unique()))
    df['movieId'] = df['movieId'].map(lambda org_id: movie_dic[org_id])
    df['userId'] = df['userId'].map(lambda org_id: user_dic[org_id])

    header_inf = (len(user_dic), len(movie_dic), len(df))
    return df[['movieId', 'userId']].sort_values(['movieId', 'userId']), header_inf


def preprocess_netflix(path_to_data):
    movie_id = -1
    userId_list = []
    movieId_list = []
    for path in sorted(glob.glob(path_to_data+'combined*.txt')):
        with open(path) as f:
            for l in f.readlines():
                # l = "key:" or s = "movie_id, rate, timestamp"
                key, *u = l.split(',')
                if u:
                    if float(u[0]) >= 4:
                        movieId_list.append(movie_id)
                        userId_list.append(int(key))
                else:
                    movie_id += 1
    df = pd.DataFrame([])
    df['movieId'] = movieId_list
    df['userId'] = userId_list

    user_dic = reindex_dict(set(userId_list))
    df['userId'] = df['userId'].map(lambda org_id: user_dic[org_id])
    header_inf = (len(user_dic), len(df["movieId"].unique()), len(df))

    return df[['movieId', 'userId']].sort_values(['movieId', 'userId']), header_inf


def output_csv(df, path_to_output, header_inf):
    os.makedirs(path_to_output, exist_ok=True)
    n, d, nnz = header_inf
    with open(path_to_output+f'B.txt', 'w') as f:
        f.write(f'{n} {d} {nnz}\n')
        for row in df.itertuples(name=None):
            i, m, u = row
            f.write(str(m)+" ")
            f.write(str(u)+"\n")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'dataset_name', help='the name of the dataset: movie_lens of netflix', type=str)
    parser.add_argument(
        'input', help='path to the directory where the raw data is located', type=str)
    parser.add_argument(
        '--output', help='path to the directory where the output .txt file is generated. default: Lazy_DPP/cpp/data/[dataset_name]', default=f'../data/', type=str)
    args = parser.parse_args()
    dataset_name = args.dataset_name
    path_to_data = args.input
    path_to_output = args.output

    if path_to_data[-1] != '/':
        path_to_data += '/'
    if path_to_output[-1] != '/':
        path_to_output += '/'
    path_to_output += dataset_name + '/'

    assert dataset_name in (
        'movie_lens, netflix'), f'no dataset named {dataset_name}'
    if dataset_name == 'movie_lens':
        df, header_inf = preprocess_movie_lens(path_to_data)
    elif dataset_name == 'netflix':
        df, header_inf = preprocess_netflix(path_to_data)

    output_csv(df, path_to_output, header_inf)
    return


if __name__ == '__main__':
    main()
