from typing import Tuple, List, Any, Sequence
from sklearn.utils import shuffle
import os

from HypothesisFilter import HypothesisFilter
import re
import numpy.ma as ma
import copy

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tensorflow as tf
import numpy as np
import time
import pickle

import json
import queue
import threading
import random
from collections import namedtuple, defaultdict
import gc
import zerorpc
import math

BEAM_SIZE = 10
TOPK_SIZE = 20
SOS_ID = 0 #change 0 to id presenting SOS_ID 
EOS_ID = 0 #change 0 to id presenting EOS_ID 
MAX_DEC_LEN = 20
MIN_DEC_LEN = 1
SOFTMAX_SIZE = 0 #change 0 to the size of the API vocabulary size
INPUT_ORDER_SIZE = 0 #change 0 to the size of the whole vocabulary
APIVOCABULARY_PATH = 'APIVocabulary.txt'
WHOLEVOCABULARY_PATH = 'WholeVocabulary.txt'
MODEL_PATH = '/model/'

class MultiRNNCell_with_name(tf.nn.rnn_cell.MultiRNNCell):
    def __init__(self, hidden_size, num_layers):
        self.dec_cell = super().__init__(
            [tf.nn.rnn_cell.GRUCell(hidden_size) for _ in range(num_layers)]
        )

    @property
    def __name__(self):
        return 'MultiRNNCell'

class MLP(object):
    def __init__(self, in_size, out_size, hid_sizes, dropout_keep_prob):
        self.in_size = in_size
        self.out_size = out_size
        self.hid_sizes = hid_sizes
        self.dropout_keep_prob = dropout_keep_prob
        self.params = self.make_network_params()

    def make_network_params(self):
        dims = [self.in_size] + self.hid_sizes + [self.out_size]
        weight_sizes = list(zip(dims[:-1], dims[1:]))
        weights = [tf.Variable(self.init_weights(s), name='MLP_W_layer%i' % i)
                   for (i, s) in enumerate(weight_sizes)]
        biases = [tf.Variable(np.zeros(s[-1]).astype(np.float32), name='MLP_b_layer%i' % i)
                  for (i, s) in enumerate(weight_sizes)]
        network_params = {
            "weights": weights,
            "biases": biases
        }
        return network_params

    def init_weights(self, shape):
        return np.sqrt(6.0 / (shape[-2] + shape[-1])) * (2 * np.random.rand(*shape).astype(np.float32) - 1)

    def __call__(self, inputs):
        acts = inputs
        for W, b in zip(self.params['weights'], self.params['biases']):
            hid = tf.matmul(acts, tf.nn.dropout(W, self.dropout_keep_prob)) + b
            acts = tf.nn.relu(hid)
        last_hidden = hid
        return last_hidden


class ThreadedIterator:
    def __init__(self, original_iterator, max_queue_size: int = 2):
        self.__queue = queue.Queue(maxsize=max_queue_size)
        self.__thread = threading.Thread(target=lambda: self.worker(original_iterator))
        self.__thread.start()

    def worker(self, original_iterator):
        for element in original_iterator:
            assert element is not None, 'By convention, iterator elements much not be None'
            self.__queue.put(element, block=True)
        self.__queue.put(None, block=True)

    def __iter__(self):
        next_element = self.__queue.get(block=True)
        while next_element is not None:
            yield next_element
            next_element = self.__queue.get(block=True)
        self.__thread.join()

def glorot_init(shape):
    initialization_range = np.sqrt(6.0 / (shape[-2] + shape[-1]))
    return np.random.uniform(low=-initialization_range, high=initialization_range, size=shape).astype(np.float32)

def sort_hyps(hyps):
    return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)

