#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c): 2012-2017, Huawei Tech. Co., Ltd.
# Description  : sql command is a utility with a lot of sql command functions
#############################################################################
try:
    import socket
    import sys
    import os
    import time

    localDirPath = os.path.dirname(os.path.realpath(__file__))
    sys.path.insert(0, localDirPath + "/../../../../lib")
    sys.path.append(localDirPath + "/../../../")
    from gspylib.common.ErrorCode import ErrorCode
    from gspylib.common.common.default_value import DefaultValue
    from gspylib.common.common.cluster_command import ClusterCommand
except ImportError as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))


class SqlCommand():
    def __init__(self):
        pass

    @staticmethod
    def execute_sql_for_transaction_read_only(cn_list, user, action, dws_mode=False):
        """
        function: execute sql for transaction read only
        input:cn_list, user, dws_mode
        output: 0/1
        """
        localhost = socket.gethostname()
        sql = "show default_transaction_read_only;"
        if dws_mode:
            for cooInst in cn_list:
                if localhost == cooInst.hostname:
                    (status, result, error_output) = ClusterCommand.excuteSqlOnLocalhost(cooInst.port, sql)
                    if status != 2:
                        return 1, "[%s]: Error: %s result: %s status: %s" % \
                                  (cooInst.hostname, error_output, result, status)
                    if result[0][0].strip().lower() == "on":
                        return 1, "The database is in read only mode."
                else:
                    currentTime = time.strftime("%Y-%m-%d_%H:%M:%S")
                    pid = os.getpid()
                    output_file = "metadata_%s_%s_%s.json" % (cooInst.hostname, pid, currentTime)
                    tmpDir = DefaultValue.getTmpDirFromEnv()
                    filepath = os.path.join(tmpDir, output_file)
                    ClusterCommand.executeSQLOnRemoteHost(cooInst.hostname, cooInst.port, sql, filepath, action=action)
                    (status, result, error_output) = ClusterCommand.getSQLResult(cooInst.hostname, output_file)
                    if status != 2:
                        return 1, "[%s]: Error: %s result: %s status: %s" % \
                                  (cooInst.hostname, error_output, result, status)
                    if result[0][0].strip().lower() == "on":
                        return 1, "The database is in read only mode."
        else:
            for cooInst in cn_list:
                (status, output) = ClusterCommand.remoteSQLCommand(sql, user, cooInst.hostname, cooInst.port)
                resList = output.split('\n')
                if status != 0 or len(resList) < 1:
                    return 1, "[%s]: %s" % (cooInst.hostname, output)
                if resList[0].strip() == "on":
                    return 1, "The database is in read only mode."
        return 0, "success"

    @staticmethod
    def execute_sql_command(user, sql, logger, cn_inst):
        """
        function:create table for failed space limit setting
        input:NA
        output:NA
        """
        try:
            exec_sql = "START TRANSACTION;"
            exec_sql += " %sCOMMIT;" % sql
            logger.debug("Execute sql command: %s." % exec_sql)
            (status, output) = ClusterCommand.remoteSQLCommand(exec_sql, user,
                                                               cn_inst.hostname,
                                                               cn_inst.port, False,
                                                               DefaultValue.DEFAULT_DB_NAME)
            if status != 0 or ClusterCommand.findErrorInSql(output):
                raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % exec_sql + " Error: \n%s" % str(output))
        except Exception as ex:
            raise Exception(str(ex))

    @staticmethod
    def getAllDatabase(user, logger, cn_inst):
        try:
            sql = "SELECT datname FROM pg_catalog.pg_database WHERE datallowconn ORDER BY 1"
            logger.debug("Get databases from cluster command: %s." % sql)
            (status, output) = ClusterCommand.remoteSQLCommand(sql,
                                                               user,
                                                               cn_inst.hostname,
                                                               cn_inst.port,
                                                               ignoreError=False)
            if status != 0:
                raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % sql + " Error: \n%s" % str(output))
            logger.debug("Get databases from cluster result: %s." % output)
            database_list = output.split("\n")

            return database_list
        except Exception as ex:
            raise Exception(str(ex))

    @staticmethod
    def check_data_redis(user, logger, cn_inst):
        database_list = SqlCommand.getAllDatabase(user, logger, cn_inst)

        sql = "select s.nspname,u.rolname from pg_namespace s,pg_authid u " \
              "where s.nspname='data_redis' and u.oid=s.nspowner;"
        logger.debug("Sql for checking data_redis: %s." % sql)

        check_result = {}
        for database in database_list:
            (status, output) = ClusterCommand.remoteSQLCommand(sql,
                                                               user,
                                                               cn_inst.hostname,
                                                               cn_inst.port,
                                                               False,
                                                               database)
            if status != 0:
                raise Exception(ErrorCode.GAUSS_513["GAUSS_51300"] % sql + " Error:\n%s" % output)

            if output.strip():
                check_result[database] = output.strip()

        if check_result:
            raise Exception(ErrorCode.GAUSS_516["GAUSS_51601"] % "expand" +
                            " ERROR: Cannot use system reserved schema 'data_redis', "
                            "which is used for cluster resize only.\n"
                            "DETAIL: %s\n"
                            "HINT: You can drop or rename it.\n" % check_result)

    @staticmethod
    def checkTransactionReadonly(user, DbclusterInfo, normalCNList=None, action=""):
        """
        function : check the CN's parameter default_transaction_read_only is on
                   if equals on, return 1 and error info
        input : user, DbclusterInfo, normalCNList
        output : 0/1
        """
        if normalCNList is None:
            normalCNList = []
        cnList = []
        try:
            if len(normalCNList):
                cnList = normalCNList
            else:
                # Find CN instance in cluster
                for dbNode in DbclusterInfo.dbNodes:
                    if len(dbNode.coordinators) != 0:
                        cnList.append(dbNode.coordinators[0])

            DWS_mode = False
            nodeInfo = DbclusterInfo.getDbNodeByName(socket.gethostname())
            data_dir = nodeInfo.cmagents[0].datadir
            security_mode_value = DefaultValue.getSecurityMode(data_dir)
            if security_mode_value == "on":
                DWS_mode = True
            status, output = SqlCommand.execute_sql_for_transaction_read_only(cnList, user, action, DWS_mode)
            return status, output
        except Exception as e:
            return 1, str(e)

    @staticmethod
    def get_dbms_om_table_list(cn_inst, user):
        """
        """
        sql = "SELECT relname FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n " \
              "WHERE n.nspname in ('dbms_om') AND relkind = 'r' AND c.relnamespace = n.oid;"
        status, output = ClusterCommand.remoteSQLCommand(sql, user, cn_inst.hostname, cn_inst.port,
                                                         False, DefaultValue.DEFAULT_DB_NAME)
        if status != 0:
            raise Exception("Failed to get tables for schema dbms_om." + " Error:\n%s." % str(output))
        dbms_table_list = output.split('\n')
        return [table.strip() for table in dbms_table_list if table.strip()]

    @staticmethod
    def get_new_dbms_om_table_list(cn_inst, user):
        """
        """
        sql = "SELECT c.relname FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pgxc_class x ON c.oid = x.pcrelid " \
              "WHERE c.oid < 16384 AND " \
              "x.pgroup IN (SELECT group_name FROM pg_catalog.pgxc_group WHERE is_installation = 't') AND " \
              "c.relnamespace IN (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = 'dbms_om');"
        status, output = ClusterCommand.remoteSQLCommand(sql, user, cn_inst.hostname, cn_inst.port,
                                                         False, DefaultValue.DEFAULT_DB_NAME, is_inplace_upgrade=True)
        if status != 0:
            raise Exception("Failed to get tables for schema dbms_om." + " Error:\n%s." % str(output))
        dbms_table_list = output.split('\n')
        return [table.strip() for table in dbms_table_list if table.strip()]

    @staticmethod
    def get_dbms_om_tables_update_sql(coo_insts, group_name, dbms_table_str):
        """
        """
        sql = ""
        for cn_inst in coo_insts:
            sql = f"""{sql}
EXECUTE DIRECT ON(cn_%(cn_inst_id)s) 'UPDATE pg_catalog.pgxc_class
    SET (nodeoids, pgroup) = 
        (SELECT group_members, group_name FROM pg_catalog.pgxc_group WHERE group_name=''%(group_name)s'') 
    WHERE pcrelid IN 
        (SELECT oid FROM pg_catalog.pg_class WHERE relname IN (''%(dbms_om_tables)s'') AND oid < 16384)';""" \
                  % {"cn_inst_id": cn_inst.instanceId,
                     "group_name": group_name,
                     "dbms_om_tables": dbms_table_str}
        return sql
