from typing import Tuple, List, Any, Sequence
from sklearn.utils import shuffle
import os
import copy

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import numpy as np
import time
import pickle

import json
import queue
import threading

from collections import namedtuple, defaultdict
import gc
import random
import math
import zerorpc
from HypothesisFilter import HypothesisFilter


def glorot_init(shape):
    initialization_range = np.sqrt(6.0 / (shape[-2] + shape[-1]))
    return np.random.uniform(low=-initialization_range, high=initialization_range, size=shape).astype(np.float32)

def sort_hyps(hyps):
    return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)

class ThreadedIterator:
    def __init__(self, original_iterator, max_queue_size: int = 2):
        self.__queue = queue.Queue(maxsize=max_queue_size)
        self.__thread = threading.Thread(target=lambda: self.worker(original_iterator))
        self.__thread.start()

    def worker(self, original_iterator):
        for element in original_iterator:
            assert element is not None, 'By convention, iterator elements much not be None'
            self.__queue.put(element, block=True)
        self.__queue.put(None, block=True)

    def __iter__(self):
        next_element = self.__queue.get(block=True)
        while next_element is not None:
            yield next_element
            next_element = self.__queue.get(block=True)
        self.__thread.join()


SOFTMAX_NUMBER = 0 #change 0 to the size of the API vocabulary size
WHOLE_VOCABULARY_SIZE = 0 #change 0 to the size of the whole vocabulary
MAX_DEC_LEN = 20
MAX_ENC_LEN = 60
MIN_DEC_LEN = 1
SOS_ID = 0 #change 0 to id presenting SOS_ID 
EOS_ID = 0 #change 0 to id presenting EOS_ID 
DEC_PAD = 0 #change 0 to id presenting DEC_PAD 
ENC_PAD = 0 #change 0 to id presenting ENC_PAD 
BEAM_SIZE = 10
TOPK_SIZE = 20
APIVOCABULARY_PATH = 'APIVocabulary.txt'
WHOLEVOCABULARY_PATH = 'WholeVocabulary.txt'
MODEL_PATH = '/model/'
class TransformerModel(object):
    @classmethod
    def default_params(cls):
        return{
            'num_epochs': 50,
            'patience': 5,
            'learning_rate': 0.005,
            'clamp_gradient_norm': 1.0,
            'layer_dropout_keep_prob': 0.75,
            'embed_size': 300,
            'encoder_hidden_size': 300,
            'decoder_hidden_size': 600,
            'softmax_size': 600,
            'batch_size': 256,
            'momentum': 0.9,
            'encoder_num_layers': 3,
            'decoder_num_layers': 1,  # todo
            'encoder_num_head': 6,
            'decoder_num_head': 6,
            'encoder_num_block': 3,
            'decoder_num_block': 1,
            'task_ids': [0],
            'random_seed': 0
        }
    
    def __init__(self, args, training_file_count, valid_file_count):
        self.args = args
        self.training_file_count = training_file_count
        self.valid_file_count = valid_file_count
        self.train_data = None
        self.valid_data = None
        
        data_dir = ''
        if args.data_dir is not None:
            data_dir = args.data_dir
        self.data_dir = data_dir
        
        self.run_id = "_".join([time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
        log_dir = '.'
        self.log_file = os.path.join(log_dir, "%s_log.pickle" % self.run_id)
        self.best_model_file = os.path.join(log_dir, "%s_model_best.pickle" % self.run_id)
        self.best_model_checkpoint = "model_best-%s" % self.run_id
        
        params = self.default_params()
        config_file = args.config_file
        if config_file is not None:
            with open(config_file, 'r') as f:
                params.update(json.load(f))
        config = args.config
        if config is not None:
            params.update(json.loads(config))

        self.params = params

        print("Run %s starting with following parameters:\n%s" % (self.run_id, json.dumps(self.params)))

        random.seed(params['random_seed'])
        np.random.seed(params['random_seed'])
        
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.log_device_placement = True
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph, config=config)
        with self.graph.as_default():
            tf.set_random_seed(params['random_seed'])
            self.placeholders = {}
            self.weights = {}
            self.ops = {}
            self.make_model()
            self.make_train_step()
            restore_file = args.restore
            if restore_file is not None:
                self.restore_model2(MODEL_PATH + restore_file)
            else:
                self.initialize_model()
        with open(APIVOCABULARY_PATH, 'r') as f:
            self.api2idx = {}
            self.idx2api = {}
            for idx, word in enumerate(f.readlines()):
                api = word.strip()
                self.api2idx[api] = idx
                self.idx2api[idx] = api
        with open(WHOLEVOCABULARY_PATH, 'r') as f:
            self.word2idx = {}
            self.idx2word = {}
            for idx, token in enumerate(f.readlines()):
                word = token.strip()
                self.word2idx[word] = idx
                self.idx2word[idx] = word
        print("On Serving...")
        
    def make_model(self)-> None:
        self.placeholders['input_label'] = tf.placeholder(tf.int64, [None, MAX_ENC_LEN,], name='input_label')
        self.placeholders['input_seq_length'] = tf.placeholder(shape=(None,), dtype=tf.int32, name='input_seq_length')
        self.placeholders['output_seq_length'] = tf.placeholder(shape=(None,), dtype=tf.int32, name='output_seq_length')
        self.placeholders['decoder_input_label'] = tf.placeholder(tf.int64, [None,MAX_DEC_LEN,], name='decoder_input_label')
        self.placeholders['decoder_input_mask'] = tf.placeholder(tf.float32, [None,MAX_DEC_LEN,self.params['embed_size']], name='decoder_input_mask')
        self.placeholders['output_mask'] = tf.placeholder(tf.float32, [None], name='output_mask')
        self.placeholders['input_mask'] =  tf.placeholder(tf.float32, [None,MAX_ENC_LEN,self.params['embed_size']], name='input_mask')
        self.placeholders['layer_dropout_keep_prob'] = tf.placeholder(tf.float32, [], name='layer_dropout_keep_prob')
        self.placeholders['output_ground_truth'] = tf.placeholder(tf.int64, [None, None],name='output_ground_truth')
        self.placeholders['batch_data_count'] = tf.placeholder(tf.int64, [], name='batch_data_count')
        
        self.weights['encoder_embedding'] = tf.Variable(dtype=tf.float32,initial_value=np.random.uniform(-0.5, 0.5, [WHOLE_VOCABULARY_SIZE, self.params['embed_size']]))
        self.weights['decoder_embedding'] = tf.Variable(dtype=tf.float32,initial_value=np.random.uniform(-0.5, 0.5, [SOFTMAX_NUMBER, self.params['embed_size']]))
        
        self.placeholders['input_emb'] = tf.nn.embedding_lookup(self.weights['encoder_embedding'],self.placeholders['input_label'])
        self.placeholders['decoder_input_emb'] = tf.nn.embedding_lookup(self.weights['decoder_embedding'],self.placeholders['decoder_input_label'])


        self.encoder_output = tf.placeholder(tf.float32,[BEAM_SIZE, None, self.params['decoder_hidden_size']], name='encoder_output')
        
        self.ops['losses'] = []
        for (internal_id, task_id) in enumerate(self.params['task_ids']):
            with tf.variable_scope("out_layer_task%i" % task_id):
                with tf.variable_scope("softmax"):
                    self.weights['softmax_weights'] = tf.Variable(glorot_init([self.params['softmax_size'], SOFTMAX_NUMBER]))
                    self.weights['softmax_biases'] = tf.Variable(np.zeros([SOFTMAX_NUMBER]).astype(np.float32))
                    print('softmax weights: ', self.weights['softmax_weights'])
                computed_values = self.gated_regression(self.weights['softmax_weights'],self.weights['softmax_biases'])
                groundtruth = self.placeholders['output_ground_truth'][internal_id, :]
                loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=computed_values, labels=groundtruth)
                label_weights = self.placeholders['output_mask']
                union_tensor = loss * label_weights
                print("union_tensor: ", union_tensor)
                reshape_tensor = tf.reshape(union_tensor, [MAX_DEC_LEN, -1])
                print("reshape_tensor: ", reshape_tensor)
                task_loss = tf.reduce_mean(tf.reduce_sum(reshape_tensor, axis=0))
                self.ops['accuracy_task%i' % task_id] = task_loss
                self.ops['losses'].append(task_loss)
        self.ops['loss'] = tf.reduce_sum(self.ops['losses'])
        self.saver = tf.train.Saver()
        
    def make_train_step(self):
        trainable_vars = self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(self.params['learning_rate'], batch, 2000, 1.0, staircase=False)
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=self.params['momentum'], use_nesterov=True)
        grads_and_vars = optimizer.compute_gradients(self.ops['loss'], var_list=trainable_vars)
        clipped_grads = []
        for grad, var in grads_and_vars:
            if grad is not None:
                clipped_grads.append((tf.clip_by_norm(grad, self.params['clamp_gradient_norm']), var))
            else:
                clipped_grads.append((grad, var))
        self.ops['train_step'] = optimizer.apply_gradients(clipped_grads, global_step=batch)
        self.sess.run(tf.local_variables_initializer())
    
    def gated_regression(self,softmax_weights,softmax_biases):
        ENCODER_HIDDEN_SIZE = self.params['encoder_hidden_size']
        ENCODER_NUM_LAYERS = self.params['encoder_num_layers']
        DECODER_HIDDEN_SIZE = self.params['decoder_hidden_size']
        DECODER_NUM_LAYERS = self.params['decoder_num_layers']
        self.SOS_ID = tf.cast(SOS_ID, tf.int32)
        self.EOS_ID = tf.cast(EOS_ID, tf.int32)
        with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
            enc = self.placeholders['input_emb']
            input_masks =  self.placeholders['input_mask']
            enc = enc * input_masks
            enc *= self.params['embed_size'] ** 0.5 
            enc += self.positional_encoding(enc, num_units=self.params['embed_size'], 
                                      zero_pad=False, 
                                      scale=False,
                                      scope="enc_pe")
            enc = enc * input_masks
            enc = tf.nn.dropout(enc,self.placeholders['layer_dropout_keep_prob'])
            
            print('input_masks',input_masks)
            for i in range(self.params['encoder_num_block']):
                with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
                    enc = self.multihead_attention(queries=enc,
                                              keys=enc,
                                              num_heads=self.params['encoder_num_head'],
                                              dropout_rate=self.placeholders['layer_dropout_keep_prob'],
                                              causality=False)
                    enc = self.feedforward(enc, num_units=[self.params['encoder_hidden_size'], self.params['encoder_hidden_size']])
        memory = enc
        memory = tf.layers.dense(memory,self.params['decoder_hidden_size'], tf.tanh, name='encoder_dense', reuse=tf.AUTO_REUSE)
    
        with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
            self.encoder_result = memory
            decoder_input_masks =  self.placeholders['decoder_input_mask']
            self.placeholders['decoder_input_emb'] = tf.nn.embedding_lookup(self.weights['decoder_embedding'],self.placeholders['decoder_input_label'])
            dec = self.placeholders['decoder_input_emb'] 
            dec = dec * decoder_input_masks
            dec *= self.params['embed_size'] ** 0.5 

            dec += self.positional_encoding(dec, num_units=self.params['embed_size'], 
                                      zero_pad=False, 
                                      scale=False,
                                      scope="dec_pe")
            dec = tf.layers.dense(dec,self.params['decoder_hidden_size'], tf.tanh, name='decoder_dense', reuse=tf.AUTO_REUSE)

            for i in range(self.params['decoder_num_block']):
                with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
                    dec = self.multihead_attention(queries=dec,
                                              keys=dec,
                                              num_heads=self.params['decoder_num_head'],
                                              dropout_rate=self.placeholders['layer_dropout_keep_prob'],
                                              causality=True,
                                              scope="self_attention")

                    dec = self.multihead_attention(queries=dec,
                                              keys=self.encoder_output,
                                              num_heads=self.params['decoder_num_head'],
                                              dropout_rate=self.placeholders['layer_dropout_keep_prob'],
                                              causality=False,
                                              scope="vanilla_attention")
                    dec = self.feedforward(dec, num_units=[self.params['decoder_hidden_size'], self.params['decoder_hidden_size']])
        dec = tf.reshape(dec, [-1, self.params['decoder_hidden_size']])
        logits = tf.add(tf.matmul(dec, softmax_weights), softmax_biases)
        topk_probs, topk_ids = tf.nn.top_k(logits, TOPK_SIZE)
        self.output = topk_probs, topk_ids
        return logits
        
    def normalize(self, inputs, 
              epsilon = 1e-8,
              scope="ln",
              reuse=tf.AUTO_REUSE):
        with tf.variable_scope(scope, reuse=reuse):
            inputs_shape = inputs.get_shape()
            params_shape = inputs_shape[-1:]

            mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
            beta= tf.Variable(tf.zeros(params_shape))
            gamma = tf.Variable(tf.ones(params_shape))
            normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) )
            outputs = gamma * normalized + beta

        return outputs
    
    def positional_encoding(self, inputs,
                        num_units,
                        zero_pad=True,
                        scale=True,
                        scope="positional_encoding",
        E = inputs.get_shape().as_list()[-1]
        N, T = tf.shape(inputs)[0], inputs.get_shape().as_list()[-2] 
        with tf.variable_scope(scope, reuse=reuse):
            position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1])
            position_enc = np.array([
                [pos / np.power(10000, 2.*i/num_units) for i in range(num_units)]
                for pos in range(T)])

            position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  
            position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) 

            lookup_table = tf.convert_to_tensor(position_enc, tf.float32)

            if zero_pad:
                lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
                                          lookup_table[1:, :]), 0)
            outputs = tf.nn.embedding_lookup(lookup_table, position_ind)

            if scale:
                outputs = outputs * num_units**0.5

            return outputs
        
    def multihead_attention(self, queries, 
                        keys, 
                        num_units=None, 
                        num_heads=6, 
                        dropout_rate=1.0,
                        causality=False,
                        scope="multihead_attention", 
                        reuse=tf.AUTO_REUSE):
        with tf.variable_scope(scope, reuse=reuse):
            if num_units is None:
                num_units = queries.get_shape().as_list()[-1]

            Q = tf.layers.dense(queries, num_units, activation=tf.tanh) 
            K = tf.layers.dense(keys, num_units, activation=tf.tanh) 
            V = tf.layers.dense(keys, num_units, activation=tf.tanh) 

            Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) 
            K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)
            V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) 

            outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) 

            outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

            key_masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) 
            key_masks = tf.tile(key_masks, [num_heads, 1]) 
            key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1])

            paddings = tf.ones_like(outputs)*(-2**32+1)
            outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) 

            if causality:
                diag_vals = tf.ones_like(outputs[0, :, :])
                tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()
                masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1])

                paddings = tf.ones_like(masks)*(-2**32+1)
                outputs = tf.where(tf.equal(masks, 0), paddings, outputs) 

            outputs = tf.nn.softmax(outputs) 

            query_masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1))
            query_masks = tf.tile(query_masks, [num_heads, 1]) 
            query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) 
            outputs *= query_masks 
            outputs = tf.matmul(outputs, V_)

            outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) 

            outputs += queries

            outputs = self.normalize(outputs) 

        return outputs
    
    def feedforward(self, inputs, 
                num_units,
                scope="multihead_attention", 
                reuse=tf.AUTO_REUSE):
        with tf.variable_scope(scope, reuse=reuse):
            outputs = tf.layers.dense(inputs, num_units[0], activation=tf.tanh)

            outputs = tf.layers.dense(outputs, num_units[1], activation=tf.tanh)

            outputs += inputs

            outputs = self.normalize(outputs)

        return outputs

    def run_epoch(self, epoch_name: str, data, is_training: bool):
        loss = 0
        accuracies = []
        accuracy_ops = [self.ops['accuracy_task%i' % task_id] for task_id in self.params['task_ids']]
        start_time = time.time()
        read_data_time = 0
        total = 0
        processed_data_count = 0
        count = 0
        file_count = 0
        index = 0
        prefix_path = ""
        if is_training:
            file_count = self.training_file_count
            prefix_path = "/transformerTrainingData/"
        else:
            file_count = self.valid_file_count
            prefix_path = "/transformerValidData/"
        while count < file_count:
            tempAcc = []
            full_path = prefix_path + str(count + index) + ".json"
            filestr = None
            if is_training:
                filestr = "training"
            else:
                filestr = "valid"
            t = time.time()
            data = self.load_data(full_path, is_training_data=is_training)
            read_data_time = time.time() - t + read_data_time
            count = count + 1
            batch_iterator = ThreadedIterator(self.make_batch_iterator(data, is_training),max_queue_size=5)
            for step, batch_data in enumerate(batch_iterator):
                total = total + 1
                batch_data_count = batch_data[self.placeholders['batch_data_count']]
                if batch_data_count == 0:
                    continue
                processed_data_count += batch_data_count
                if is_training:
                    batch_data[self.placeholders['layer_dropout_keep_prob']] = self.params['layer_dropout_keep_prob']
                    fetch_list = [self.ops['loss'], accuracy_ops, self.ops['train_step']]
                else:
                    batch_data[self.placeholders['layer_dropout_keep_prob']] = 1.0
                    fetch_list = [self.ops['loss'], accuracy_ops]
                result = self.sess.run(fetch_list, feed_dict=batch_data)
                (batch_loss, batch_accuracies) = (result[0], result[1])
                loss += batch_loss * batch_data_count
                accuracies.append(np.array(batch_accuracies) * batch_data_count)
                tempAcc.append(np.array(batch_accuracies) * batch_data_count)
                print("Running %s, %s file %i, batch %i (has %i graphs). Loss so far: %.4f" % (epoch_name, filestr, count, total,batch_data_count, loss / processed_data_count),end='\r')
            del data
            del batch_iterator
            gc.collect()
        accuracies = np.sum(accuracies, axis=0) / processed_data_count
        loss = loss / processed_data_count
        end_time = time.time() - start_time
        m, s = divmod(end_time, 60)
        h, m = divmod(m, 60)
        time_str = "%02d:%02d:%02d" % (h, m, s)
        m, s = divmod(read_data_time, 60)
        h, m = divmod(m, 60)
        read_data_time_str = "%02d:%02d:%02d" % (h, m, s)
        instance_per_sec = processed_data_count / (time.time() - start_time - read_data_time)
        return loss, accuracies, instance_per_sec, total, processed_data_count, time_str, read_data_time_str
            
    def train(self):
        total_time_start = time.time()
        start_epoch = 0
        with self.graph.as_default():
            ckpt = tf.train.get_checkpoint_state('/transformer_model/')
            if ckpt and ckpt.model_checkpoint_path:
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
                valid_loss, valid_accs, valid_speed, valid_batch, valid_total_data_count, valid_time, valid_read_data_time = self.run_epoch(
                    "Resumed (validation)", self.valid_data, False)
                best_val_acc = np.sum(valid_accs)
                best_val_acc_epoch = 0
                print("\r\x1b[KResumed operation, initial cum. val. acc: %.5f" % best_val_acc)
                start_epoch = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[0])
            else:
                (best_val_acc, best_val_acc_epoch) = (float("-inf"), 0)
                best_val_acc = 100000.0
            print("Training")
            for epoch in range(start_epoch + 1, self.params['num_epochs'] + 1):
                print("== Epoch %i" % epoch)
                train_loss, train_accs, train_speed, train_batch, train_total_data_count, train_time, train_read_data_time = self.run_epoch(
                    "epoch %i (training)" % epoch,self.train_data, True)
                accs_str = " ".join(["%i:%.5f" % (id, acc) for (id, acc) in zip(self.params['task_ids'], train_accs)])
                print(
                    "\r\x1b[K Train: loss: %.5f | acc: %s | instances/sec: %.2f | train_batch: %i | train_total_graph: %i | train_time: %s | train_read_data_time: %s" % (
                        train_loss,
                        accs_str,
                        train_speed, train_batch, train_total_data_count, train_time, train_read_data_time))
                valid_loss, valid_accs, valid_speed, valid_batch, valid_total_data_count, valid_time, valid_read_data_time = self.run_epoch(
                    "epoch %i (validation)" % epoch,
                    self.valid_data, False)
                accs_str = " ".join(["%i:%.5f" % (id, acc) for (id, acc) in zip(self.params['task_ids'], valid_accs)])
                print(
                    "\r\x1b[K Valid: loss: %.5f | acc: %s | instances/sec: %.2f | valid_batch: %i | valid_total_graph: %i | valid_time: %s | valid_read_data_time: %s" % (
                        valid_loss,
                        accs_str,
                        valid_speed, valid_batch, valid_total_data_count, valid_time, valid_read_data_time))
                epoch_time = time.time() - total_time_start

                val_acc = np.sum(valid_accs) 
                if val_acc < best_val_acc:
                    self.save_model2(epoch_num=epoch,
                                    path='/transformer_model/')
                    print("  (Best epoch so far, cum. val. acc increased to %.5f from %.5f. Saving to '%s')" % (
                        val_acc, best_val_acc, self.best_model_file))
                    best_val_acc = val_acc
                    best_val_acc_epoch = epoch
                elif epoch - best_val_acc_epoch >= self.params['patience']:
                    print("Stopping training after %i epochs without improvement on validation accuracy." % self.params[
                        'patience'])
                    break
           
    def save_model(self, path: str) -> None:
        weights_to_save = {}
        for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
            assert variable.name not in weights_to_save
            weights_to_save[variable.name] = self.sess.run(variable)

        data_to_save = {
            "params": self.params,
            "weights": weights_to_save
        }

        with open(path, 'wb') as out_file:
            pickle.dump(data_to_save, out_file, pickle.HIGHEST_PROTOCOL)

    def save_model2(self, epoch_num, path: str) -> None:
        self.saver.save(self.sess, path + str(epoch_num) + '-' + self.best_model_checkpoint)


    def restore_model(self, path: str) -> None:
        print("Restoring weights from file %s." % path)
        with open(path, 'rb') as in_file:
            data_to_load = pickle.load(in_file)

        assert len(self.params) == len(data_to_load['params'])
        for (par, par_value) in self.params.items():
            if par not in ['task_ids', 'num_epochs']:
                assert par_value == data_to_load['params'][par]

        variables_to_initialize = []
        with tf.name_scope("restore"):
            restore_ops = []
            used_vars = set()
            for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                used_vars.add(variable.name)
                if variable.name in data_to_load['weights']:
                    restore_ops.append(variable.assign(data_to_load['weights'][variable.name]))
                else:
                    print("Freshly initializing %s since no saved value was found." % variable.name)
                    variables_to_initialize.append(variable)
            for var_name in data_to_load['weights']:
                if var_name not in used_vars:
                    print("Saved weights for %s not used by model." % var_name)

            restore_ops.append(tf.variables_initializer(variables_to_initialize))
            self.sess.run(restore_ops)

    def restore_model2(self, path: str) -> None:
        print("Restore...")
        self.saver.restore(self.sess, path)
        print("Restore done!")
            
    def load_data(self, filename, is_training_data: bool):
        with open(filename, 'r') as f:
            data = json.load(f)
        bucketed = []
        for d in data:
            input_seq = d['input_seq'][0]
            output_seq = d['output_seq'][0]
            bucketed.append({'input_seq':input_seq,'output_seq':output_seq})
        return bucketed
    
    def make_batch(self, elements):
        batch_data = {'input_seqs': [], 'input_seqs_length': [],'input_seqs_mask':[], 'output_seqs_length': [], 'decoder_input_seqs': [], 'decoder_input_seqs_mask':[], 'output_ground_truth_mask': [],'output_ground_truth':[]}
        output_ground_truth = []
        for d in elements:
            input_array = d['input_seq']
            input_length = len(input_array)
            if input_length >= MAX_ENC_LEN:
                input_array = input_array[0:MAX_ENC_LEN]
                input_length = MAX_ENC_LEN
            else:
                for i in range(MAX_ENC_LEN - input_length):
                    input_array.append(ENC_PAD)  # pad 
            batch_data['input_seqs'].append(input_array)
            batch_data['input_seqs_length'].append(input_length)
            input_mask = []
            for i in range(input_length):
                input_mask.append([1 for _ in range(self.params['embed_size'])])
            for i in range(MAX_ENC_LEN - input_length):
                input_mask.append([0 for _ in range(self.params['embed_size'])])
            batch_data['input_seqs_mask'].append(input_mask)
            
            output_array = d['output_seq']
            output_length = len(output_array)
            if output_length >= MAX_DEC_LEN:
                output_array = output_array[0:MAX_DEC_LEN - 1]
                output_array.append(EOS_ID)
            else:
                output_array.append(EOS_ID) 
            output_length = len(output_array)
            decoder_input_seq = []
            decoder_input_seq.append(SOS_ID)  
            for i in range(MAX_DEC_LEN):
                if i < output_length:
                    output_ground_truth.append(output_array[i])
                else:
                    output_ground_truth.append(DEC_PAD)  
            
            decoder_input_seq_mask = []
            decoder_input_seq_mask.append([1 for _ in range(self.params['embed_size'])])
            for i in range(MAX_DEC_LEN - 1):
                if i < (output_length - 1):
                    decoder_input_seq.append(output_array[i])
                    decoder_input_seq_mask.append([1 for _ in range(self.params['embed_size'])])
                else:
                    decoder_input_seq.append(DEC_PAD)
                    decoder_input_seq_mask.append([0 for _ in range(self.params['embed_size'])])
            
            for i in range(output_length):
                batch_data['output_ground_truth_mask'].append(1)
            for i in range(MAX_DEC_LEN - output_length):
                batch_data['output_ground_truth_mask'].append(0)
            
            batch_data['decoder_input_seqs_mask'].append(decoder_input_seq_mask)
            batch_data['decoder_input_seqs'].append(decoder_input_seq )
            batch_data['output_seqs_length'].append(output_length)
        batch_data['output_ground_truth'].append(output_ground_truth)
        return batch_data
        
    def make_batch_iterator(self, data, is_training: bool):
        data_len = len(data)
        BATCH_SIZE = self.params['batch_size']
        n_batch = int(data_len // BATCH_SIZE) + 1
        if data_len % BATCH_SIZE == 0:
            n_batch = n_batch - 1
        for b in range(n_batch):
            start_id = b * BATCH_SIZE
            end_id = min((b + 1) * BATCH_SIZE, data_len)
            elements = data[start_id:end_id]
            batch_data = self.make_batch(elements)
            batch_data_count = len(elements)
            drop_out = 1.0
            if is_training:
                drop_out = self.params['layer_dropout_keep_prob']
            batch_feed_dict = {
                self.placeholders['input_label']: batch_data['input_seqs'],
                self.placeholders['decoder_input_label']: batch_data['decoder_input_seqs'],
                self.placeholders['input_seq_length']: batch_data['input_seqs_length'],
                self.placeholders['output_seq_length']: batch_data['output_seqs_length'],
                self.placeholders['batch_data_count']: batch_data_count,
                self.placeholders['output_ground_truth']: batch_data['output_ground_truth'], 
                self.placeholders['output_mask']: batch_data['output_ground_truth_mask'],
                self.placeholders['decoder_input_mask']: batch_data['decoder_input_seqs_mask'],
                self.placeholders['input_mask']: batch_data['input_seqs_mask'],
                self.placeholders['layer_dropout_keep_prob']:drop_out
            }
            yield batch_feed_dict
        
    def initialize_model(self) -> None:
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        self.sess.run(init_op)  

    def softmax(self, array):
        now_array = np.array(array)
        new_array = []
        for a in now_array:
            b = []
            exp_sum = 0.0
            for num in a:
                exp_sum += np.exp(num)
            for num in a:
                b.append(np.exp(num) / exp_sum)
            new_array.append(b)
        return np.array(new_array)

    def predict(self, input_string, c_clazz):
        input_seq = []
        output_seq = [-1]
        clazzes = []

        context_list = []
        type_list = []
        scope_list = []
        hole_scope = []
        start_scope = 1
        current_temp_scope = []
        current_temp_scope.append(1)

        input_tokens = input_string.strip().split(" ")
        for token in input_tokens:
            index = self.word2idx.get(token, WHOLE_VOCABULARY_SIZE - 2) 
            input_seq.append(index)
            word = token
            if word == "if" or word == "while" or word == "doWhile" or word == "for" or word == "try" or word == "switch":
                start_scope += 1
                current_temp_scope.append(start_scope)
            elif word == "elseif" or word == "else" or word == "catch" or word == "finally" or word == "case" or word == "default":
                start_scope += 1
                current_temp_scope.pop()
                current_temp_scope.append(start_scope)
            elif word == "out_control":
                current_temp_scope.pop()                
            final_current_temp_scope = []
            for sc in current_temp_scope:
                final_current_temp_scope.append(sc)
        for sc in current_temp_scope:
            hole_scope.append(sc)

        data = [{'input_seq': input_seq, 'output_seq': output_seq}]
        one_batch = self.make_batch_iterator(data, False)
       
        output_apiSeq = "start"
        current_scope = [1]
        last_scope = [1]
        scopes = []
        clazz = "none"
        for step, batch in enumerate(one_batch):
            batch[self.placeholders['layer_dropout_keep_prob']] = 1.0
            batch_data_count = batch[self.placeholders['batch_data_count']]
            if batch_data_count == 0:
                break
            fetch_list = self.encoder_result
            encoder_output = self.sess.run(fetch_list, feed_dict=batch)
            hyps = [HypothesisFilter(tokens=[SOS_ID],
                               log_probs=[0.0],
                               state=encoder_output,
                               clazzes = clazzes,
                               scopes = scopes,
                               scope_index = 1,
                               current_scope = current_scope,
                               last_scope = last_scope,
                               control_apis = [],
                               thenbodycontrol_apis = [],
                               context_apis = []
                               ) for _ in range(BEAM_SIZE)]
            results = []
            steps = 0
            while steps < MAX_DEC_LEN and len(results) < BEAM_SIZE:
                decoder_inputs = [h.get_tokens for h in hyps] 
                states = [h.state for h in hyps] 
                encoder_outputs = [encoder_output[0] for h in hyps]
                states = np.reshape(states, (-1, self.params['decoder_hidden_size']))

                decoder_final_inputs = []
                decoder_final_masks = []
                for d_input in decoder_inputs:
                    d_final_input = []
                    d_input_length = len(d_input)
                    for i in range(MAX_DEC_LEN):
                        if i < d_input_length:
                            d_final_input.append(d_input[i])
                        else:
                            d_final_input.append(DEC_PAD)
                    decoder_final_inputs.append(d_final_input)
                    d_final_mask = []
                    for i in range(MAX_DEC_LEN):
                        if i < d_input_length:
                            d_final_mask.append([1 for _ in range(self.params['embed_size'])])
                        else:
                            d_final_mask.append([0 for _ in range(self.params['embed_size'])])
                    decoder_final_masks.append(d_final_mask)

                feed = {
                    self.encoder_output: encoder_outputs,
                    self.placeholders['decoder_input_mask']:decoder_final_masks,
                    self.placeholders['decoder_input_label']:decoder_final_inputs,
                    self.placeholders['layer_dropout_keep_prob']:1.0
                }
                original_topk_probs, original_topk_ids = self.sess.run(fetches=self.output,feed_dict=feed)
                topk_probs = []
                topk_ids =[]


                for i in range(len(hyps)):
                    topk_probs.append([p for p in original_topk_probs[steps + i * MAX_DEC_LEN]])
                    topk_ids.append([d for d in original_topk_ids[steps + i * MAX_DEC_LEN]])
                topk_ids = np.array(topk_ids)
                

                topk_probs = self.softmax(topk_probs)
                all_hyps = []
                num_orig_hyps = 1 if steps == 0 else len(hyps)
                for i in range(num_orig_hyps):
                    h =  hyps[i]
                    for j in range(TOPK_SIZE):
                        log_prob = topk_probs[i, j]
                        old = h.latest_token
                        new = topk_ids[i, j]
                        new_api = self.idx2api[new]
                        if new_api == "try" or new_api == "catch" or new_api == "finally" or new_api == "switch" or new_api == "case" or new_api == "default":
                            continue
                        
                        new_hyp = h.extend(token = topk_ids[i, j],
                                           log_prob=log_prob,
                                           state=encoder_output,
                                           clazz = clazz,
                                           scopes = [1],
                                           scope_index = 4,
                                           current_scope = [1],
                                           last_scope = [1],
                                           control_api = new_api,
                                           thenbodycontrol_api = new_api,
                                           context_api = "new_api"
                                           )
                        all_hyps.append(new_hyp)

                hyps = [] 
                for h in sort_hyps(all_hyps): 
                    if h.latest_token == EOS_ID: 
                        if steps >= MIN_DEC_LEN:
                            results.append(h)
                    else:  
                        if steps == (MAX_DEC_LEN - 1):
                            results.append(h)
                        hyps.append(h)
                    if len(hyps) == BEAM_SIZE or len(results) == BEAM_SIZE:
                        break
                results = sort_hyps(results)
                if len(hyps) > BEAM_SIZE:
                    results = results[0:BEAM_SIZE]
                steps += 1
            if len(results) == 0:  
                results = hyps
            hyps_sorted = sort_hyps(results)
            topAPIidSeqList = []
            for hyp in hyps_sorted:
                topAPIidSeqList.append(hyp.get_tokens)
            top_10_api_sequence = ''
            for ele in topAPIidSeqList:
                for e in ele:
                    top_10_api_sequence += ' ' + self.idx2api[e]
                top_10_api_sequence += ';'
            top_10_api_sequence = top_10_api_sequence[1:-1]
            return top_10_api_sequence