#!/usr/bin/env python3
# -*- coding:utf-8 -*-
try:
    import sys
    import importlib

    importlib.reload(sys)
    from multiprocessing.dummy import Pool as ThreadPool
    from gspylib.inspection.common import SharedFuncs
    from gspylib.inspection.common.CheckItem import BaseItem
    from gspylib.inspection.common.CheckResult import ResultStatus
    from gspylib.common.Common import DefaultValue
    from gspylib.inspection.common.Exception import CheckNAException
    from gspylib.os.gsfile import g_file
except Exception as ie:
    raise Exception("[GAUSS-52200] : Unable to import module: %s." % str(ie))

# cn
INSTANCE_ROLE_COODINATOR = 3
# dn
INSTANCE_ROLE_DATANODE = 4

MASTER_INSTANCE = 0


class CheckSysTable(BaseItem):
    def __init__(self):
        super(CheckSysTable, self).__init__(self.__class__.__name__)
        self.database = None

    def preCheck(self):
        # check the threshold was set correctly
        if (not self.threshold.__contains__('database')):
            raise Exception("threshold database can not be empty")
        self.database = self.threshold['database']

    def checkSingleSysTable(self, Instance):
        tablelist = ["pg_attribute", "pg_class", "pg_constraint", "pg_partition", "pgxc_class", "pg_index", "pg_stats"]
        resultMap = {}
        for i in tablelist:
            sqlFile = "%s/sqlFile_%s_%s.sql" % (self.tmpPath, i, Instance.instanceId)
            resFile = "%s/resFile_%s_%s.out" % (self.tmpPath, i, Instance.instanceId)
            g_file.createFile(sqlFile, True, 644)
            g_file.createFile(resFile, True, 644)
            g_file.changeOwner(self.user, sqlFile)
            g_file.changeOwner(self.user, resFile)
            sql = "select * from pg_catalog.pg_table_size('%s');" % i
            sql = "%sselect count(*) from %s;" % (sql, i)
            sql = "%sselect * from pg_catalog.pg_column_size('%s');" % (sql, i)
            g_file.writeFile(sqlFile, [sql])

            cmd = "gsql -d %s -p %s -f %s --output %s -t -A -X" % (
                self.database, Instance.port, sqlFile, resFile)
            if self.mpprcFile:
                cmd = "source '%s' && %s" % (self.mpprcFile, cmd)
            SharedFuncs.runShellCmd(cmd, self.user)

            restule = g_file.readFile(resFile)
            g_file.removeFile(sqlFile)
            g_file.removeFile(sqlFile)

            size = restule[0].strip()
            line = restule[1].strip()
            width = restule[2].strip()
            Role = ""
            if (Instance.instanceRole == INSTANCE_ROLE_COODINATOR):
                Role = "CN"
            elif (Instance.instanceRole == INSTANCE_ROLE_DATANODE):
                Role = "DN"
            instanceName = "%s_%s" % (Role, Instance.instanceId)
            resultMap[i] = [instanceName, size, line, width]
        return resultMap

    def checkSysTable(self):
        primaryDNidList = []
        nodeInfo = self.cluster.getDbNodeByName(self.host)
        CN = nodeInfo.coordinators
        masterDnList = SharedFuncs.getMasterDnNum(self.user, self.mpprcFile)
        for DnInstance in nodeInfo.datanodes:
            if (DnInstance.instanceId in masterDnList):
                primaryDNidList.append(DnInstance)
        if (len(CN) < 1 and len(primaryDNidList) < 1):
            raise CheckNAException("There is no CN instance and primary DN instance in the current node.")

        # test database Connection
        for Instance in (CN + primaryDNidList):
            if (Instance == "" or Instance is None):
                continue
            sqlcmd = "select pg_catalog.pg_sleep(1);"
            SharedFuncs.runSqlCmd(sqlcmd, self.user, "", Instance.port, self.tmpPath, self.database, self.mpprcFile)
        outputList = []
        pool = ThreadPool(DefaultValue.getCpuSet())
        results = pool.map(self.checkSingleSysTable, CN + primaryDNidList)
        pool.close()
        pool.join()
        for result in results:
            if result:
                outputList.append(result)
        return outputList

    def doCheck(self):
        resultStr = ""
        resultStr += "Instance table           size            row      width row*width\n"
        outputList = self.checkSysTable()
        for resultMap in outputList:
            for table in resultMap.keys():
                resultStr = "%s%s  %s %s %s %s %s\n" % (resultStr, resultMap[table][0], table.ljust(15),
                                                         resultMap[table][1].ljust(15), resultMap[table][2].ljust(8),
                                                         resultMap[table][3].ljust(5),
                                                         int(resultMap[table][2]) * int(resultMap[table][3]))

        self.result.val = resultStr
        self.result.raw = resultStr
        self.result.rst = ResultStatus.OK
