#!/usr/bin/env python3
# -*- coding:utf-8 -*-
try:
    import socket
    import sys
    import re
    import os
    import time

    sys.path.append(sys.path[0] + "/../../")
    from gspylib.common.GaussLog import GaussLog
    from gspylib.common.DbClusterInfo import dbClusterInfo
    from gspylib.common.Common import DefaultValue, ClusterCommand
    from gspylib.os.gsOSlib import g_OSlib
    from gspylib.common.ErrorCode import ErrorCode
    from gspylib.component.CM.CM_OLAP.CM_OLAP import CM_OLAP
    from gspylib.component.ETCD.ETCD_OLAP.ETCD_OLAP import ETCD_OLAP
    from gspylib.component.GTM.GTM_OLAP.GTM_OLAP import GTM_OLAP
    from gspylib.component.Kernel.CN_OLAP.CN_OLAP import CN_OLAP
    from gspylib.component.Kernel.DN_OLAP.DN_OLAP import DN_OLAP
    from gspylib.threads.SshTool import SshTool

except ImportError as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

KEY = "local_public"


class LocalBaseOM(object):
    """
    Base class for local command
    """

    def __init__(self, logFile=None, user=None, clusterConf=None, dwsMode=False, initParas=None, gtmInitParas=None):
        '''
        Constructor
        '''
        if initParas is None:
            initParas = []
        if gtmInitParas is None:
            gtmInitParas = []
        self.logger = None
        self.clusterInfo = None
        self.dbNodeInfo = None
        self.clusterConfig = clusterConf
        self.user = user
        self.group = ""
        self.dws_mode = dwsMode
        self.initParas = initParas
        self.gtmInitParas = gtmInitParas
        self.etcdCons = []
        self.cmCons = []
        self.gtmCons = []
        self.cnCons = []
        self.dnCons = []
        self.gtsCons = []
        self.logAction = ""
        self.logUuid = ""
        self.logStep = 0

    def initLogger(self, module=""):
        """
        function: Init logger
        input : module
        output: NA
        """

        PATTERN = "^[a-zA-Z0-9-]{36}$"
        pattern = re.compile(PATTERN)
        result = pattern.match(self.logUuid)
        if (result is None):
            raise Exception(ErrorCode.GAUSS_500["GAUSS_50004"] % "--uuid" +
                            " The value of the uuid does not len 36 characters and "
                            "can only contain letters, numbers, and underscores.")

        # log level
        LOG_DEBUG = 1
        self.logger = GaussLog(self.logFile, module, LOG_DEBUG, self.logAction, self.logUuid, self.logStep)

    def initComponent(self):
        """
        function: Init component
        input : NA
        output: NA
        """
        self.initCmComponent()
        self.initGtmComponent()
        self.initEtcdComponent()
        self.initKernelComponent()

    def initComponentAttributes(self, component):
        """
        function: Init  component attributes on current node
        input : Object component
        output: NA
        """
        component.logger = self.logger
        component.binPath = "%s/bin" % self.clusterInfo.appPath
        component.dwsMode = self.dws_mode

    def initCmComponent(self):
        """
        function: Init cm component on current node
        input : Object nodeInfo
        output: NA
        """
        for inst in self.dbNodeInfo.cmservers:
            component = CM_OLAP()
            # init component cluster type
            component.clusterType = self.clusterInfo.clusterType
            component.instInfo = inst
            self.initComponentAttributes(component)
            self.cmCons.append(component)

        for inst in self.dbNodeInfo.cmagents:
            component = CM_OLAP()
            # init component cluster type
            component.clusterType = self.clusterInfo.clusterType
            component.instInfo = inst
            self.initComponentAttributes(component)
            self.cmCons.append(component)

    def initEtcdComponent(self):
        """
        function: Init etcd component on current node
        input : Object nodeInfo
        output: NA
        """
        for inst in self.dbNodeInfo.etcds:
            component = ETCD_OLAP()
            # init component cluster type
            component.clusterType = self.clusterInfo.clusterType
            component.instInfo = inst
            self.initComponentAttributes(component)
            self.etcdCons.append(component)

    def initGtmComponent(self):
        """
        function: Init gtm component on current node
        input : Object nodeInfo
        output: NA
        """
        for inst in self.dbNodeInfo.gtms:
            component = GTM_OLAP()
            # init component cluster type
            component.clusterType = self.clusterInfo.clusterType
            component.instInfo = inst
            component.initParas = self.gtmInitParas
            self.initComponentAttributes(component)
            self.gtmCons.append(component)

    def initKernelComponent(self):
        """
        function: Init kernel component on current node
        input : Object nodeInfo
        output: NA
        """
        for inst in self.dbNodeInfo.coordinators:
            component = CN_OLAP()
            # init component cluster type
            component.clusterType = self.clusterInfo.clusterType
            component.instInfo = inst
            self.initComponentAttributes(component)
            component.initParas = self.initParas
            self.cnCons.append(component)

        for inst in self.dbNodeInfo.datanodes:
            component = DN_OLAP()
            # init component cluster type
            component.clusterType = self.clusterInfo.clusterType
            component.instInfo = inst
            self.initComponentAttributes(component)
            component.initParas = self.initParas
            self.dnCons.append(component)

    def readConfigInfo(self):
        """
        function: Read config from static config file
        input : NA
        output: NA
        """
        try:
            self.clusterInfo = dbClusterInfo()
            self.clusterInfo.initFromStaticConfig(self.user)
            hostName = socket.gethostname()
            self.dbNodeInfo = self.clusterInfo.getDbNodeByName(hostName)
            if (self.dbNodeInfo is None):
                self.logger.logExit(ErrorCode.GAUSS_516["GAUSS_51619"] % hostName)
        except Exception as e:
            self.logger.logExit(str(e))

        self.logger.debug("Instance information on local node:\n%s" % str(self.dbNodeInfo))

    def readConfigInfoByXML(self):
        """
        function: Read config from xml config file
        input : NA
        output: NA
        """
        try:
            if (self.clusterConfig is None):
                self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50201"] % "XML configuration file")
            static_config_file = "%s/bin/cluster_static_config" % DefaultValue.getInstallDir(self.user)
            self.clusterInfo = dbClusterInfo()
            self.clusterInfo.initFromXml(self.clusterConfig, static_config_file)
            hostName = socket.gethostname()
            self.dbNodeInfo = self.clusterInfo.getDbNodeByName(hostName)
            if (self.dbNodeInfo is None):
                self.logger.logExit(ErrorCode.GAUSS_516["GAUSS_51619"] % hostName)
        except Exception as e:
            self.logger.logExit(str(e))
        self.logger.debug("Instance information on local node:\n%s" % str(self.dbNodeInfo))

    def getUserInfo(self):
        """
        Get user and group
        """
        (self.user, self.group) = g_OSlib.getPathOwner(self.clusterInfo.appPath)
        if (self.user == "" or self.group == ""):
            self.logger.logExit(ErrorCode.GAUSS_503["GAUSS_50308"])

    def getGTMDict(self, peerGtm, gtmInst, user=None, configItemType=None, alarm_component=None):
        """
        function: Get GTM configuration
        input : peerGtm is empty means that this is a single cluster.
        output: NA
        """
        tmpGTMDict = {}
        tmpGTMDict["listen_addresses"] = "'localhost,%s'" % ",".join(gtmInst.listenIps)
        tmpGTMDict["port"] = str(gtmInst.port)
        tmpGTMDict["nodename"] = "'gtm_%s'" % str(gtmInst.instanceId)
        if (configItemType != "ChangeIPUtility"):
            tmpGTMDict["log_directory"] = "'%s/pg_log/gtm'" % (DefaultValue.getUserLogDirWithUser(user))
        if (peerGtm):
            tmpGTMDict["local_host"] = "'%s'" % ",".join(gtmInst.haIps)
            tmpGTMDict["local_port"] = str(gtmInst.haPort)
            if (not self.clusterInfo.isSinglePrimaryMultiStandbyCluster()):
                tmpGTMDict["active_host"] = "'%s'" % ",".join(peerGtm[0].haIps)
                tmpGTMDict["active_port"] = str(peerGtm[0].haPort)
        if (configItemType == "ConfigInstance"):
            tmpGTMDict["alarm_component"] = "'%s'" % alarm_component
        return tmpGTMDict

    def getNodeGroupProcessControlFile(self):
        """
        function: Get node group process control file for the express cluster
        input : NA
        output: filePath
        """
        fileName = ".nodegroup_process_control.dat"
        tmpDir = DefaultValue.getTmpDirFromEnv(self.user)
        filePath = os.path.join(tmpDir, fileName)
        return filePath

    def createNodegroup(self, DNNameStr, CNNameStr, new_group_name):
        """
        In the express cluster, we will perform the following operations:
            Create new node group.
        """
        self.logger.log("Creating new node group.")

        OBTAIN_OLD_GROUP_SQL = "SELECT group_name FROM pg_catalog.pgxc_group WHERE in_redistribution='n'" \
                               " and is_installation = TRUE;"
        self.logger.debug("Command for getting old node group: %s" % OBTAIN_OLD_GROUP_SQL)
        status, output = ClusterCommand.execSQLCommand(OBTAIN_OLD_GROUP_SQL, self.user, "",
                                                       self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_502["GAUSS_50219"] % "old group" + " Error:\n%s" % str(output))
        old_group_name = output.strip()
        sql_list = ["""
        SET xc_maintenance_mode = on;
        START TRANSACTION;
        CREATE NODE GROUP "%(newNodeGroupName)s" WITH(%(dnNameString)s);
        """ % {"newNodeGroupName": new_group_name, "dnNameString": DNNameStr}]

        for cn_name in CNNameStr.split(','):
            sql_list.append("""
        EXECUTE DIRECT ON(%(cnName)s) 'UPDATE pg_catalog.pgxc_group SET is_installation = FALSE ,
         in_redistribution = ''y'', group_kind=''n'' WHERE group_name = ''%(oldNodeGroupName)s''';
        EXECUTE DIRECT ON(%(cnName)s) 'UPDATE pg_catalog.pgxc_group SET is_installation = TRUE ,
         in_redistribution = ''n'', group_kind=''i'' WHERE group_name = ''%(newNodeGroupName)s''';
        EXECUTE DIRECT ON(%(cnName)s) 'update pg_catalog.pgxc_group set group_acl= concat(''{'',
         concat(current_user,''=UCp/'', current_user, '',=UCp/'',current_user), ''}'')::aclitem[]
         WHERE group_name = ''%(newNodeGroupName)s''';
        """ % {"cnName": cn_name, "oldNodeGroupName": old_group_name, "newNodeGroupName": new_group_name})
        sql_list.append("""
        COMMIT;
        RESET xc_maintenance_mode;
        """)
        sql = "".join(sql_list)
        self.logger.debug("Sql command for creating new node group: %s" % sql)
        # delete old node group
        (status, output) = ClusterCommand.execSQLCommand(sql, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % sql + " Error: \n%s" % str(output))
        return old_group_name, new_group_name

    def get_dbms_om_table_list(self):
        """
        """
        sql = "SELECT relname FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n " \
              "WHERE n.nspname='dbms_om' AND c.relkind='r' AND c.relnamespace=n.oid;"
        status, output = ClusterCommand.execSQLCommand(sql, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0 or output == "":
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % sql + " Error: \n%s" % str(output))
        dbms_om_table_list = output.strip().split('\n')
        return dbms_om_table_list

    def updateDbmsData(self, CNNameStr, new_group_name):
        """
        Update dbms schema data to new_group_name for express cluster
        """
        sql_list = ["""
        SET xc_maintenance_mode = on;
        START TRANSACTION;
        """]
        dbms_table_list = self.get_dbms_om_table_list()
        dbms_table_str = "'',''".join(dbms_table_list)

        for cn_name in CNNameStr.split(','):
            sql_list.append("""
        EXECUTE DIRECT ON(%(cnName)s) 'update pg_catalog.pgxc_class
           set nodeoids = (select group_members from pgxc_group where group_name=''%(newNodeGroupName)s'')
           where pcrelid in (select oid from pg_catalog.pg_class where relname in (''%(dbms_table)s''))';
        EXECUTE DIRECT ON(%(cnName)s) 'update pg_catalog.pgxc_class set pgroup = ''%(newNodeGroupName)s''
           where pcrelid in (select oid from pg_catalog.pg_class where relname in (''%(dbms_table)s''))';
        """ % {"cnName": cn_name, "newNodeGroupName": new_group_name, "dbms_table": dbms_table_str})
        sql_list.append("""
        COMMIT;
        RESET xc_maintenance_mode;
        """)
        sql = "".join(sql_list)
        self.logger.debug("Sql command for updating dbms schema: %s" % sql)
        # delete old node group
        (status, output) = ClusterCommand.execSQLCommand(sql, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % sql + " Error: \n%s" % str(output))

    def dropDependTable(self, new_group, old_group):
        """
        """
        drop_table_sql = """
        DECLARE
            sql_stmt text;
            my_cursor REFCURSOR;
            schemaname  text;
            tablename text;
            tablekind text;
        BEGIN
            sql_stmt := 'SELECT pg_catalog.quote_ident(n.nspname) AS schemaname,
                                pg_catalog.quote_ident(c.relname) AS tablename,
                                pg_catalog.quote_ident(c.relkind) AS tablekind
                         FROM pg_catalog.pg_class c
                         JOIN pg_catalog.pgxc_class x ON x.pcrelid = c.oid
                         LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
                         WHERE x.pgroup = ''%s'' and c.relkind in (''f'', ''r'') order by c.oid desc';
            OPEN my_cursor FOR EXECUTE sql_stmt;
            FETCH FROM my_cursor INTO schemaname, tablename, tablekind;
            WHILE my_cursor %% FOUND LOOP
                IF tablekind = 'r'
                THEN
                    IF schemaname = 'dbms_om'
                    THEN
                        EXECUTE IMMEDIATE 'TRUNCATE TABLE '||schemaname||'.'||tablename||'';
                        EXECUTE IMMEDIATE 'SET enable_cluster_resize=ON;
                        ALTER TABLE '||schemaname||'.'||tablename||' TO GROUP %s;
                        SET enable_cluster_resize=OFF';
                    ELSE
                        EXECUTE IMMEDIATE 'DROP TABLE IF EXISTS '||schemaname||'.'||tablename||' CASCADE';
                    END IF;
                ELSE
                    EXECUTE IMMEDIATE 'DROP FOREIGN TABLE IF EXISTS '||schemaname||'.'||tablename||' CASCADE';
                END IF;
                FETCH FROM my_cursor INTO schemaname, tablename, tablekind;
            END LOOP;
            CLOSE my_cursor;
        END;
        """ % (old_group, new_group)
        self.logger.log("Deleting other tables which depend on the delete node group %s." % old_group)
        status, output = ClusterCommand.execSQLCommand(drop_table_sql, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % drop_table_sql + " Error: \n%s" % str(output))
        self.logger.log("Successfully deleted the depend tables.")

    def dropNodeGroup(self, CNNameStr, reserved_group_name, del_group_name):
        """
        drop node group for express cluster
        """
        self.logger.log("Deleting the pmk schema.")
        PMK_SQL = "SET enable_parallel_ddl = off; DROP SCHEMA IF EXISTS pmk CASCADE;"
        status, output = ClusterCommand.execSQLCommand(PMK_SQL, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % PMK_SQL + " Error: \n%s" % str(output))
        self.logger.log("Successfully deleted the pmk schema.")

        # In express cluster, Schema `Scheduler` is created on old node group,
        # We must re-create this schema on new node group by two steps:
        # 1. drop schema cascade
        # 2. restart gs_scheduler.
        SCHEDULER_SQL = "DROP SCHEMA IF EXISTS scheduler CASCADE;"
        self.logger.log("Deleting the scheduler schema.")
        status, output = ClusterCommand.execSQLCommand(SCHEDULER_SQL, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % SCHEDULER_SQL + " Error: \n%s" % str(output))
        self.logger.log("Successfully deleted the scheduler schema.")
        all_nodes_name = []
        for dbNode in self.clusterInfo.dbNodes:
            all_nodes_name.append(dbNode.name)
        cmd = "ps -ux | grep gs_scheduler | grep start | grep -v grep | awk -F ' ' '{print \$2}' | xargs -r kill -9"
        self.logger.debug("Command for killing scheduler process: %s." % cmd)
        sshTool = SshTool(all_nodes_name, KEY)
        (status, output) = sshTool.getSshStatusOutput(cmd, all_nodes_name)
        self.logger.debug("The result of kill scheduler process commands. "
                          "Status:%s, Output:%s." % (status, output))

        self.dropDependTable(reserved_group_name, del_group_name)

        sql_list = ["""
        SET xc_maintenance_mode = on;
        START TRANSACTION;
        """]
        for cn_name in CNNameStr.split(','):
            sql_list.append("""
        EXECUTE DIRECT ON(%(cnName)s) 'UPDATE pg_catalog.pgxc_group SET is_installation = FALSE ,
         in_redistribution = ''y'', group_kind=''n'' WHERE group_name = ''%(oldNodeGroupName)s''';
        EXECUTE DIRECT ON(%(cnName)s) 'UPDATE pg_catalog.pgxc_group SET is_installation = TRUE ,
         in_redistribution = ''n'', group_kind=''i'' WHERE group_name = ''%(newNodeGroupName)s''';
        EXECUTE DIRECT ON(%(cnName)s) 'update pg_catalog.pgxc_group set group_acl= concat(''{'',
         concat(current_user,''=UCp/'', current_user, '',=UCp/'',current_user), ''}'')::aclitem[]
         WHERE group_name = ''%(newNodeGroupName)s''';
        """ % {"cnName": cn_name, "oldNodeGroupName": del_group_name, "newNodeGroupName": reserved_group_name})
        sql_list.append("""
                DROP NODE GROUP %s;
                COMMIT;
                RESET xc_maintenance_mode;
              """ % del_group_name)
        sql = "".join(sql_list)
        self.logger.debug("Sql command for dropping old node group: %s" % sql)
        self.logger.log("Deleting the old node group %s." % del_group_name)
        status, output = ClusterCommand.execSQLCommand(sql, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0 and output.find("because other objects depend on it") >= 0:
            self.dropDependTable(reserved_group_name, del_group_name)
            status, output = ClusterCommand.execSQLCommand(sql, self.user, "", self.cnCons[0].instInfo.port)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % sql + " Error: \n%s" % str(output))
        self.logger.log("Successfully deleted the old node group.")

    def close_hba(self):
        """
        close user hba
        :return:
        """
        self.logger.debug("Closing user hba.")

        hba_file = os.path.join(self.dbNodeInfo.coordinators[0].datadir, "pg_hba.conf")
        cmd = "sed -i -e '/^[^(#|local)].*sha256.*/s/\(.*\)/#@#@#\\1/g'" \
              " -e '/^[^(#|local)].*md5.*/s/\(.*\)/#@#@#\\1/g' %s" % hba_file
        if os.path.exists(hba_file):
            self.logger.debug("Command for closing user hba:%s." % cmd)
            (status, output) = DefaultValue.retryGetstatusoutput(cmd)
            if status != 0:
                raise Exception(ErrorCode.GAUSS_502["GAUSS_50205"] % cmd + " Error: \n%s" % str(output))

            self.logger.debug("Successfully closed user hba.")

    def open_hba(self):
        """
        open user hba
        :return:
        """
        if len(self.dbNodeInfo.coordinators) == 0:
            self.logger.debug("There is no coordinator in current node.")
            return

        self.logger.debug("Opening user hba.")

        hba_file = os.path.join(self.dbNodeInfo.coordinators[0].datadir, "pg_hba.conf")
        cmd = "sed -i '/^#@#@#.*/s/#@#@#\(.*\)/\\1/g' %s" % hba_file
        self.logger.debug("Command for opening user hba:%s." % cmd)
        (status, output) = DefaultValue.retryGetstatusoutput(cmd)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_502["GAUSS_50205"] % cmd + " Error: \n%s" % str(output))

        self.logger.debug("Successfully opened user hba.")
