import numpy as np
import tensorflow as tf
import argparse

import scipy.io.wavfile as wav

import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import sys
from collections import namedtuple
sys.path.append("../../new/DeepSpeech")
import DeepSpeech

from tf_logits import get_logits


# These are the tokens that we're allowed to use.
# The - token is special and corresponds to the epsilon
# value in CTC decoding, and can not occur in the phrase.
toks = " abcdefghijklmnopqrstuvwxyz'-"



def main():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--in', type=str, dest="input",
                        required=True,
                        help="Input audio .wav file(s), at 16KHz (separated by spaces)")
    args = parser.parse_args()
    while len(sys.argv) > 1:
        sys.argv.pop()
    with tf.Session() as sess:
        if args.input.split(".")[-1] == 'mp3':
            raw = pydub.AudioSegment.from_mp3(args.input)
            audio = np.array([struct.unpack("<h", raw.raw_data[i:i+2])[0] for i in range(0,len(raw.raw_data),2)])
        elif args.input.split(".")[-1] == 'wav':
            _, audio = wav.read(args.input)
        else:
            raise Exception("Unknown file format")
        N = len(audio)
        new_input = tf.placeholder(tf.float32, [1, N])
        lengths = tf.placeholder(tf.int32, [1])

        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
            logits = get_logits(new_input, lengths)

        saver = tf.train.Saver()
        saver.restore(sess, "deepspeech-0.4.1-checkpoint/model.v0.4.1")

        decoded, _ = tf.nn.ctc_beam_search_decoder(logits, lengths, merge_repeated=False, beam_width=500)

        for start, end in [(0, len(audio)),
                           (0, len(audio)//2),
                           (len(audio)//2, len(audio))]:
            print('logits shape', logits.shape)
            aa = list(audio[start:end]) + [0]*(len(audio)-(end-start))
            aa = np.array(aa)
            length = (len(aa)-1)//320
            l = len(aa)
            r = sess.run(decoded, {new_input: [aa],
                                   lengths: [length]})
    
            print("-"*80)
            print("-"*80)
    
            print("Classification:")
            print("".join([toks[x] for x in r[0].values]))
            print("-"*80)
            print("-"*80)
            
            
main()


