#!/usr/bin/env python3
# -*- coding:utf-8 -*-


try:
    import sys
    import importlib

    importlib.reload(sys)
    import os
    from gspylib.inspection.common import SharedFuncs
    from gspylib.inspection.common.CheckItem import BaseItem
    from gspylib.inspection.common.CheckResult import ResultStatus
    from gspylib.hardware.gsdisk import g_disk
    from gspylib.common.Common import ClusterCommand
    from gspylib.common.ErrorCode import ErrorCode
except ImportError as ie:
    raise Exception("[GAUSS-52200] : Unable to import module: %s." % str(ie))


class CheckSpaceForShrink(BaseItem):
    def __init__(self):
        super(CheckSpaceForShrink, self).__init__(self.__class__.__name__)
        self.Threshold_NG = None
        self.cluster_LC = None

    def preCheck(self):
        # check current node contains cn instances if not raise  exception
        super(CheckSpaceForShrink, self).preCheck()
        # check the threshold was set correctly
        if not (self.threshold.__contains__('Threshold_NG')):
            raise Exception("The threshold Threshold_NG can not be empty.")
        self.Threshold_NG = int(self.threshold['Threshold_NG'])

    def obtainDataDir(self, nodeInfo):
        dataDirList = {}
        for inst in nodeInfo.datanodes:
            if inst.hostname not in self.ShrinkNodes:
                dataDirList[inst.datadir] = ["DN", 0]
        return dataDirList

    def obtainDiskDir(self):
        cmd = "df -h -P | awk '{print $NF}'"
        output = SharedFuncs.runShellCmd(cmd)
        allDiskPath = output.split('\n')[1:]
        return allDiskPath

    def initCluster_LC(self):
        if self.LCName:
            from gspylib.common.DbClusterInfo import dbClusterInfo
            self.cluster_LC = dbClusterInfo()
            filename = os.getenv("GAUSSHOME") + "/bin/%s.cluster_static_config" % self.LCName
            self.cluster_LC.initFromStaticConfig(self.user, filename, True)
        else:
            self.cluster_LC = self.cluster

    def checkLcGroupNameExist(self, lcName, cnport):
        """
        function: Check if the logical cluster name exists
        input :NA
        output:NA
        """
        sql = "SELECT count(*) FROM pg_catalog.pgxc_group \
               WHERE group_name='%s' and group_kind = 'v';" % lcName
        (status, output) = ClusterCommand.remoteSQLCommand(sql,
                                                           self.user,
                                                           cnport.hostname,
                                                           cnport.port,
                                                           ignoreError=False)
        if status != 0 or not output.isdigit():
            raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % sql + " Error:\n%s" % str(output))
        elif int(output) == 0:
            raise Exception(ErrorCode.GAUSS_504["GAUSS_50410"]
                            + " The specified logical cluster name does not exist in the cluster.")

    def getOneCNInst(self):
        """
        function: find a cn instance by dbNodes, which we can execute SQL commands
        input : NA
        output: CN instance
        notice: now gs_upgradectl only support binary upgrade/online upgrade/inplace upgrade,
                The cluster structure is consistent before and after the upgrade.
                As self.context.oldClusterInfo.dbNodes and self.context.clusterInfo.dbNodes are same.
        """
        try:
            cooInst = None
            for dbNode in self.cluster.dbNodes:
                if len(dbNode.coordinators) > 0:
                    cooInst = dbNode.coordinators[0]
                    break
            # check if contain CN on nodes
            if not cooInst:
                raise Exception(ErrorCode.GAUSS_526["GAUSS_52602"])
            else:
                return cooInst
        except Exception as e:
            raise Exception(str(e))

    def doCheck(self):
        DiskInfoDict = {}
        diskList = {}
        pathList = []
        cn = self.getOneCNInst()
        self.check_shrink_completed_sql()
        # get dns path list except shrink dns
        self.checkLcGroupNameExist(self.LCName, cn)
        self.initCluster_LC()
        if self.cluster:
            nodeinfo = self.cluster_LC.getDbNodeByName(self.host)
            if nodeinfo:
                pathDisk = self.obtainDataDir(nodeinfo)
                pathList = list(pathDisk.keys())
            # no dn in this host , no need to check
            if not pathList:
                self.result.rst = ResultStatus.NA
                return
        else:
            pathList = self.obtainDiskDir()
        # get the max dn size in cluster
        deltasize = self.get_max_dn_size(cn)
        # get all dn's disk and dn num in each disk
        # if a disk has N dns
        # the data added of this disk is N*deltasize
        self.get_dn_disk_and_collect_result(pathList, diskList, deltasize, DiskInfoDict)

    def check_shrink_completed_sql(self):
        # get cn port
        cn = self.getOneCNInst()
        # check if the last shrink was completed
        sql_query = "SELECT count(*) FROM pg_catalog.pgxc_group \
                             WHERE group_kind='v' and in_redistribution='t'"
        (status, output) = ClusterCommand.remoteSQLCommand(sql_query,
                                                           self.user,
                                                           cn.hostname,
                                                           cn.port,
                                                           ignoreError=False)
        if status != 0 or not output.isdigit():
            raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % sql_query + " Error:\n%s" % str(output))
        elif int(output) == 1:
            self.result.rst = ResultStatus.OK
            return
        elif int(output) != 0:
            raise Exception(ErrorCode.GAUSS_504["GAUSS_50410"] + " More than one cluster in redistributing.")

    def get_max_dn_size(self, cn):
        try:
            sql_query = "select pg_catalog.pgxc_max_datanode_size('%s')" % self.LCName
            (status, output) = ClusterCommand.remoteSQLCommand(sql_query,
                                                               self.user,
                                                               cn.hostname,
                                                               cn.port,
                                                               ignoreError=False)
            if 0 != status or "" == output:
                raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % sql_query + " Error:\n%s" % str(output))
            dn_size_max = int(output.split("\n")[0].strip())
            # dns in each host is same,so data added in each dn is dnsize_max*num_shrink/num_remain
            dn_name_all = self.cluster_LC.getClusterNodeNames()
            num_shrink = len(self.ShrinkNodes.split(','))
            num_remain = len(dn_name_all) - num_shrink
            delta_size = float(dn_size_max) * num_shrink / num_remain
        except ZeroDivisionError as ex:
            raise ZeroDivisionError from ex
        return delta_size

    def get_dn_disk_and_collect_result(self, pathList, diskList, deltasize, DiskInfoDict):
        flag = "Normal"
        resultStr = ""
        for path in pathList:
            diskName = g_disk.getMountPathByDataDir(path)
            if diskName not in diskList.keys():
                diskList[diskName] = [path, 1]
            else:
                diskList[diskName][1] += 1
        # get the usage after shrink
        if diskList:
            for _, value in diskList.items():
                usageInfo = round(g_disk.getDiskSpaceForShrink(value[0], deltasize * value[1]), 2)
                diskInfo = "%s %s%%" % (diskName, usageInfo)
                DiskInfoDict[usageInfo] = diskInfo
                rateNum = usageInfo
            if rateNum > self.Threshold_NG:
                resultStr += "The usage of the device disk space[%s:%d%%] cannot be greater than %d%%.\n" % (
                    diskName, rateNum, self.Threshold_NG)
                flag = "Error"
        else:
            resultStr += "error to get disk list of dns"
            flag = "Error"
        # set result
        self.result.val = resultStr
        if flag == "Normal":
            self.result.rst = ResultStatus.OK
        else:
            self.result.rst = ResultStatus.NG
        # set raw and val
        if DiskInfoDict:
            keys = list(DiskInfoDict.keys())
            keys.sort()
            MaxDisk = list(map(DiskInfoDict.get, keys))[-1]
            MinDisk = list(map(DiskInfoDict.get, keys))[0]
            self.result.val = "%s\nDisk     Filesystem spaceUsage\nMax free %s\nMin free %s" % (self.result.val,
                                                                                                MaxDisk,
                                                                                                MinDisk)
            for diskInfo in map(DiskInfoDict.get, keys):
                self.result.raw = "%s\n%s" % (self.result.raw, diskInfo)
