import os
import pandas as pd
from tape.analysis import get_config, get_name, get_parent_name


if __name__ == '__main__':
    import argparse
    from glob import glob
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser(description='Plot results for protein tasks')
    parser.add_argument('dataset', choices=['pfam31_whole', 'secondary_structure', 'paired_scope', 'transmembrane'])

    args = parser.parse_args()

    supervised = args.dataset != 'pfam31_whole'

    results = sorted(glob(f'results/*{args.dataset}'))

    legend = []

    for directory in results:
        config = get_config(directory)

        if not supervised:
            model_name = get_name(config)
        else:
            model_name = get_parent_name(config)

        data = pd.read_json(os.path.join(directory, '1', 'metrics.json'))

        legend.append(model_name)
        try:
            plt.plot(data['valid.Acc']['values'])
        except KeyError:
            # Older code used a different key
            plt.plot(data['valid_accuracy']['values'])

    plt.legend(legend)

    plt.show()