class ProgModel(object):
    @classmethod
    def default_params(cls):
        return {
            'num_epochs': 50,
            'patience': 5,
            'learning_rate': 0.005,
            'clamp_gradient_norm': 1.0,
            'out_layer_dropout_keep_prob': 0.75,
            'momentum': 0.9,

            'embed_size': 300,
            'hidden_size': 300,
            'softmax_size': 600,
            'num_timesteps': 5,
            'use_graph': True,
            'decoder_hidden_size': 600, 
            'decoder_num_layers': 1, 

            'tie_fwd_bkwd': False,
            'task_ids': [0],
            'random_seed': 0,
        }

    def __init__(self, args, training_file_count, valid_file_count):
        self.args = args
        self.training_file_count = training_file_count
        self.valid_file_count = valid_file_count
        self.is_training = False

        data_dir = ''
        if args.data_dir is not None:
            data_dir = args.data_dir
        self.data_dir = data_dir

        self.run_id = "_".join([time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
        log_dir = '.'
        self.log_file = os.path.join(log_dir, "%s_log.pickle" % self.run_id)
        self.best_model_file = os.path.join(log_dir, "%s_model_best.pickle" % self.run_id)
        self.best_model_checkpoint = "model_best-%s" % self.run_id

        params = self.default_params()
        config_file = args.config_file
        if config_file is not None:
            with open(config_file, 'r') as f:
                params.update(json.load(f))

        config = args.config
        if config is not None:
            params.update(json.loads(config))

        self.params = params

        print("Run %s starting with following parameters:\n%s" % (self.run_id, json.dumps(self.params)))

        random.seed(params['random_seed'])
        np.random.seed(params['random_seed'])

        self.max_num_vertices = 0
        self.num_edge_types = 8
        self.annotation_size = 0
        self.mini_train_data = None
        self.mini_valid_data = None

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.log_device_placement = True

        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph, config=config)
        with self.graph.as_default():
            tf.set_random_seed(params['random_seed'])
            self.placeholders = {}
            self.weights = {}
            self.ops = {}

            self.make_model()
            if self.is_training:
                self.make_train_step()

            restore_file = args.restore
            if restore_file is not None:
                self.restore_model2( MODEL_PATH + restore_file)
            else:
                self.initialize_model()

    def make_model(self):
        self.placeholders['trg_label'] = tf.placeholder(tf.int64, [None, None],
                                                        name='trg_label')
        self.placeholders['target_mask'] = tf.placeholder(tf.float32, [None, None],
                                                          name='target_mask')
        self.placeholders['num_graphs'] = tf.placeholder(tf.int64, [], name='num_graphs')
        self.placeholders['out_layer_dropout_keep_prob'] = tf.placeholder(tf.float32, [],
                                                                          name='out_layer_dropout_keep_prob')

        self.placeholders['trg_emb'] = tf.placeholder(tf.float32, [None, ], name='trg_emb')
        self.placeholders['trg_size'] = tf.placeholder(tf.int32, [None], name='trg_size')
        self.placeholders['trg_mask'] = tf.placeholder(tf.float32, [None], name='trg_mask')

        with tf.variable_scope("graph_model"):
            self.prepare_specific_graph_model()
            if self.params['use_graph']:
                self.ops['final_node_representations'] = self.compute_final_node_representations()
            else:
                self.ops['final_node_representations'] = tf.zeros_like(self.placeholders['orders_embed'])

        self.ops['losses'] = []

        for (internal_id, task_id) in enumerate(self.params['task_ids']):
            with tf.variable_scope("out_layer_task%i" % task_id):
                with tf.variable_scope("regression_gate"):
                    self.weights['regression_gate_task%i' % task_id] = MLP(2 * self.params['hidden_size'],
                                                                           self.params['hidden_size'], [],
                                                                           self.placeholders[
                                                                               'out_layer_dropout_keep_prob'])
                with tf.variable_scope("regression"):
                    self.weights['regression_transform_task%i' % task_id] = MLP(self.params['hidden_size'],
                                                                                self.params['hidden_size'], [],
                                                                                self.placeholders[
                                                                                    'out_layer_dropout_keep_prob'])

                with tf.variable_scope("softmax"):
                    self.weights['softmax_weights'] = tf.Variable(glorot_init([self.params['softmax_size'], SOFTMAX_SIZE]))
                    self.weights['softmax_biases'] = tf.Variable(np.zeros([SOFTMAX_SIZE]).astype(np.float32))
                    print('softmax weights: ', self.weights['softmax_weights'])

                computed_values = self.gated_regression(self.ops['final_node_representations'],
                                                        self.weights['regression_gate_task%i' % task_id],
                                                        self.weights['regression_transform_task%i' % task_id],
                                                        self.weights['softmax_weights'],
                                                        self.weights['softmax_biases'])
                if self.is_training:
                    groundtruth = self.placeholders['trg_label'][internal_id, :] 

                    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=computed_values, labels=groundtruth)
                    label_weights = self.placeholders['trg_mask']
                    union_tensor = loss * label_weights
                    reshape_tensor = tf.reshape(union_tensor, [60, -1])
                    task_loss = tf.reduce_mean(tf.reduce_sum(reshape_tensor, axis=0))

                    self.ops['accuracy_task%i' % task_id] = task_loss
                    task_loss = task_loss * (1.0 / (self.params['task_sample_ratios'].get(task_id) or 1.0))
                    self.ops['losses'].append(task_loss)
        if self.is_training:
            self.ops['loss'] = tf.reduce_sum(self.ops['losses'])
        self.saver = tf.train.Saver()

    def make_train_step(self):
        trainable_vars = self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if self.args.freeze_graph_model:
            graph_vars = set(self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="graph_model"))
            filtered_vars = []
            for var in trainable_vars:
                if var not in graph_vars:
                    filtered_vars.append(var)
                else:
                    print("Freezing weights of variable %s." % var.name)
            trainable_vars = filtered_vars

        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(self.params['learning_rate'], batch, 2000, 1.0, staircase=False)

        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=self.params['momentum'], use_nesterov=True)


        grads_and_vars = optimizer.compute_gradients(self.ops['loss'], var_list=trainable_vars)
        clipped_grads = []

        for grad, var in grads_and_vars:
            if grad is not None:
                clipped_grads.append((tf.clip_by_norm(grad, self.params['clamp_gradient_norm']), var))
            else:
                clipped_grads.append((grad, var))
        self.ops['train_step'] = optimizer.apply_gradients(clipped_grads, global_step=batch)
        self.sess.run(tf.local_variables_initializer())

    def gated_regression(self, last_h, regreesion_gate, regression_transform, softmax_weights, softmax_biases):
        raise Exception("Models have to implement gated_regression!")

    def prepare_specific_graph_model(self) -> None:
        raise Exception("Models have to implement prepare_specific_graph_model!")

    def compute_final_node_representations(self) -> tf.Tensor:
        raise Exception("Models have to implement compute_final_node_representations!")

    def make_minibatch_iterator(self, data: Any, is_training: bool):
        raise Exception("Models have to implement make_minibatch_iterator!")

    def load_data(self, file_name, is_training_data: bool):
        full_path = os.path.join(self.data_dir, file_name)

        print("Loading data from %s" % full_path)
        with open(full_path, 'r') as f:
            data = json.load(f)

        restrict = self.args.restrict
        if restrict is not None and restrict > 0:
            data = data[:restrict]

        num_fwd_edge_types = 0
        for g in data:
            self.max_num_vertices = max(self.max_num_vertices, max([v for e in g['graph'] for v in [e[0], e[2]]]))
            num_fwd_edge_types = max(num_fwd_edge_types, max([e[1] for e in g['graph']]))
        self.num_edge_types = max(self.num_edge_types, num_fwd_edge_types * (1 if self.params['tie_fwd_bkwd'] else 2))
        return self.process_raw_graphs(data, is_training_data)

    def load_minidata(self, filename, is_training_data: bool):
        with open(filename, 'r') as f:
            data = json.load(f)
        num_fwd_edge_types = 0
        for g in data:
            self.max_num_vertices = max(self.max_num_vertices, max([v for e in g['graph'] for v in [e[0], e[2]]]))
            num_fwd_edge_types = max(num_fwd_edge_types, max([e[1] for e in g['graph']]))
        self.num_edge_types = max(self.num_edge_types, num_fwd_edge_types * (1 if self.params['tie_fwd_bkwd'] else 2))
        return self.process_raw_graphs(data, is_training_data)

    @staticmethod
    def graph_string_to_array(graph_string: str) -> List[List[int]]:
        return [[int(v) for v in s.split(' ')]
                for s in graph_string.split('\n')]


    def process_raw_graphs(self, raw_data: Sequence[Any], is_training_data: bool) -> Any:
        raise Exception("Models have to implement process_raw_graphs!")

    def run_epoch(self, epoch_name: str, data, is_training: bool):

        loss = 0
        accuracies = []
        accuracy_ops = [self.ops['accuracy_task%i' % task_id] for task_id in self.params['task_ids']]
        start_time = time.time()
        read_data_time = 0
        total = 0
        processed_graphs = 0
        count = 0
        file_count = 0
        index = 4
        if is_training:
            file_count = self.training_file_count
        else:
            file_count = self.valid_file_count
        while count < file_count:
            tempGraph = 0
            tempAcc = []
            full_path = "Data/" + str(
                count + index) + ".json"
            filestr = None
            if is_training:
                filestr = "training"
            else:
                filestr = "valid"
            t = time.time()
            data = self.load_minidata(full_path, is_training_data=is_training)
            read_data_time = time.time() - t + read_data_time
            count = count + 1
            batch_iterator = ThreadedIterator(self.make_minibatch_iterator(data, is_training), max_queue_size=5)
            for step, batch_data in enumerate(batch_iterator):
                total = total + 1
                num_graphs = batch_data[self.placeholders['num_graphs']]
                processed_graphs += num_graphs
                tempGraph += num_graphs
                if is_training:
                    batch_data[self.placeholders['out_layer_dropout_keep_prob']] = self.params[
                        'out_layer_dropout_keep_prob']
                    fetch_list = [self.ops['loss'], accuracy_ops, self.ops['train_step']]
                else:
                    batch_data[self.placeholders['out_layer_dropout_keep_prob']] = 1.0
                    fetch_list = [self.ops['loss'], accuracy_ops]

                print(batch_data['graph'])

                result = self.sess.run(fetch_list, feed_dict=batch_data)
                (batch_loss, batch_accuracies) = (result[0], result[1])
                loss += batch_loss * num_graphs
                accuracies.append(np.array(batch_accuracies) * num_graphs)
                tempAcc.append(np.array(batch_accuracies) * num_graphs)

                print("Running %s, %s file %i, batch %i (has %i graphs). Loss so far: %.4f" % (
                    epoch_name, filestr, count, total,
                    num_graphs, loss / processed_graphs), end='\r')
            del data
            del batch_iterator
            gc.collect()
        accuracies = np.sum(accuracies, axis=0) / processed_graphs
        loss = loss / processed_graphs
        end_time = time.time() - start_time
        m, s = divmod(end_time, 60)
        h, m = divmod(m, 60)
        time_str = "%02d:%02d:%02d" % (h, m, s)
        m, s = divmod(read_data_time, 60)
        h, m = divmod(m, 60)
        read_data_time_str = "%02d:%02d:%02d" % (h, m, s)
        instance_per_sec = processed_graphs / (time.time() - start_time - read_data_time)
        return loss, accuracies, instance_per_sec, total, processed_graphs, time_str, read_data_time_str

    def train(self):
        log_to_save = []
        total_time_start = time.time()
        with self.graph.as_default():
            if self.args.restore is not None:
                valid_loss, valid_accs, valid_speed, valid_batch, valid_total_graph, valid_time, valid_read_data_time = self.run_epoch(
                    "Resumed (validation)", self.mini_valid_data, False)
                best_val_acc = np.sum(valid_accs)
                best_val_acc_epoch = 0
                print("\r\x1b[KResumed operation, initial cum. val. acc: %.5f" % best_val_acc)
            else:
                (best_val_acc, best_val_acc_epoch) = (float("-inf"), 0)
                best_val_acc = 100000.0
            for epoch in range(1, self.params['num_epochs'] + 1):
                print("== Epoch %i" % epoch)
                train_loss, train_accs, train_speed, train_batch, train_total_graph, train_time, train_read_data_time = self.run_epoch(
                    "epoch %i (training)" % epoch,
                    self.mini_train_data, True)
                accs_str = " ".join(["%i:%.5f" % (id, acc) for (id, acc) in zip(self.params['task_ids'], train_accs)])
                print(
                    "\r\x1b[K Train: loss: %.5f | acc: %s | instances/sec: %.2f | train_batch: %i | train_total_graph: %i | train_time: %s | train_read_data_time: %s" % (
                        train_loss,
                        accs_str,
                        train_speed, train_batch, train_total_graph, train_time, train_read_data_time))
                valid_loss, valid_accs, valid_speed, valid_batch, valid_total_graph, valid_time, valid_read_data_time = self.run_epoch(
                    "epoch %i (validation)" % epoch,
                    self.mini_valid_data, False)
                accs_str = " ".join(["%i:%.5f" % (id, acc) for (id, acc) in zip(self.params['task_ids'], valid_accs)])
                print(
                    "\r\x1b[K Valid: loss: %.5f | acc: %s | instances/sec: %.2f | valid_batch: %i | valid_total_graph: %i | valid_time: %s | valid_read_data_time: %s" % (
                        valid_loss,
                        accs_str,
                        valid_speed, valid_batch, valid_total_graph, valid_time, valid_read_data_time))

                epoch_time = time.time() - total_time_start
                val_acc = np.sum(valid_accs)
                if val_acc < best_val_acc:
                    self.save_model2(MODEL_PATH) 
                    print("  (Best epoch so far, cum. val. acc increased to %.5f from %.5f. Saving to '%s')" % (
                        val_acc, best_val_acc, self.best_model_file))
                    best_val_acc = val_acc
                    best_val_acc_epoch = epoch
                elif epoch - best_val_acc_epoch >= self.params['patience']:
                    print("Stopping training after %i epochs without improvement on validation accuracy." % self.params[
                        'patience'])
                    break


    def initialize_model(self) -> None:
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        self.sess.run(init_op)

    def save_model(self, path: str) -> None:
        weights_to_save = {}
        for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
            assert variable.name not in weights_to_save
            weights_to_save[variable.name] = self.sess.run(variable)

        data_to_save = {
            "params": self.params,
            "weights": weights_to_save
        }

        with open(path, 'wb') as out_file:
            pickle.dump(data_to_save, out_file, pickle.HIGHEST_PROTOCOL)

    def save_model2(self, path: str) -> None:
        self.saver.save(self.sess, path + self.best_model_checkpoint)

    def restore_model(self, path: str) -> None:
        print("Restoring weights from file %s." % path)
        with open(path, 'rb') as in_file:
            data_to_load = pickle.load(in_file)

        assert len(self.params) == len(data_to_load['params'])
        for (par, par_value) in self.params.items():
            if par not in ['task_ids', 'num_epochs']:
                assert par_value == data_to_load['params'][par]

        variables_to_initialize = []
        with tf.name_scope("restore"):
            restore_ops = []
            used_vars = set()
            for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                used_vars.add(variable.name)
                if variable.name in data_to_load['weights']:
                    restore_ops.append(variable.assign(data_to_load['weights'][variable.name]))
                else:
                    print("Freshly initializing %s since no saved value was found." % variable.name)
                    variables_to_initialize.append(variable)
            for var_name in data_to_load['weights']:
                if var_name not in used_vars:
                    print("Saved weights for %s not used by model." % var_name)

            restore_ops.append(tf.variables_initializer(variables_to_initialize))
            self.sess.run(restore_ops)

    def restore_model2(self, path: str) -> None:
        print("Restore...")
        self.saver.restore(self.sess, path)
        print("Restore done!")


