"""
Author: Anonymous
Code for uncertainty-aware self-training for few label learning.
"""


import random
import logging as logger
import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import RandomUniform
from tensorflow.keras.layers import Embedding, Input, LSTM, Bidirectional, TimeDistributed, Dropout, Dense, Conv1D, Lambda, Concatenate,\
    RepeatVector, Activation, Flatten, Permute, Add, concatenate, MaxPooling1D, GlobalMaxPooling1D
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from numpy.random import seed
from tensorflow.keras.regularizers import l1, l2
import tensorflow as tf
from tensorflow.keras.utils import multi_gpu_model

from bert.loader import StockBertConfig, map_stock_config_to_params, load_stock_weights

from bert import BertModelLayer

import bert
import os

def gelu(x):
  """Gaussian Error Linear Unit.

  This is a smoother version of the RELU.
  Original paper: https://arxiv.org/abs/1606.08415
  Args:
    x: float Tensor to perform activation.

  Returns:
    `x` with the GELU activation applied.
  """
  cdf = 0.5 * (1.0 + tf.tanh(
      (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
  return x * cdf


def get_H_t(X, timestep):
    ans = X[:, timestep, :]  # get first element from time dim
    return ans

def construct_bert(model_dir, timesteps, classes, dense_dropout=0.5, attention_dropout=0.3, hidden_dropout=0.3, adapter_size=8):

    bert_ckpt_file   = os.path.join(model_dir, "bert_model.ckpt")
    bert_config_file = os.path.join(model_dir, "bert_config.json")

    # create the bert layer
    with tf.io.gfile.GFile(bert_config_file, "r") as reader:
        bc = StockBertConfig.from_json_string(reader.read())
        bert_params = map_stock_config_to_params(bc)
        #bert_params.adapter_size = adapter_size
        bert_params.attention_dropout = attention_dropout
        bert_params.hidden_dropout = hidden_dropout
        bert_params.mask_zero = True
        bert = BertModelLayer.from_params(bert_params, name="bert")
        print ("DEBUG ", str(bert_params.intermediate_activation), str(bert_params.intermediate_size), str(bert_params.mask_zero))
    

    #bert_params = bert.params_from_pretrained_ckpt(model_dir)
    #l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")

    input_ids = Input(shape=(timesteps,), dtype='int32', name="input_ids_1")
    token_type_ids = Input(shape=(timesteps,), dtype='int32', name="token_type_ids_1")

    dense = Dense(units=768, activation="tanh", name="dense")

    # using the default token_type/segment id 0
    output = bert([input_ids, token_type_ids]) # output: [batch_size, max_seq_len, hidden_size]

    # output = Lambda(get_H_t, arguments={"timestep": 0})(output)#, output_shape=(k,))(x)
    # output = Dense(classes, activation="softmax", name='output')(output)
    # model = Model(inputs=l_input_ids, outputs=output)
    # model.build(input_shape=(None, timesteps))
    # return model

    print("bert shape", output.shape)
    cls_out = Lambda(lambda seq: seq[:, 0:1, :])(output)
    cls_out = Dropout(dense_dropout)(cls_out)
    logits = dense(cls_out)
    logits = Dropout(dense_dropout)(logits)
    logits = Dense(units=classes, activation="softmax", name="output_1")(logits)

    model = Model(inputs=[input_ids, token_type_ids], outputs=logits)
    model.build(input_shape=(None, timesteps))

    #bert.apply_adapter_freeze()

    # load the pre-trained model weights
    load_stock_weights(bert, bert_ckpt_file)
    return model