#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""
Copyright (c): 2021, Huawei Tech. Co., Ltd.
Description  : Consists of common utility used by other scripts.
"""
try:
    import os
    import sys
    import time
    import socket
    import getpass
    import subprocess

    sys.path.append(os.path.split(os.path.realpath(__file__))[0] + "/../../script/")
    import paramiko
    import netifaces
    from gspylib.common.ErrorCode import ErrorCode
    from gspylib.common.Common import DefaultValue
    from gspylib.common.GaussLog import GaussLog
    from multiprocessing.dummy import Pool as ThreadPool
    from multiprocessing import TimeoutError
except ImportError as e:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(e))


class GDSUtils:
    """
    Common utility functions, used by install/check/ctl scripts.
    """

    def __init__(self):
        pass

    # constants
    ENV_NAME_GDS_INSTALL_DIR = "GDS_INSTALL_DIR"
    ENV_NAME_PYTHONPATH = "PYTHONPATH"
    ENV_NAME_LIB_PATH = "LD_LIBRARY_PATH"
    ENV_NAME_PATH = "PATH"
    EXEC_ENV_FILE = "~/.bashrc"
    WORK_PATH = os.path.dirname(os.path.realpath(__file__))
    CHECK_LOG_FILE = "logs/os_check.log"
    TIMEOUT = 20

    @staticmethod
    def get_version_info(app_name):
        """
        function: get version information of the package.
        """
        work_path = os.path.dirname(os.path.realpath(__file__))
        bin_path = os.path.join(work_path, "../bin")
        env_path = "%s/%s_env" % (bin_path, app_name)
        cmd = "source %s ; %s -V" % (env_path, app_name)
        status, output = subprocess.getstatusoutput(cmd)
        if status != 0:
            GaussLog.exitWithError(ErrorCode.GAUSS_516["GAUSS_51623"])
        GaussLog.printMessage(str(output))

    @staticmethod
    def exec_command_remote_parallelly(handler, hosts, timeout=None):
        """
        function: execute command parallelly by ThreadPool:
                (1) through root using paramiko or
                (2) through trusted user using ssh
        input timeout must be a number or None. If it is None, it won't have
        """
        parallel_result = {}
        exec_result = {}
        with ThreadPool(DefaultValue.getCpuSet()) as pool:
            mymap = pool.map_async(handler, hosts)
            try:
                exec_result = mymap.get(timeout=timeout)
            except TimeoutError:
                raise Exception(ErrorCode.GAUSS_535["GAUSS_53501"])
            except Exception as e:
                raise e
        try:
            for i in exec_result:
                for (key, value) in i.items():
                    parallel_result[key] = value
            return parallel_result
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_535["GAUSS_53504"] % str(e))

    @staticmethod
    def connect_paramiko_ssh_host(host, username, passwd):
        """
            function: Connect to a node, return node connection.
            input: ip address of the node
            output: node connection through paramiko
            Please stop it manually.
        """
        try:
            transport = paramiko.Transport((host, 22))
            transport.connect(username=username, password=passwd)
            client = paramiko.SSHClient()
            client._transport = transport
            if client.get_transport().active:
                return client
            raise Exception(ErrorCode.GAUSS_535["GAUSS_53500"] % host)
        except Exception:
            raise Exception(ErrorCode.GAUSS_535["GAUSS_53500"] % host)

    @staticmethod
    def paramiko_cmd_raw(client, cmd, get_pty=False):
        """
            function: execute command through paramiko.
            input:
                client: node connection through paramiko node connection through paramiko
                cmd: command line
            output:
                status: whether execute the command line successfully
                output: output of the command line or error message of the failed execution
        """
        if cmd is None or len(cmd) == 0:
            return True, ""
        try:
            stdin, stdout, stderr = client.exec_command(cmd, get_pty=get_pty, timeout=GDSUtils.TIMEOUT)
            cmd_error = stderr.read().decode('utf-8')
            if cmd_error:
                status = False
                output = ErrorCode.GAUSS_514["GAUSS_51400"] % (cmd + " , error: " + cmd_error)
            else:
                status = True
                output = stdout.read().decode('utf-8')
            return status, output
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_535["GAUSS_53504"] % str(e))

    @staticmethod
    def exec_command_by_paramiko_on_one_node(ip, cmd, user, password):
        """
        function:
            Connect to a node, then execute cmd and return result.
        precondition:
            root's password is correct on each node.
        postcondition:
            NA
        input: ip, cmd
        """
        result = {}
        try:
            ssh = paramiko.Transport((ip, 22))
        except paramiko.SSHException as e:
            raise Exception(ErrorCode.GAUSS_512["GAUSS_51220"] % ip + " Error: %s" % str(e))
        try:
            ssh.connect(username=user, password=password)
        except paramiko.AuthenticationException:
            ssh.close()
            raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % ip)

        try:
            channel = ssh.open_session()
            channel.exec_command(cmd)
            recv_data = channel.recv(9999).strip()
            result[ip] = recv_data
            ssh.close()
        except TimeoutError:
            raise Exception(ErrorCode.GAUSS_535["GAUSS_53501"])
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_535["GAUSS_53504"] % str(e))
        return result

    @staticmethod
    def is_ipv4(ip):
        try:
            socket.inet_pton(socket.AF_INET, ip)
        except AttributeError:
            try:
                socket.inet_aton(ip)
            except socket.error:
                return False
            return ip.count('.') == 3
        except socket.error:
            return False
        return True

    @staticmethod
    def is_ipv6(ip):
        try:
            socket.inet_pton(socket.AF_INET6, ip)
        except socket.error:
            return False
        return True

    @staticmethod
    def check_ip(ip):
        return GDSUtils.is_ipv4(ip) or GDSUtils.is_ipv6(ip)

    @staticmethod
    def divide_local_remote(host_list):
        """
            function: parse the host_list, find out whether there is local IP in the list,
                      find out all remote IPs.
            input: ip list
            output:
                local_ip: last local ip in the host_list, if there's none, return ‘0.0.0.0’
                local_flag: Boolean, whether there is local IP in the list
                remote_hosts: all remote hosts in the host list
        """
        try:
            remote_hosts = []
            local_flag = False
            local_ip = '0.0.0.0'
            local_ips = ['0.0.0.0']
            for interfaces in netifaces.interfaces():
                addresses = netifaces.ifaddresses(interfaces)
                if netifaces.AF_INET in addresses:
                    for items in addresses[netifaces.AF_INET]:
                        ipv4 = items.get('addr')
                        if ipv4:
                            local_ips.append(ipv4)
            for host in host_list:
                if host in local_ips:
                    local_flag = True
                    local_ip = host
                else:
                    remote_hosts.append(host)
        except Exception:
            raise Exception(ErrorCode.GAUSS_506["GAUSS_50616"])
        return local_ip, local_flag, remote_hosts

    @staticmethod
    def two_more_chances_for_passwd(hosts, user, passwd):
        """
        function:for better user experience, if the root password is wrong, two more chances should be given
        input:
            hosts: list of all ips
            user: the check target
            passwd: the check target
        output:NA
        """
        __check_user = user
        __check_password = passwd
        times = 0
        while True:
            try:
                # an inner function
                def __check_user_passwd(ip):
                    ssh = None
                    try:
                        # ssh the ip
                        ssh = paramiko.SSHClient()
                        transport = paramiko.Transport((ip, 22))
                        transport.connect(username=__check_user, password=__check_password)
                        ssh._transport = transport
                    except Exception:
                        raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % ip
                                        + " Maybe communication is exception, please check the password and "
                                          "communication. Wrong password or communication is abnormal.")
                    finally:
                        if ssh is not None:
                            ssh.close()

                # check user password in async mode
                with ThreadPool(DefaultValue.getCpuSet()) as pool:
                    mymap = pool.map_async(__check_user_passwd, hosts)
                    try:
                        mymap.get(timeout=GDSUtils.TIMEOUT)
                    except TimeoutError:
                        raise Exception(ErrorCode.GAUSS_535["GAUSS_53501"])
                return __check_password
            except Exception as e:
                if times == 2:
                    raise e
                GaussLog.printMessage("Password authentication failed, please try again.")
                __check_password = getpass.getpass()
                times += 1

    @staticmethod
    def get_user_password(name, point=""):
        """
        function: get user password
        input: name: username
               point: tips
        output: password
        """
        if point == "":
            GaussLog.printMessage("Please enter password for %s." % name)
        else:
            GaussLog.printMessage("Please enter password for %s %s." % (name, point))
        passwdone = getpass.getpass()
        DefaultValue.checkPasswordVaild(passwdone)
        return passwdone

    @staticmethod
    def get_hosts_from_param(host_parm):
        """
        function: get host list from parameters no matter the type of parameter is file or string.
        input: value of parameter '--hosts' or '--ping-host'
        output: host list
        """
        parsed_hosts = []
        try:
            if not host_parm:
                return parsed_hosts
            hosts = host_parm.split(',')
            if GDSUtils.check_ip(hosts[0]):
                for param in hosts:
                    if param and param not in parsed_hosts:
                        if not GDSUtils.check_ip(param):
                            raise Exception(ErrorCode.GAUSS_506["GAUSS_50603"] + " Invalid ip input: %s" % param)
                        else:
                            parsed_hosts.append(param)
            else:
                if not os.path.isfile(host_parm):
                    raise Exception(ErrorCode.GAUSS_535["GAUSS_53503"] % host_parm
                                    + "Invalid ip address or invalid file path.")
                else:
                    with open(host_parm) as f:
                        f_lines = [x.strip() for x in f.readlines()]
                    for line in f_lines:
                        if not line or line.startswith("#"):
                            continue
                        elif not GDSUtils.check_ip(line):
                            raise Exception(ErrorCode.GAUSS_506["GAUSS_50603"] + " Invalid ip input: %s" % line)
                        elif line not in parsed_hosts:
                            parsed_hosts.append(line)
                        else:
                            continue
            return parsed_hosts
        except Exception as e:
            raise e

    @staticmethod
    def generate_log_uuid():
        """
        function: generate 36 bit log uuid.
        """
        return (time.strftime("%Y-%m-%dT%H-%M-%S", time.localtime(time.time())) + "-" +
                str(os.getpid()) + "-").ljust(36, '0')

    @staticmethod
    def load_gds_env():
        """
        function: load environment path.
        """
        try:
            status, output = subprocess.getstatusoutput("echo $LD_LIBRARY_PATH")
            ld_library_path_export = os.path.realpath(os.path.join(GDSUtils.WORK_PATH, "../lib"))
            if status == 0 and output:
                ld_library_path_export += ":$LD_LIBRARY_PATH"
            ld_library_path_export = "export LD_LIBRARY_PATH=%s" % ld_library_path_export
            path_export = "export PATH=%s/../bin:%s/../script:%s/../script/gspylib/pssh/bin:" \
                          "/usr/sbin:/usr/local/sbin:/sbin:$PATH" % \
                          (GDSUtils.WORK_PATH, GDSUtils.WORK_PATH, GDSUtils.WORK_PATH)
            pythonpath_export = "export PYTHONPATH=%s/../lib:$PYTHONPATH" % GDSUtils.WORK_PATH
            status, output = subprocess.getstatusoutput("%s; %s; %s" %
                                                        (ld_library_path_export, path_export, pythonpath_export))
            if status != 0:
                raise Exception(output)
        except Exception as e:
            GaussLog.exitWithError(ErrorCode.GAUSS_535["GAUSS_53513"], str(e))
