#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c): 2012-2017, Huawei Tech. Co., Ltd.
# Description  : Result.py is a utility to store search result from database
#############################################################################
try:
    import os
    import sys
    import math
    from multiprocessing.dummy import Pool as ThreadPool
    sys.path.append(sys.path[0] + "/../../")

    from gspylib.common.Common import DefaultValue
    from gspylib.common.ErrorCode import ErrorCode
    from gspylib.common.Common import ClusterCommand
except ImportError as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

PER_BATCH_NUMS = 2000
PROCESS_MAX_TIMES = 10
PROCESS_DEGRADE_TIMES = 5
CONCURRENT_NUMBERS = 8
DN_NUMS_PER_NODE = 4


class SqlBatchProcess(object):
    """
    Class for process sql result from database
    """

    def __init__(self, sql, user, db_inst, database, logger, ignore_errors=False):
        """
        Constructor
        """
        self.sql = sql
        self.user = user
        self.db_inst = db_inst
        self.database = database
        self.logger = logger
        self.sql_result_dict = None
        self.ignore_errors = ignore_errors

    def query_sql_result(self):
        (status, output) = ClusterCommand.execSQLCommand(self.sql,
                                                         self.user,
                                                         "",
                                                         self.db_inst.port,
                                                         self.database,
                                                         is_inplace_upgrade=True)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % self.sql +
                            " DN [%s], DB [%s] Error:\n%s" % (self.db_inst.instanceId, self.database, str(output)))
        if output.strip() == "":
            return []
        sql_result_dict = output.strip().split("\n")
        return sql_result_dict

    def exec_sql(self, process_number):
        begin_index = process_number * PER_BATCH_NUMS
        end_index = begin_index + PER_BATCH_NUMS
        sql_list = self.sql_result_dict[begin_index:end_index]
        sql_info = "\n".join(sql_list)
        (status, output) = ClusterCommand.execSQLCommand(sql_info,
                                                         self.user,
                                                         "",
                                                         self.db_inst.port,
                                                         self.database,
                                                         is_inplace_upgrade=True)
        if status != 0:
            self.logger.debug("Warning: Batch of data that fails to be processed, batch nums:%s" % process_number)
            self.logger.debug(ErrorCode.GAUSS_514["GAUSS_51400"] % sql_info +
                              " DN [%s], DB [%s] Error:\n%s" % (self.db_inst.instanceId, self.database, str(output)))

    def get_new_parall_jobs(self, parallel_jobs, sql_nums, retry_times):
        # The SQL data is small, or the the last retry is changed to 1.
        if sql_nums <= PER_BATCH_NUMS or retry_times == PROCESS_MAX_TIMES - 1:
            parallel_jobs = 1

        # parallel_jobs is degraded in the last few times.
        if parallel_jobs > 1 and retry_times > PROCESS_DEGRADE_TIMES:
            parallel_jobs = parallel_jobs // 2
        return parallel_jobs

    def run(self):
        """
        function : Query the SQL result and execute the query result concurrently.
        input:NA
        output:NA
        """
        parallel_jobs =  DefaultValue.getCpuSet() // DN_NUMS_PER_NODE
        if parallel_jobs < CONCURRENT_NUMBERS:
            parallel_jobs = CONCURRENT_NUMBERS
        retry_times = 0
        self.logger.debug("Processing the db data of instance [%s] in the database[%s]" % (self.db_inst.instanceId,
                                                                                           self.database))
        try:
            while retry_times < PROCESS_MAX_TIMES:
                self.sql_result_dict = self.query_sql_result()
                sql_nums = len(self.sql_result_dict)
                if sql_nums == 0:
                    break
                parallel_jobs = self.get_new_parall_jobs(parallel_jobs, sql_nums, retry_times)
                self.logger.debug("Query numbers of data records "
                                  "to be process from the database [%s] on the inst [%s], data nums:%s, "
                                  "retry times:%d, jobs:%d, "
                                  " sql:\n %s" % (self.database, self.db_inst.instanceId, sql_nums,
                                                  retry_times, parallel_jobs, self.sql))
                retry_times += 1
                process_nums = math.ceil(sql_nums / PER_BATCH_NUMS)
                process_list = [i for i in range(process_nums)]
                pool = ThreadPool(parallel_jobs)
                pool.map(self.exec_sql, process_list)
                pool.close()
                pool.join()
            if sql_nums > 0:
                if self.ignore_errors:
                    self.logger.debug("Ignore the error.")
                else:
                    raise Exception("data is not completely processed.")
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % self.sql +
                            " DN [%s], DB [%s] Error: %s" % (self.db_inst.instanceId, self.database, str(e)))