def graph_to_adj_mat(graph, max_n_vertices, num_edge_types, tie_fwd_bkwd=False):
    bwd_edge_offset = 0 if tie_fwd_bkwd else (num_edge_types // 2)
    amat = np.zeros((num_edge_types, max_n_vertices, max_n_vertices))
    for src, e, dest in graph:
        if (e == 0 and src == 0 and dest == 0):
            continue
        amat[e - 1, dest - 1, src - 1] = 1
        amat[e - 1 + bwd_edge_offset, src - 1, dest - 1] = 1
    return amat


class DenseGGNNProgModel(ProgModel):
    def __init__(self, args, training_file_count, valid_file_count):
        super().__init__(args, training_file_count, valid_file_count)
        with open(APIVOCABULARY_PATH, 'r') as f:
            self.api2idx = {}
            self.idx2api = {}
            for idx, word in enumerate(f.readlines()):
                api = word.strip()
                self.api2idx[api] = idx
                self.idx2api[idx] = api
        with open(WHOLEVOCABULARY_PATH, 'r') as f:
            self.word2idx = {}
            self.idx2word = {}
            for idx, token in enumerate(f.readlines()):
                word = token.strip()
                self.word2idx[word] = idx
                self.idx2word[idx] = word
        print("On Serving...")


    @classmethod
    def default_params(cls):
        params = dict(super().default_params())
        params.update({
            'batch_size': 256,
            'graph_state_dropout_keep_prob': 0.75,
            'task_sample_ratios': {},
            'use_edge_bias': True,
        })
        return params

    def prepare_specific_graph_model(self) -> None:
        h_dim = self.params['hidden_size']
        e_dim = self.params['embed_size']

        self.placeholders['graph_state_keep_prob'] = tf.placeholder(tf.float32, None, name='graph_state_keep_prob')
        self.placeholders['edge_weight_dropout_keep_prob'] = tf.placeholder(tf.float32, None,
                                                                            name='edge_weight_dropout_keep_prob')

        self.placeholders['input_orders'] = tf.placeholder(tf.int32, [None, None, ], name='input_orders')
        self.weights['index2vector'] = tf.Variable(dtype=tf.float32,
                                                   initial_value=np.random.uniform(-0.5, 0.5, [INPUT_ORDER_SIZE, e_dim]))

        self.placeholders['output_orders'] = tf.placeholder(tf.int32, [None, None, ], name='output_orders')
        self.weights['output2vector'] = tf.Variable(dtype=tf.float32,
                                                    initial_value=np.random.uniform(-0.5, 0.5, [SOFTMAX_SIZE, e_dim]))

       

        self.placeholders['orders_embed'] = tf.nn.embedding_lookup(self.weights['index2vector'],
                                                                   self.placeholders['input_orders'])
        self.placeholders['trg_emb'] = tf.nn.embedding_lookup(self.weights['output2vector'],
                                                              self.placeholders['output_orders'])
        self.placeholders['initial_node_representation'] = tf.placeholder(tf.float32,
                                                                          [None, None, self.params['hidden_size']],
                                                                          name='node_features')
        self.placeholders['node_mask'] = tf.placeholder(tf.float32, [None, None, self.params['hidden_size']],
                                                        name='node_mask')
        self.placeholders['num_vertices'] = tf.placeholder(tf.int32, ())
        self.placeholders['adjacency_matrix'] = tf.placeholder(tf.float32,
                                                               [None, self.num_edge_types, None, None])
        self.__adjacency_matrix = tf.transpose(self.placeholders['adjacency_matrix'], [1, 0, 2, 3])
        self.weights['edge_weights'] = tf.Variable(glorot_init([self.num_edge_types, h_dim, h_dim]),
                                                   name='edge_weights')
        if self.params['use_edge_bias']:
            self.weights['edge_biases'] = tf.Variable(np.zeros([self.num_edge_types, 1, h_dim]).astype(np.float32),
                                                      name='edge_biases')
        with tf.variable_scope("gru_scope"):
            cell = tf.contrib.rnn.GRUCell(h_dim)
            cell = tf.nn.rnn_cell.DropoutWrapper(cell, state_keep_prob=self.placeholders['graph_state_keep_prob'])
            self.weights['node_gru'] = cell
        self.latest_tokens = tf.placeholder(tf.int32, [None], name='latest_tokens')
        self.dec_init_states = tf.placeholder(tf.float32, [None, 600], name='dec_init_states')
        self.topk_size = tf.placeholder(tf.int32, [], name='topk_size')

    def compute_final_node_representations(self) -> tf.Tensor:
        v = self.placeholders['num_vertices']
        h_dim = self.params['hidden_size']
        e_dim = self.params['embed_size']
        orders = self.placeholders['orders_embed']
        orders_embed = tf.reshape(orders, [-1, h_dim])
        if self.params['use_edge_bias']:
            biases = []
            for edge_type, a in enumerate(tf.unstack(self.__adjacency_matrix, axis=0)):
                summed_a = tf.reshape(tf.reduce_sum(a, axis=-1), [-1, 1])
                biases.append(tf.matmul(summed_a, self.weights['edge_biases'][edge_type]))

        with tf.variable_scope("gru_scope") as scope:
            for i in range(self.params['num_timesteps']):
                if i > 0:
                    tf.get_variable_scope().reuse_variables()
                for edge_type in range(self.num_edge_types):
                    m = tf.matmul(orders_embed, tf.nn.dropout(self.weights['edge_weights'][edge_type],
                                                              self.placeholders['edge_weight_dropout_keep_prob']))
                    if self.params['use_edge_bias']:
                        m += biases[edge_type]
                    m = tf.reshape(m, [-1, v, h_dim])
                    if edge_type == 0:
                        acts = tf.matmul(self.__adjacency_matrix[edge_type], m)
                    else:
                        acts += tf.matmul(self.__adjacency_matrix[edge_type], m)
                acts = tf.reshape(acts, [-1, h_dim])
                orders_embed = self.weights['node_gru'](acts, orders_embed)[1]
            last_h = tf.reshape(orders_embed, [-1, v, h_dim])
        return last_h

    def gated_regression(self, last_h, regression_gate, regression_transform, softmax_weights, softmax_biases):
        gate_input = tf.concat([last_h, self.placeholders['orders_embed']], axis=2)
        gate_input = tf.reshape(gate_input, [-1, 2 * self.params["hidden_size"]])
        last_h = tf.reshape(last_h, [-1, self.params["hidden_size"]])
        gated_outputs = tf.nn.sigmoid(regression_gate(gate_input)) * regression_transform(last_h)
        gated_outputs = tf.reshape(gated_outputs, [-1, self.placeholders['num_vertices'], self.params["hidden_size"]])
        masked_gated_outputs = gated_outputs * self.placeholders['node_mask']
        graph_final_state = tf.reduce_sum(masked_gated_outputs, axis=1)
        HIDDEN_SIZE = self.params['decoder_hidden_size']
        NUM_LAYERS = self.params['decoder_num_layers']
        self.dec_cell = MultiRNNCell_with_name(HIDDEN_SIZE, NUM_LAYERS)
        if self.is_training:
            with tf.variable_scope("decoder"):
                init_state = graph_final_state
                init_state = tf.layers.dense(init_state, self.params['softmax_size'], tf.tanh, name='init_state')
                self.encoder_result = init_state
                dec_outputs, _ = tf.nn.dynamic_rnn(
                    self.dec_cell, self.placeholders['trg_emb'], sequence_length=self.placeholders['trg_size'],
                    initial_state=tuple([init_state[:, :] for _ in range(NUM_LAYERS)]))
            output = tf.reshape(dec_outputs, [-1, self.params['decoder_hidden_size']])
            output = tf.matmul(output, softmax_weights) + softmax_biases
            self.output = output
            return output
        else:
            with tf.variable_scope("decoder"):
                init_state = graph_final_state
                init_state = tf.layers.dense(init_state, self.params['softmax_size'], tf.tanh,
                                             name='init_state')
                self.encoder_result = init_state

            with tf.variable_scope("decoder/rnn"):
                trg_emb = tf.nn.embedding_lookup(self.weights['output2vector'],
                                                     self.latest_tokens)
                final_input = trg_emb
            with tf.variable_scope("decoder/rnn/multi_rnn_cell"):
                dec_outputs, next_state = self.dec_cell.call(
                    state=tuple([self.dec_init_states for _ in range(NUM_LAYERS)]), inputs=final_input)

            output = tf.reshape(dec_outputs, [-1, self.params['decoder_hidden_size']])
            logits = (tf.matmul(output, self.weights['softmax_weights'])
                      + self.weights['softmax_biases'])
            topk_probs, topk_ids = tf.nn.top_k(logits, TOPK_SIZE)
            self.output = topk_probs, topk_ids, next_state


    def process_raw_graphs(self, raw_data: Sequence[Any], is_training_data: bool, bucket_sizes=None) -> Any:
        if bucket_sizes is None:
            bucket_sizes = np.array(list(range(1, 6200, 2)))
        bucketed = defaultdict(list)
        for d in raw_data:
            graph = d['graph']
            orders = d['orders'][0]
            chosen_bucket_idx = np.argmax(bucket_sizes > max([v for e in d['graph']
                                                              for v in [e[0], e[2]]]))
            chosen_bucket_size = bucket_sizes[chosen_bucket_idx]
            n_active_nodes = len(d["orders"][0])
            bucketed[chosen_bucket_idx].append({
                'adj_mat': graph_to_adj_mat(d['graph'], chosen_bucket_size, self.num_edge_types,
                                            self.params['tie_fwd_bkwd']),
                'orders': [d["orders"][0] + [0 for _ in range(chosen_bucket_size - n_active_nodes)]],
                'labels': d["targets"][0],
                'mask': [[1.0 for _ in range(self.params['hidden_size'])] for _ in range(n_active_nodes)] +
                        [[0. for _ in range(self.params['hidden_size'])] for _ in
                         range(chosen_bucket_size - n_active_nodes)]
            })

        if is_training_data:
            for (bucket_idx, bucket) in bucketed.items():
                np.random.shuffle(bucket)
                for task_id in self.params['task_ids']:
                    task_sample_ratio = self.params['task_sample_ratios'].get(str(task_id))
                    if task_sample_ratio is not None:
                        ex_to_sample = int(len(bucket) * task_sample_ratio)
                        for ex_id in range(ex_to_sample, len(bucket)):
                            bucket[ex_id]['labels'][task_id] = None
        bucket_at_step = [[bucket_idx for _ in range(len(bucket_data) // self.params['batch_size'] + 1)]
                          for bucket_idx, bucket_data in bucketed.items()]
        bucket_at_step = [x for y in bucket_at_step for x in y]

        return (bucketed, bucket_sizes, bucket_at_step)

    def make_batch(self, elements):
        batch_data = {'adj_mat': [], 'orders': [], 'output_orders': [], 'labels': [], 'node_mask': [], 'tar_size': [],
                      'task_masks': []}
        max_variable_size = 10
        target_task_values = []
        for d in elements:
            batch_data['adj_mat'].append(d['adj_mat'])
            batch_data['orders'].append(d['orders'])
            batch_data['node_mask'].append(d['mask'])
            target_val_array = d['labels']
            target_length = len(target_val_array)
            if target_length >= MAX_DEC_LEN:
                target_val_array = target_val_array[0:(MAX_DEC_LEN - 1)]
                target_val_array.append(EOS_ID)
            else:
                target_val_array.append(EOS_ID)
            target_length = len(target_val_array)
            output_orders = []
            output_orders.append(SOS_ID)
            for i in range(MAX_DEC_LEN):
                if i < target_length:
                    target_task_values.append(target_val_array[i])
                else:
                    target_task_values.append(EOS_ID + 1)
            for i in range(MAX_DEC_LEN - 1):
                if i < target_length:
                    output_orders.append(target_val_array[i])
                else:
                    output_orders.append(EOS_ID + 1)

            for i in range(target_length):
                batch_data['task_masks'].append(1)
            for i in range(MAX_DEC_LEN - target_length):
                batch_data['task_masks'].append(0)
            batch_data['output_orders'].append(output_orders)
            batch_data['tar_size'].append(target_length)
        batch_data['labels'].append(target_task_values)
        return batch_data

    def make_minibatch_iterator(self, data, is_training: bool):
        self.is_training = is_training
        (bucketed, bucket_sizes, bucket_at_step) = data
        if is_training:
            np.random.shuffle(bucket_at_step)
            for _, bucketed_data in bucketed.items():
                np.random.shuffle(bucketed_data)

        bucket_counters = defaultdict(int)
        dropout_keep_prob = self.params['graph_state_dropout_keep_prob'] if is_training else 1.
        for step in range(len(bucket_at_step)):
            bucket = bucket_at_step[step]
            start_idx = bucket_counters[bucket] * self.params['batch_size']
            end_idx = (bucket_counters[bucket] + 1) * self.params['batch_size']
            elements = bucketed[bucket][start_idx:end_idx]
            batch_data = self.make_batch(elements)
            num_graphs = len(batch_data['orders'])
            batch_data['orders'] = np.squeeze(batch_data['orders'], axis=1)
            if len(batch_data['orders'].shape) == 1:
                batch_data['orders'] = np.expand_dims(batch_data['orders'], axis=0)
            if len(batch_data['labels']) == 0:
                continue
            batch_feed_dict = {
                self.placeholders['input_orders']: batch_data['orders'],
                self.placeholders['output_orders']: batch_data['output_orders'],
                self.placeholders['trg_label']: batch_data['labels'],
                self.placeholders['num_graphs']: num_graphs,
                self.placeholders['num_vertices']: bucket_sizes[bucket], 
                self.placeholders['adjacency_matrix']: batch_data['adj_mat'],
                self.placeholders['node_mask']: batch_data['node_mask'], 
                self.placeholders['graph_state_keep_prob']: dropout_keep_prob,
                self.placeholders['edge_weight_dropout_keep_prob']: dropout_keep_prob,
                self.placeholders['trg_size']: batch_data['tar_size'],
                self.placeholders['trg_mask']: batch_data['task_masks']
            }

            bucket_counters[bucket] += 1
            yield batch_feed_dict

    def pad_annotations(self, annotations):
        return np.pad(annotations, pad_width=[[0, 0], [0, 0], [0, self.params['hidden_size'] - self.annotation_size]],
                      mode='constant')

    def predict(self, graphRepresent, graphVocab, holeParentNodeString, holeSequence, c_clazz, single_targets=[[-1]]):
        clazzes = []
        clazzes = c_clazz.split(" ")
        scopes = []
        for i in range(len(clazzes)):
            scopes.append([1])

        sub = SubString()
        repeat = RepeatSeq()
        cur_order = []
        input_sequence = ""
        context_list = []
        type_list = []
        scope_list = []
        hole_scope = []
        start_scope = 1
        current_temp_scope = []
        current_temp_scope.append(1)

        input_tokens = holeSequence.strip().split(" ")
        input_tokens = input_tokens[0:len(input_tokens)-1]
        for token in input_tokens:
            index = self.word2idx.get(token, INPUT_ORDER_SIZE - 2) 
            word = token
            type_clazz = self.getClazz(word)
            type_list.append(type_clazz)
            if word == "if" or word == "while" or word == "doWhile" or word == "for" or word == "try" or word == "switch":
                start_scope += 1
                current_temp_scope.append(start_scope)
            elif word == "elseif" or word == "else" or word == "catch" or word == "finally" or word == "case" or word == "default":
                start_scope += 1
                current_temp_scope.pop()
                current_temp_scope.append(start_scope)
            elif word == "out_control":
                current_temp_scope.pop()               
            final_current_temp_scope = []
            for sc in current_temp_scope:
                final_current_temp_scope.append(sc)
            scope_list.append(final_current_temp_scope)
            if word != "hole":
                context_list.append(self.dealAPI(word))
        for sc in current_temp_scope:
            hole_scope.append(sc)
        for i in range(len(scope_list)):
                if set(scope_list[i]) <= set(hole_scope):
                    add_clazz_types = type_list[i].split(";")
                    for add_clazz_type in add_clazz_types:
                        clazzes.append(add_clazz_type)


        for idx, word in eval(graphVocab).items():
            real_idx = self.word2idx.get(word, INPUT_ORDER_SIZE - 1)
            cur_order.append(real_idx)
        json_now = {'graph': json.loads(graphRepresent), 'orders': [cur_order], 'targets': single_targets}
        example = self.process_raw_graphs([json_now], is_training_data=False)
        one_batch = self.make_minibatch_iterator(example, False)

        output_apiSeq = "start"
        current_scope = [1]
        last_scope = [1]
        clazz = "none"
        for step, batch in enumerate(one_batch):
            batch[self.placeholders['out_layer_dropout_keep_prob']] = 1.0
            num_graphs = batch[self.placeholders['num_graphs']]

            if num_graphs == 0:
                break
            fetch_list = self.encoder_result
            encoder_result = self.sess.run(fetches=fetch_list, feed_dict=batch)
            hyps = [HypothesisFilter(tokens=[SOS_ID],
                               log_probs=[1.0],
                               state=encoder_result,
                               clazzes = clazzes,
                               scopes = scopes,
                               scope_index = 1,
                               current_scope = current_scope,
                               last_scope = last_scope,
                               control_apis = [],
                               thenbodycontrol_apis = [],
                               context_apis = context_list
                               ) for _ in range(BEAM_SIZE)]
            results = []
            results_tokens_dictionary = {}
            steps = 0
            while steps < MAX_DEC_LEN and len(results) < BEAM_SIZE:
                latest_tokens = [h.latest_token for h in hyps] 
                states = [h.state for h in hyps]  
                states = np.reshape(states, (-1, 600))
                feed = {
                     self.latest_tokens : latest_tokens,
                     self.dec_init_states : states,
                     self.topk_size: TOPK_SIZE
                }
               
                topk_probs, topk_ids, next_state = self.sess.run(fetches=self.output,
                                                                 feed_dict=feed)
                original_topk_probs = topk_probs
                topk_probs = self.softmax(topk_probs)
                next_state = np.reshape(next_state, (-1, 600))
                all_hyps = []
                num_orig_hyps = 1 if steps == 0 else len(hyps)
                for i in range(num_orig_hyps):
                    h, new_state = hyps[i], next_state[i]
                    for j in range(TOPK_SIZE):
                        log_prob = topk_probs[i, j]
                        old = h.latest_token
                        new = topk_ids[i, j]
                        new_api = self.idx2api[new]
                        if new_api == "try" or new_api == "catch" or new_api == "finally" or new_api == "switch" or new_api == "case" or new_api == "default":
                            continue
                        new_hyp = h.extend(token = topk_ids[i, j],
                                           log_prob=log_prob,
                                           state=new_state,
                                           clazz = clazz,
                                           scopes = [1],
                                           scope_index = 4,
                                           current_scope = [1],
                                           last_scope = [1],
                                           control_api = new_api,
                                           thenbodycontrol_api = new_api,
                                           context_api = "new_api"
                                           )
                        all_hyps.append(new_hyp)
                hyps = [] 
                for h in sort_hyps(all_hyps):  
                    if h.latest_token == EOS_ID:  
                        if steps >= MIN_DEC_LEN:
                            results.append(h)
                    else:
                        if steps == (MAX_DEC_LEN - 1):
                            results.append(h)
                        hyps.append(h)
                    if len(hyps) == BEAM_SIZE or len(results) == BEAM_SIZE:
                        break
                steps += 1
            if len(results) == 0: 
                results = hyps
            hyps_sorted = sort_hyps(results)
            topAPIidSeqList = []
            for hyp in hyps_sorted:
                topAPIidSeqList.append(hyp.get_tokens)
            top_10_api_sequence = ''
            for ele in topAPIidSeqList:
                for e in ele:
                    top_10_api_sequence += ' ' + self.idx2api[e]
                top_10_api_sequence += ';'
            top_10_api_sequence = top_10_api_sequence[1:-1]
            return top_10_api_sequence

    def softmax(self, array):
        now_array = np.array(array)
        new_array = []
        for a in now_array:
            b = []
            exp_sum = 0.0
            for num in a:
                exp_sum += np.exp(num)
            for num in a:
                p = np.exp(num) / exp_sum
                b.append(p)
            new_array.append(b)
        return np.array(new_array)
    
    def getClazz(self,new_api):
        clazz = "none"
        if ".new(" in new_api:
            clazz = new_api.split(".new(")[0]
        elif ".Declaration" in new_api:
            clazz = new_api.split(".Declaration")[0]
        elif ".Null" in new_api:
            clazz = new_api.split(".Null")[0]
        elif ".Constant" in new_api:
            clazz = new_api.split(".Constant")[0]
        elif ".Cast" in new_api:
            clazz = new_api.split(".Cast")[0] 
        elif ".ArrayNull[][]" in new_api:
            clazz = new_api.split(".ArrayNull[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayNull[]" in new_api:
            clazz = new_api.split(".ArrayNull[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"               
        elif ".new[][]" in new_api:
            clazz = new_api.split(".new[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".new[][]" in new_api:
            clazz = new_api.split(".new[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".new[]" in new_api:
            clazz = new_api.split(".new[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"
        elif ".ArrayInit[][]" in new_api:
            clazz = new_api.split(".ArrayInit[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayInit[]" in new_api:
            clazz = new_api.split(".ArrayInit[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"                                                       
        elif ".ArrayDeclaration[][]" in new_api:
            clazz = new_api.split(".ArrayDeclaration[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayDeclaration[]" in new_api:
            clazz = new_api.split(".ArrayDeclaration[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"
        elif ".ArrayConstant[][]" in new_api:
            clazz = new_api.split(".ArrayConstant[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayConstant[]" in new_api:
            clazz = new_api.split(".ArrayConstant[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"
        else:
            try:
                clazz = self.return_type[new_api]
                if "[]" in clazz:
                    temp = clazz
                    temp = temp.replace("[]","")
                    clazz = temp + ";" + clazz
            except Exception as e:
                clazz = "none"
        return clazz 

    def dealAPI(self,new_api):
        clazz = new_api
        if ".new(" in new_api:
            clazz = new_api.split(".new(")[0]
        elif ".Declaration" in new_api:
            clazz = new_api.split(".Declaration")[0]
        elif ".Null" in new_api:
            clazz = new_api.split(".Null")[0]
        elif ".Constant" in new_api:
            clazz = new_api.split(".Constant")[0]
        elif ".Cast" in new_api:
            clazz = new_api.split(".Cast")[0] 
        elif ".ArrayNull[][]" in new_api:
            clazz = new_api.split(".ArrayNull[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayNull[]" in new_api:
            clazz = new_api.split(".ArrayNull[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"               
        elif ".new[][]" in new_api:
            clazz = new_api.split(".new[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".new[][]" in new_api:
            clazz = new_api.split(".new[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".new[]" in new_api:
            clazz = new_api.split(".new[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"
        elif ".ArrayInit[][]" in new_api:
            clazz = new_api.split(".ArrayInit[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayInit[]" in new_api:
            clazz = new_api.split(".ArrayInit[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"                                                       
        elif ".ArrayDeclaration[][]" in new_api:
            clazz = new_api.split(".ArrayDeclaration[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayDeclaration[]" in new_api:
            clazz = new_api.split(".ArrayDeclaration[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"
        elif ".ArrayConstant[][]" in new_api:
            clazz = new_api.split(".ArrayConstant[][]")[0]
            temp = clazz 
            clazz += ";" + temp + "[][]"
        elif ".ArrayConstant[]" in new_api:
            clazz = new_api.split(".ArrayConstant[]")[0]
            temp = clazz 
            clazz += ";" + temp + "[]"
        return clazz                     