#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c): 2010-2020, Huawei Tech. Co., Ltd.
# Description  : gsnetwork.py is a utility to do something for network information.
#############################################################################
try:
    import subprocess
    import sys
    import _thread as thread
    import re
    import socket
    import binascii

    sys.path.append(sys.path[0] + "/../../")

    from gspylib.common.ErrorCode import ErrorCode
    from gspylib.os.gsplatform import g_Platform
    from gspylib.threads.parallelTool import parallelTool
except Exception as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

g_failedAddressList = []
g_lock = thread.allocate_lock()

"""
Requirements:
"""


class networkInfo():
    """
    """

    def __init__(self):
        self.NICNum = ""
        self.ipAddress = ""
        self.networkMask = ""
        self.MTUValue = ""

        self.TXValue = ""
        self.RXValue = ""
        self.networkSpeed = ""
        self.networkConfigFile = ""
        self.networkBondModeInfo = ""

    def __str__(self):
        """
        """
        return "NICNum=%s,ipAddress=%s,networkMask=%s,MTUValue=%s,TXValue=%s,RXValue=%s," \
               "networkSpeed=%s,networkConfigFile=%s,networkBondModeInfo=\"%s\"" % \
            (self.NICNum, self.ipAddress, self.networkMask, self.MTUValue, self.TXValue, self.RXValue,
             self.networkSpeed, self.networkConfigFile, self.networkBondModeInfo)


