#!/usr/bin/env python3
# coding: UTF-8
try:
    import sys
    import socket
    import importlib
    importlib.reload(sys)
    from gspylib.inspection.common import SharedFuncs
    from gspylib.inspection.common.CheckItem import BaseItem
    from gspylib.inspection.common.CheckResult import ResultStatus
except ImportError as ie:
    raise Exception("[GAUSS-52200] : Unable to import module: %s." % str(ie))

g_retry = 3


class CheckDnWait(BaseItem):
    def __init__(self):
        super(CheckDnWait, self).__init__(self.__class__.__name__)

    def getInstInformation(self, masterDnlist):
        DN_HOST = {}
        CN_HOST = {}
        noMasterDNNode = []
        datanode = None
        for dbNode in self.cluster.dbNodes:
            for datanode in dbNode.datanodes:
                if(datanode.instanceId in masterDnlist):
                    DN_HOST[datanode.instanceId] = datanode.hostname
        for dbNode in self.cluster.dbNodes:
            for cn in dbNode.coordinators:
                CN_HOST[str(cn.instanceId)] = datanode.hostname
                for ip in cn.listenIps:
                    CN_HOST[ip] = datanode.hostname


        for dbNode in self.cluster.dbNodes:
            noMasterFlag = True
            for datanode in dbNode.datanodes:
                if (datanode.instanceId in masterDnlist):
                    noMasterFlag = False
            if (noMasterFlag):
                noMasterDNNode.append(dbNode)
        return (CN_HOST, DN_HOST, noMasterDNNode)

    def parseCnAndSlave(self, totalWait, noMasterDNNode, dnlist, dicDN, dicHost):
        tmpresult = ""
        if(totalWait > 0):
            for host, dnlist in dicHost.items():
                flag = False
                for dn in dnlist:
                    if(dn in dicDN.keys()):
                        flag = True
                if(flag):
                    tmpresult += "%s " % (host)
                    tmplist = []
                    for dn in dnlist:
                        if(dn in dicDN.keys()):
                            tmplist.append("%s:%s" % (dn, dicDN[dn]))
                    tmpStr = ",".join(tmplist)
                    tmpresult += "%s\n" % tmpStr
        if(tmpresult):
            tmpresult = "monitor CN:\n%s" % tmpresult
        slaveDN = ""
        for dbNode in noMasterDNNode:
            tmplist = []
            wait = 0
            for datanode in dbNode.datanodes:
                for dnName in dicDN.keys():
                    if (str(datanode.instanceId) in dnName):
                        tmpid = "dn_%s" % (datanode.instanceId)
                        tmplist.append("%s:%s" % (tmpid, dicDN[dnName]))
                        wait += dicDN[dnName]
            if (wait):
                slaveDN += "%s %s\n" % (dbNode.name, ",".join(tmplist))
        if (slaveDN):
            tmpresult += "\nmonitor SalveDN:\n%s" % slaveDN
        return tmpresult

    def parseDn(self, dicHost):
        dicDNCount = {}
        tmpresult = ""
        try:
            sqlMonitorDN = """with para as \
        (select count(*) from pg_catalog.pgxc_stat_activity where state = 'active' and usename != 'omm') \
        select node, wait_num, query_num from (select query_id, node, sum(total_num) \
        over(partition by node ) as wait_num, count(query_id) over(partition by node) \
        as query_num, row_number() over(partition by node) as rownum \
        from (select query_id, case when wait_status like 'wait node:%' \
        then pg_catalog.split_part(pg_catalog.split_part(substr(wait_status, 12), ',', 1), '(', 1) \
        when wait_status like 'flush data:%' then \
        pg_catalog.split_part(pg_catalog.split_part(substr(wait_status, 13), ',', 1), '(', 1) end \
        as node, sum(num) as total_num \
        from (select query_id, wait_status, count(*) as num \
        from pg_catalog.pgxc_thread_wait_status where wait_status like 'wait node:%' \
        or wait_status like 'flush data:%' group by 1, 2) group by 1, 2 \
        order by 3 desc)) where 3 * 4 * query_num > (select * from para) \
        and rownum = 1 order by wait_num desc,  query_num desc
"""
            outputMonitorDN = self.executeSqlWithRetry(sqlMonitorDN)
        except Exception as e:
            self.result.val = "%s %s" % (self.host, str(e))
            self.result.rst = ResultStatus.NG
            return ""
        if (outputMonitorDN):
            for line in outputMonitorDN.splitlines():
                dnName = line.split("|")[0]
                count = line.split("|")[1]
                if (dnName in dicDNCount.keys()):
                    dicDNCount[dnName] += int(count)
                else:
                    dicDNCount[dnName] = int(count)

        tmpDnResult = ""
        if (dicDNCount):
            dnMapList = sorted(list(dicDNCount.items()), key=lambda item: item[1])
            if (len(dnMapList) > 8):
                dnMapList = dnMapList[-8:]
            for host, dnlist in dicHost.items():
                tmplist = []
                for dn, count in dnMapList:
                    if (dn in dnlist):
                        tmplist.append("%s:%s" % (dn, dicDNCount[dn]))
                if tmplist:
                    tmpDnResult += "%s %s\n" % (host, ",".join(tmplist))
            if (tmpDnResult):
                tmpresult = "\n monitor DN:\n%s" % tmpDnResult
        return tmpresult

    def doCheck(self):
        dicHost = {}
        dnlist = []
        CN_HOST, noMasterDNNode = self.get_master_dn_list(dnlist, dicHost)
        self.get_slow_cn_dn_result(dnlist, dicHost, CN_HOST, noMasterDNNode)

    def executeSqlWithRetry(self, sql):
        retry = 0
        excep = None
        while retry < g_retry:
            try:
                output = SharedFuncs.runSqlCmdWithTimeOut(sql, self.user, "", self.port, self.tmpPath,
                                                          "postgres", self.mpprcFile)
                return output
            except Exception as e:
                retry += 1
                excep = e
        raise Exception(str(excep))

    def get_master_dn_list(self, dnlist, dicHost):
        masterDnlist = SharedFuncs.getMasterDnNum(self.user, self.mpprcFile)
        (CN_HOST, DN_HOST, noMasterDNNode) = self.getInstInformation(masterDnlist)
        sqlgetDNHost = "SELECT node_name, node_host FROM pg_catalog.pgxc_node WHERE node_type='D' order by node_name;"
        output = self.executeSqlWithRetry(sqlgetDNHost)
        if output.startswith("ERROR") or output.startswith("TIMEOUT"):
            self.result.val = "%s %s" % (self.host, output)
            self.result.rst = ResultStatus.NG
            return CN_HOST, noMasterDNNode
        for i in range(len(output.splitlines()) - 1):
            dnName, hostName = output.splitlines()[i].split('|')
            dnName = dnName.strip()
            hostName = hostName.strip()
            dnlist.append(dnName)
            for masterDn in masterDnlist:
                if str(masterDn) in dnName:
                    hostName = DN_HOST[masterDn]
            if hostName in dicHost.keys():
                tmplist = dicHost[hostName]
                tmplist.append(dnName)
                dicHost[hostName] = tmplist
            else:
                dicHost[hostName] = [dnName]
        return CN_HOST, noMasterDNNode

    def get_slow_cn_dn_result(self, dnlist, dicHost, CN_HOST, noMasterDNNode):
        """
        """
        cnlist = list()
        dicDN = dict()
        sqlgetCNHost = "SELECT node_name, node_host FROM pg_catalog.pgxc_node WHERE node_type='C' order by node_name;"
        output = self.executeSqlWithRetry(sqlgetCNHost)
        if output.startswith("ERROR") or output.startswith("TIMEOUT"):
            self.result.val = "%s %s" % (self.host, output)
            self.result.rst = ResultStatus.NG
            return
        for i in range(len(output.splitlines())):
            cnName, hostName = output.splitlines()[i].split('|')
            dnlist.append(cnName)
            cnlist.append(cnName)
            if hostName == "localhost":
                hostName = socket.gethostname()
            else:
                hostName = CN_HOST[hostName]
            if hostName in dicHost.keys():
                tmplist = dicHost[hostName]
                tmplist.append(cnName)
                dicHost[hostName] = tmplist
            else:
                dicHost[hostName] = [cnName]
        totalWait = self.execute_cn_wait_status_sql(cnlist, CN_HOST, dnlist, dicDN)
        cnResult = self.parseCnAndSlave(totalWait, noMasterDNNode, dnlist, dicDN, dicHost)
        dnResult = self.parseDn(dicHost)

        self.result.rst = ResultStatus.OK
        self.result.val = cnResult + dnResult

    def execute_cn_wait_status_sql(self, cnlist, CN_HOST, dnlist, dicDN):
        """
        """
        totalWait = 0
        for cn in cnlist:
            tmpsql = "execute direct on(%s) 'select wait_status from pg_catalog.pg_thread_wait_status';" % (cn)
            try:
                output = self.executeSqlWithRetry(tmpsql)
            except Exception as e:
                self.result.val = "%s %s" % (self.host, str(e))
                self.result.rst = ResultStatus.NG
                return totalWait
            if output.startswith("ERROR") or output.startswith("TIMEOUT"):
                self.result.val = "%s %s" % (CN_HOST[cn[3:]], output)
                self.result.rst = ResultStatus.NG
                return totalWait
            for line in output.splitlines():
                if line.startswith("wait node"):
                    if len(line.split(",")[0].split(":")) != 2:
                        continue
                    tmpDn = line.split(",")[0].split(":")[1].strip()
                    if tmpDn in dnlist:
                        totalWait += 1
                        if tmpDn in dicDN.keys():
                            dicDN[tmpDn] = dicDN[tmpDn] + 1
                        else:
                            dicDN[tmpDn] = 1
        return totalWait
