#!/usr/bin/env python3
# coding: UTF-8
try:
    import sys
    import importlib
    importlib.reload(sys)
    import os
    import socket
    from gspylib.inspection.common import SharedFuncs
    from gspylib.inspection.common.CheckItem import BaseItem
    from gspylib.inspection.common.CheckResult import ResultStatus
    from gspylib.common.Common import DefaultValue, ClusterCommand
except ImportError as ie:
    raise Exception("[GAUSS-52200] : Unable to import module: %s." % str(ie))


class CheckTableSkew(BaseItem):
    def __init__(self):
        super(CheckTableSkew, self).__init__(self.__class__.__name__)

    def check_table(self):
        final_result = ""
        sqlPath = os.path.realpath(
            os.path.join(os.path.split(os.path.realpath(__file__))[0],
                         "../../lib/checkblacklist/"))
        sqlFileName = os.path.join(sqlPath, "GetTableSkew.sql")
        sqldb = "select datname from pg_database;"
        output = SharedFuncs.runSqlCmd(sqldb, self.user, "", self.port,
                                       self.tmpPath, "postgres",
                                       self.mpprcFile)
        dbList = output.split("\n")
        dbList.remove("template0")
        dbList.remove("template1")
        for db in dbList:
            db = db.replace("$", "\\$")
            cmd = "gsql -d %s -p %s -f %s" % (db, self.port, sqlFileName)
            tmpout = ""
            output = SharedFuncs.runShellCmd(cmd, self.user, self.mpprcFile)
            if output.find("(0 rows)") < 0:
                tmpresult = output.splitlines()
                idxS = 0
                idxE = 0
                for idx in range(len(tmpresult)):
                    if not tmpresult[idx].find("---+---") < 0:
                        idxS = idx - 1
                    if tmpresult[idx].find("row)") > 0 or tmpresult[idx].find("rows)") > 0:
                        idxE = idx
                for i in range(idxS, idxE):
                    tmpout = "%s%s\n" % (tmpout, tmpresult.get(i))
                final_result = "%s%s:\n%s\n" % (final_result, db, tmpout)
        return final_result

    def check_dws_table(self):
        finalResult = ""
        dbList = []
        sqlPath = os.path.realpath(
            os.path.join(os.path.split(os.path.realpath(__file__))[0],
                         "../../lib/checkblacklist/"))
        sqlFileName = os.path.join(sqlPath, "GetTableSkew.sql")
        if os.path.exists(sqlFileName):
            try:
                with open(sqlFileName, "r") as fp:
                    lines = fp.read()
                    sqlList = lines.split("--sqlblock")
                    sqlList.pop()
            except Exception as e:
                raise Exception(
                    "Unable to read file:%s,Error:%s" % (sqlFileName, str(e)))
        else:
            raise Exception("Can't find sql file:%s" % sqlFileName)
        sqldb = "select datname from pg_database;"
        (status, result, error) = ClusterCommand.excuteSqlOnLocalhost(self.port, sqldb)
        if status != 2:
            raise Exception("Execute sql:%s failed. Error:%s" % (sqldb, error))
        recordsCount = len(result)
        for i in range(0, recordsCount):
            dbList.append(result[i][0])
        dbList.remove("template0")
        dbList.remove("template1")
        for db in dbList:
            schemaTable = []
            for sql in sqlList:
                sql = "set client_min_messages='error';\n%s" % sql
                ClusterCommand.excuteSqlOnLocalhost(self.port, sql, db)
            sql = "SELECT  schemaname , tablename FROM PUBLIC.pgxc_analyzed_skewness WHERE skewness_tuple > 100000;"
            (status, result, error) = ClusterCommand.excuteSqlOnLocalhost(
                self.port, sql, db)
            if status != 2:
                raise Exception("Execute sql:%s failed. Error:%s" % (sql, error))
            else:
                for i in range(len(result)):
                    schema = result[i][0]
                    table = result[i][1]
                    schemaTable.append("%s.%s" % (schema, table))
            if schemaTable:
                finalResult += "%s:\n%s\n" % (db, "\n".join(schemaTable))
        return finalResult

    def doCheck(self):
        flag = SharedFuncs.getFirstCNInstance(self.user, self.mpprcFile, self.tmpPath)
        dataDir = self.cluster.getDbNodeByName(socket.gethostname()).cmagents[0].datadir
        security_mode_value = DefaultValue.getSecurityMode(dataDir)
        if flag or self.cluster.isSingleInstCluster():
            if security_mode_value == "on":
                final_result = self.check_dws_table()
                if final_result:
                    self.result.rst = ResultStatus.WARNING
                    self.result.val = "The result is not ok:\n%s" % final_result
                else:
                    self.result.rst = ResultStatus.OK
                    self.result.val = "Data is well distributed"
            else:
                final_result = self.check_table()
                if final_result:
                    self.result.rst = ResultStatus.WARNING
                    self.result.val = "Data is not well distributed:\n%s" % final_result
                else:
                    self.result.rst = ResultStatus.OK
                    self.result.val = "Data is well distributed"
        else:
            self.result.rst = ResultStatus.NA
            self.result.val = "First cn is not in this host"