class Network():
    """
    function: Init the Network options
    """

    def __init__(self):
        pass

    @staticmethod
    def isIp4Valid(ipAddress):
        """
        function : check if the input ip address is valid IPv4 format.
        input : String
        output : NA
        """
        Valid = re.match("^(25[0-5]|2[0-4][0-9]|[0-1][0-9]{2}|"
                         "[1-9][0-9]|[1-9])\.(25[0-5]|2[0-4][0-9]|[0-1][0-9]{2}|"
                         "[1-9][0-9]|[1-9]|0)\.(25[0-5]|2[0-4][0-9]|[0-1][0-9]{2}|"
                         "[1-9][0-9]|[1-9]|0)\.(25[0-5]|2[0-4][0-9]|[0-1][0-9]{2}|"
                         "[1-9][0-9]|[0-9])$",
                         ipAddress)
        if Valid:
            if Valid.group() == ipAddress:
                return True
        return False

    def isIpValid(self, ip, checkScope=True):
        """
        function :
            Check if the input ip address is a valid IP.
            Valid IP:
            1. For IPv4, it's a valid IP address.
            2. For IPv6, it's a valid IP address, and the IP is not a scope link one('fe80::/16'), when "checkScope" is
               set.
        input :
            ip: IP address in plain text.
            checkScope: Whether to check the IP scope for ipv6. Local link
        output : NA
        """
        if ip.strip().endswith('svc.cluster.local'):
            return True

        if self.getIPType(ip) == 4:
            return True

        if self.getIPType(ip) == 6:
            if checkScope and self.isIPv6LocalLink(ip):
                return False
            else:
                return True
        if self.getHostProtoVersion(ip) != 0:
            return True

        return False

    def check_port_connected(self, address, port, timeout=1):
        try:
            # Create a socket object.
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

            # Setting the connection timeout interval
            sock.settimeout(timeout)

            # Attempt to connect to the specified IP address and port
            result = sock.connect_ex((address, port))

            # Close the socket.
            sock.close()

            # If 0 is returned for connect_ex, the connection is successful.
            return result == 0
        except socket.error:
            pass
        return False

    @staticmethod
    def executePingCmd(ipAddress):
        """
        function : Send the network command of ping.
        input : String
        output : NA
        """
        pingCmd = g_Platform.getPingCmd(ipAddress, "5", "1")
        cmd = "%s | %s ttl | %s -l" % (pingCmd, g_Platform.getGrepCmd(), g_Platform.getWcCmd())
        (status, output) = subprocess.getstatusoutput(cmd)
        if str(output) == '0' or status != 0:
            g_lock.acquire()
            g_failedAddressList.append(ipAddress)
            g_lock.release()

    def checkIpAddressList(self, ipAddressList):
        """
        function : Check the connection status of network.
        input : []
        output : []
        """
        global g_failedAddressList
        g_failedAddressList = []
        parallelTool.parallelExecute(self.executePingCmd, ipAddressList)
        return g_failedAddressList

    @staticmethod
    def getAllNetworkIp():
        networkInfoList = []
        mappingList = g_Platform.getIpAddressAndNICList()
        for onelist in mappingList:
            data = networkInfo()
            # NIC number
            data.NICNum = onelist[0]
            # ip address
            data.ipAddress = onelist[1]
            networkInfoList.append(data)
        return networkInfoList

    def getAllNetworkInfo(self):
        """
        """
        networkInfoList = []
        mappingList = g_Platform.getIpAddressAndNICList('all')
        for oneList in mappingList:
            data = networkInfo()
            # NIC number
            data.NICNum = oneList[0]
            # ip address
            data.ipAddress = oneList[1]

            # host name
            try:
                data.hostName = g_Platform.getHostNameByIPAddr(data.ipAddress)
            except Exception:
                data.hostName = ""

            # network mask
            try:
                iptype = "ipv4"
                if self.getIPType(data.ipAddress) == 6:
                    iptype = "ipv6"
                data.networkMask = g_Platform.getNetworkMaskByNICNum(data.NICNum, iptype)
            except Exception:
                data.networkMask = ""

            # MTU value
            try:
                data.MTUValue = g_Platform.getNetworkMTUValueByNICNum(data.NICNum)
            except Exception:
                data.MTUValue = ""

            # TX value
            try:
                data.TXValue = g_Platform.getNetworkRXTXValueByNICNum(data.NICNum, 'tx')
            except Exception:
                data.TXValue = ""

            # RX value
            try:
                data.RXValue = g_Platform.getNetworkRXTXValueByNICNum(data.NICNum, 'rx')
            except Exception:
                data.RXValue = ""

            # network speed
            try:
                data.networkSpeed = g_Platform.getNetworkSpeedByNICNum(data.NICNum)
            except Exception:
                data.networkSpeed = ""

            # network config file
            try:
                data.networkConfigFile = g_Platform.getNetworkConfigFileByNICNum(data.NICNum)
            except Exception:
                data.networkConfigFile = ""

            # network bond mode info
            try:
                data.networkBondModeInfo = g_Platform.getNetworkBondModeInfo(data.networkConfigFile, data.NICNum)
            except Exception:
                data.networkBondModeInfo = ""

            networkInfoList.append(data)
        return networkInfoList

    @staticmethod
    def setNetworkInterruptByNIC(networkCardNum):
        return g_Platform.setNetworkInterruptByNIC(networkCardNum)

    @staticmethod
    def checkNetworkInterruptByNIC(networkCardNum):
        return g_Platform.checkNetworkInterruptByNIC(networkCardNum)

    @staticmethod
    def getNetmask(baseIp, user="", hostName=""):
        """
        function: obtain the netmask
        param: hostname
        output: netmask
        """
        from gspylib.common.Common import DefaultValue
        if hostName == "":
            # 1.get netmask from local ip
            cmd = "source /etc/profile;ip addr | grep -E '\<%s\>' | awk '{print $2}'" % baseIp
        else:
            # 2.get netmask from remote ip
            cmd = "export LD_LIBRARY_PATH=/lib64:$LD_LIBRARY_PATH; ssh %s %s \\\"source /etc/profile;%s addr | " \
                  "grep -E '\<%s\>' | awk '{print \\\\\$2}'\\\" 2>/dev/null" % \
                  (hostName,
                   DefaultValue.SSH_OPTION,
                   g_Platform.getIpCmd(),
                   baseIp)
            cmd = "su - %s -c \"%s\"" % (user, cmd)
        (status, output) = subprocess.getstatusoutput(cmd)
        if status != 0:
            raise Exception(ErrorCode.GAUSS_506["GAUSS_50614"])
        netmask = output.strip().split("\n")[0].split("/")[1].strip()
        return netmask

    def getIPType(self, ipAddr):
        """
        function: check if it is a valid IP address.
        param: ipAddr: IP address in plain text.
        output:
            4: IPv4 address
            6: IPv6 address
            0: Invalid IP address
        """
        try:
            if self.isIp4Valid(ipAddr):
                return 4
            # If an address contains any ':', then it may be a v6 address,
            # thus we will test it by 'getaddrinfo'.
            maybev6 = (str(ipAddr).find(':')) >= 0
            if not maybev6:
                return 0

            rl = socket.getaddrinfo(ipAddr, None)
            if len(rl) == 0:
                return 0
            if socket.AF_INET == rl[0][0]:
                return 4
            elif socket.AF_INET6 == rl[0][0]:
                return 6
            else:
                return 0
        except Exception:
            return 0

    def isSameIP(self, ip1, ip2):
        """
        function:
            Detect if the pair of IPs are the same one.
            It is very useful for IPv6 format.
        input:
            ip1: IP in plain text.
            ip2: IP in plain text.
        output:
            True: They are the same.
            False: Either of the IPs is not a valid IP address or they are
                   not the same.
        """
        try:
            return self.formatIP(ip1) == self.formatIP(ip2)
        except Exception:
            return False

    def formatIP(self, ip):
        """
        function:
            Format the ipv6 address into a regular one.
            Such as:
                '2002:0000:0000:0000:0000:0000:00AB:CD00' -> '2002::ab:cd00'
            Here we SHOULD NOT treat IPv4, because it is confused:
                '1.02.03.040' -> '1.2.3.32'
                As the '040' is an octal number, 32 in decimal.
        input:
            ip: IP in plain text.
        output:
            The formatted IP in plain text.
        """
        iptype = self.getIPType(ip)

        try:
            # no need to format IPv4 or hostname.
            if 6 != iptype:
                return ip

            # sr[0] is the first runnable protocol.
            # sr[0][4] is the tuple of host.sockaddr, (ip, port),
            # and ip is always in lowercase, abbreviated.
            sr = socket.getaddrinfo(ip, None)
            return str(sr[0][4][0])
        except Exception:
            if iptype == 6:
                return str(ip).lower()
            return ip

    def makeSCPHost(self, host):
        """
        function:
            The same as "makeSequareBracketIP", but it is easier to understand.
        :param host:
            hostname or IP address.
        :return:
            "[host]" in case host is an IPv6 address.
            "host" itself in other cases.
        """
        return self.makeSquareBracketIP(host)

    def makeSquareBracketIP(self, ip):
        """
        function:
            Add [] to IPv6 addresses.
            ONLY IPv6 addresses will be translated.
        :param ip:
            IP address.
        :return:
            '[IPv6]' string, or ip itself.
        """
        if 6 == self.getIPType(ip):
            return '[%s]' % ip
        return ip

    def isInSameNetSegment(self, cidr1, cidr2):
        """
        function:
            Are the two IPs in the same network segment?
        :param cidr1:
            IP1 in CIDR format.
        :param cidr2:
            IP2 in CIDR format.
        :return:
            boolean
        """
        # Invalid CIDR format without '/'.
        if str(cidr1).find('/') < 0 or str(cidr2).find('/') < 0:
            return False

        ip1, mask1 = str(cidr1).split('/')
        ip2, mask2 = str(cidr2).split('/')

        mask1 = int(mask1, 10)
        mask2 = int(mask2, 10)

        iptype1 = self.getIPType(ip1)
        iptype2 = self.getIPType(ip2)

        # IPv4 and IPv6 are incompatible.
        if iptype1 != iptype2:
            return False

        if iptype1 == 6:
            totallen = 128
        elif iptype1 == 4:
            totallen = 32
        else:
            # Invalid IP string.
            return False

        # Invalid CIDR value.
        if totallen < mask1 or totallen < mask2:
            return False

        try:
            # Convert the IP and mask into a long integer, and then bit-and them.
            # If they are in same network segment, the bit-and result should be the same.
            if iptype1 == 6:
                iplong1 = int(binascii.hexlify(socket.inet_pton(socket.AF_INET6, ip1)), 16)
                iplong2 = int(binascii.hexlify(socket.inet_pton(socket.AF_INET6, ip2)), 16)
            else:
                iplong1 = int(binascii.hexlify(socket.inet_pton(socket.AF_INET, ip1)), 16)
                iplong2 = int(binascii.hexlify(socket.inet_pton(socket.AF_INET, ip2)), 16)
            masklong1 = int('0b' + '1' * mask1 + '0' * (totallen - mask1), 2)
            masklong2 = int('0b' + '1' * mask1 + '0' * (totallen - mask2), 2)

            return (iplong1 & masklong1) == (iplong2 & masklong2)

        except Exception:
            return False

    @staticmethod
    def getHostProtoVersion(host, proto=""):
        """
        function:
            Get the IP version of the host.
            If IPv4 and IPv6 are all available, IPv4 will be returned.
        :param host: The hostname or IP to detect.
        :param proto:
            "":     Blank string means either IPv4 or IPv6 is OK.
            "ipv4": means if the host can be connected through IPv4.
            "ipv6": as above.
        :return:
            4: IPv4 is available.
            6: IPv6 is available.
            0:
                If iptype is blank string or None, 0 means the target host is not available.
                If iptype is specified, 0 means the target host is not avaiable under the specified IP version.
        """
        try:
            sr = socket.getaddrinfo(host, None)
            for i in range(len(sr)):
                if proto is None or proto == "":
                    if sr[i][0] == socket.AF_INET:
                        return 4
                    else:
                        return 6
                elif proto == "ipv4" and sr[i][0] == socket.AF_INET:
                    return 4
                elif proto == "ipv6" and sr[i][0] == socket.AF_INET6:
                    return 6
            return 0
        except Exception:
            return 0

    def isIPv6LocalLink(self, ip):
        """
        function:
            Is the IPv6 ip string of local link scope one(fe80::/16).
        :param ip:
            The IP string.
        :return:
            Boolean.
        """
        if self.getIPType(ip) != 6:
            return False

        ip = self.formatIP(ip)
        if ip.startswith('fe80:'):
            return True

        return False


g_network = Network()
