#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c): 2012-2017, Huawei Tech. Co., Ltd.
# Description  : gs_sshexkey is a utility to create SSH trust among nodes in a cluster.
#############################################################################

try:
    import sys
    import warnings

    warnings.simplefilter('ignore', DeprecationWarning)
    sys.path.append(sys.path[0] + "/../lib")
    import time
    import os
    import subprocess
    import pwd
    import grp
    import socket
    import getpass
    import shutil
    from gspylib.common.ParallelBaseOM import ParallelBaseOM
    from gspylib.common.GaussLog import GaussLog
    from gspylib.common.ErrorCode import ErrorCode
    from gspylib.threads.parallelTool import parallelTool
    from gspylib.common.Common import DefaultValue, ClusterCommand
    from gspylib.common.ParameterParsecheck import Parameter
    from gspylib.os.gsfile import g_file
    from gspylib.os.gsOSlib import g_OSlib
    from gspylib.os.gsnetwork import g_network

    DefaultValue.doConfigForParamiko()
    import paramiko
except Exception as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

IP_NUM = 512
HOSTS_MAPPING_FLAG = "#Gauss OM IP Hosts Mapping"
ipHostInfo = ""
# the tmp path
tmp_files = ""
# tmp file name
TMP_TRUST_FILE = "step_preinstall_file.dat"
functionName = "gs_sshexkey"


class PrintOnScreen():
    """
    class about print on screen
    """

    def __init__(self):
        '''
        function : Constructor
        input: NA
        output: NA
        '''
        return None

    def log(self, msg, *args):
        '''
        function : print log
        input: msg: str
        output: NA
        '''
        print(msg)

    def debug(self, msg, *args):
        '''
        function : debug
        input: msg: debug message string
        output: NA
        '''
        return None

    def logExit(self, msg, *args):
        '''
        function : print log and exit
        input: msg: str
        output: NA
        '''
        print(msg)
        sys.exit(1)


class GaussCreateTrust(ParallelBaseOM):
    """
    class about create trust for user
    """

    def __init__(self):
        '''
        function : Constructor
        input: NA
        output: NA
        '''
        ParallelBaseOM.__init__(self)
        self.logger = None
        self.hostFile = ""
        self.hostList = []
        self.passwd = []
        self.logFile = ""
        self.localHost = ""
        self.flag = False
        self.localID = ""
        self.user = pwd.getpwuid(os.getuid()).pw_name
        self.group = grp.getgrgid(os.getgid()).gr_name
        self.incorrectPasswdInfo = ""
        self.failedToAppendInfo = ""
        self.homeDir = os.path.expanduser("~" + self.user)
        self.sshDir = "%s/.ssh" % self.homeDir
        self.authorized_keys_fname = '%s/.ssh/authorized_keys' % self.homeDir
        self.known_hosts_fname = '%s/.ssh/known_hosts' % self.homeDir
        self.id_rsa_fname = '%s/.ssh/id_rsa' % self.homeDir
        self.id_rsa_pub_fname = self.id_rsa_fname + '.pub'
        self.skipHostnameSet = False
        self.isKeyboardPassword = False

    def usage(self):
        from help.gs_sshexkey_help import gs_sshexkey_usage
        gs_sshexkey_usage()

    def parseCommandLine(self):
        """
        function: Check parameter from command line
        input : NA
        output: NA
        """

        paraObj = Parameter()
        paraDict = paraObj.ParameterCommandLine("sshexkey")
        if ("helpFlag" in paraDict.keys()):
            self.usage()
            sys.exit(0)

        if ("hostfile" in paraDict.keys()):
            self.hostFile = paraDict.get("hostfile")
        if ("passwords" in paraDict.keys()):
            self.passwd = paraDict.get("passwords")
            if (self.passwd == []):
                GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"] % 'W' + ".")
        if ("logFile" in paraDict.keys()):
            self.logFile = paraDict.get("logFile")
        if ("skipHostnameSet" in paraDict.keys()):
            self.skipHostnameSet = paraDict.get("skipHostnameSet")
        if ("uuid" in paraDict.keys()):
            self.logUuid = paraDict.get("uuid")
        if ("logAction" in paraDict.keys()):
            self.logAction = paraDict.get("logAction")
        if ("logStep" in paraDict.keys()):
            self.logStep = paraDict.get("logStep")

    def checkParameter(self):
        """
        function: Check parameter from command line
        input : NA
        output: NA
        """
        # check required parameters
        if (self.hostFile == ""):
            self.usage()
            GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"] % 'f' + ".")
        if (not os.path.exists(self.hostFile)):
            GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50201"] % self.hostFile)
        if (not os.path.isabs(self.hostFile)):
            GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"] % self.hostFile)

        # read host file to hostList
        self.readHostFile()

        if (self.hostList == []):
            GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50004"] % 'f' + " It cannot be empty.")

        # check logfile
        if (self.logFile != ""):
            if (not os.path.isabs(self.logFile)):
                GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"] % self.logFile)

        if (not self.passwd):
            self.passwd = self.getUserPasswd()
        if self.logAction == "":
            self.logAction = functionName
        if self.logStep != "":
            self.logStep = int(self.logStep)
        else:
            self.logStep = 0

    def readHostFile(self):
        """
        function: read host file to hostList
        input : NA
        output: NA
        """
        f = None
        inValidIp = []
        try:
            f = open(self.hostFile, "r")
            for readLine in f:
                hostname = readLine.strip().split("\n")[0]
                if hostname != "" and hostname not in self.hostList:
                    if (not DefaultValue.isIpValid(hostname)):
                        inValidIp.append(hostname)
                        continue
                    self.hostList.append(hostname)
            if (len(inValidIp) > 0):
                GaussLog.exitWithError(ErrorCode.GAUSS_506["GAUSS_50603"] + "The IP list is:%s." % inValidIp)
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_502["GAUSS_50204"] % "host file" + " Error: \n%s" % str(e))
        finally:
            if f:
                f.close()

    def getAllHostsName(self, ip):
        """
        function:
          Connect to all nodes ,then get all hostname by threading
        precondition:
          1.User's password is correct on each node
        postcondition:
           NA
        input: ip
        output:Dictionary ipHostname,key is IP  and value is hostname
        hideninfo:NA
        """

        ipHostname = {}
        try:
            ssh = paramiko.Transport((ip, 22))
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_512["GAUSS_51220"] % ip + " Error: \n%s" % str(e))
        try:
            polic = paramiko.SSHClient()
            polic.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            ssh.connect(username=self.user, password=self.passwd[0])
        except Exception:
            ssh.close()
            raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % ip)

        check_channel = ssh.open_session()
        cmd = "cd"
        check_channel.exec_command(cmd)
        env_msg = check_channel.recv_stderr(9999).decode()
        while True:
            channel_read = check_channel.recv(9999).decode().strip()
            if (len(channel_read) != 0):
                env_msg += str(channel_read)
            else:
                break
        if (env_msg != ""):
            ipHostname["Node[%s]" % ip] = "Output: [" + env_msg + \
                                          " ] print by /etc/profile or ~/.bashrc, please check it."
            return ipHostname

        channel = ssh.open_session()
        cmd = "hostname"
        channel.exec_command(cmd)
        hostname = channel.recv(9999).decode().strip()
        ipHostname[ip] = hostname
        ssh.close()
        return ipHostname

    def verifyPasswd(self, ssh, pswd=None):
        try:
            polic = paramiko.SSHClient()
            polic.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            ssh.connect(username=self.user, password=pswd)
            return True
        except Exception:
            ssh.close()
            return False

    def parallelGetHosts(self, sshIps):
        parallelResult = {}
        ipHostname = parallelTool.parallelExecute(self.getAllHostsName, sshIps)

        err_msg = ""
        for i in ipHostname:
            for (key, value) in i.items():
                if (key.find("Node") >= 0):
                    err_msg += str(i)
                else:
                    parallelResult[key] = value
        if (len(err_msg) > 0):
            raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"] % err_msg)
        return parallelResult

    def serialGetHosts(self, sshIps):
        serialResult = {}
        invalidIP = ""
        boolInvalidIp = False
        for sshIp in sshIps:
            isPasswdOK = False
            for pswd in self.passwd:
                try:
                    ssh = paramiko.Transport((sshIp, 22))
                except Exception as e:
                    self.logger.debug(str(e))
                    invalidIP += "Incorrect IP address: %s.\n" % sshIp
                    boolInvalidIp = True
                    break

                isPasswdOK = self.verifyPasswd(ssh, pswd)
                if (isPasswdOK):
                    break

            if (boolInvalidIp):
                boolInvalidIp = False
                continue

            if (not isPasswdOK and self.isKeyboardPassword):
                GaussLog.printMessage("Please enter password for current user[%s] on the node[%s]." %
                                      (self.user, sshIp))
                # Try entering the password 3 times interactively
                for _ in range(3):
                    KeyboardPassword = getpass.getpass()
                    DefaultValue.checkPasswordVaild(KeyboardPassword)
                    ssh = paramiko.Transport((sshIp, 22))
                    isPasswdOK = self.verifyPasswd(ssh, KeyboardPassword)
                    if (isPasswdOK):
                        self.passwd.append(KeyboardPassword)
                        break
                    else:
                        continue
            # if isKeyboardPassword is true, 3 times after the password is also wrong to throw an unusual exit
            if (not isPasswdOK):
                raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % sshIp)

            cmd = "cd"
            check_channel = ssh.open_session()
            check_channel.exec_command(cmd)
            check_result = check_channel.recv_stderr(9999).decode()
            while True:
                channel_read = check_channel.recv(9999).decode()
                if (len(channel_read) != 0):
                    check_result += str(channel_read)
                else:
                    break

            if (check_result != ""):
                raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"] % check_result +
                                "Please check %s node /etc/profile or ~/.bashrc" % sshIp)
            else:
                cmd = "hostname"
                channel = ssh.open_session()
                channel.exec_command(cmd)
                while True:
                    hostname = channel.recv(9999).decode().strip()
                    if (len(hostname) != 0):
                        serialResult[sshIp] = hostname
                    else:
                        break
                ssh.close()

        if (invalidIP):
            raise Exception(ErrorCode.GAUSS_511["GAUSS_51101"] % invalidIP.rstrip("\n"))

        return serialResult

    def getAllHosts(self, sshIps):
        """
        function:
          Connect to all nodes ,then get all hostname
        precondition:
          1.User's password is correct on each node
        postcondition:
           NA
        input: sshIps,username,passwd
        output:Dictionary ipHostname,key is IP  and value is hostname
        hideninfo:NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("Get hostnames for all nodes.", "addStep")
            else:
                self.logger.debug("Get hostnames for all nodes.")
        if (len(self.passwd) == 0):
            self.isKeyboardPassword = True
            GaussLog.printMessage("Please enter password for current user[%s]." % self.user)
            passwd = getpass.getpass()
            self.passwd.append(passwd)

        if (len(self.passwd) == 1):
            try:
                result = self.parallelGetHosts(sshIps)
            except Exception as e:
                if (self.isKeyboardPassword and str(e).startswith("[GAUSS-50306] : The password of")):
                    GaussLog.printMessage("Notice :The password of some nodes is incorrect.")
                    result = self.serialGetHosts(sshIps)
                else:
                    raise Exception(str(e))
        else:
            result = self.serialGetHosts(sshIps)
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("Successfully get hostnames for all nodes.", "constant")
            else:
                self.logger.debug("Successfully get hostnames for all nodes.")
        return result

    def writeLocalHosts(self, result):
        """
        function:
         Write hostname and Ip into /etc/hosts when there's not the same one in /etc/hosts file
        precondition:
          NA
        postcondition:
           NA
        input: Dictionary result,key is IP and value is hostname
        output: NA
        hideninfo:NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("Write local hostname and Ip into /etc/hosts.", "addStep")
            else:
                self.logger.debug("Write local hostname and Ip into /etc/hosts.")
        hostIPInfo = ""
        if (os.getuid() == 0):
            tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
            # Check if /etc/hosts exists.
            if (not os.path.exists("/etc/hosts")):
                raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"] + " Error: \nThe /etc/hosts does not exist.")
            (_, output) = g_OSlib.getGrepValue("-v", " #Gauss.* IP Hosts Mapping", '/etc/hosts')
            g_file.createFile(tmpHostIpName)
            g_file.changeMode(DefaultValue.KEY_FILE_MODE, tmpHostIpName)
            g_file.writeFile(tmpHostIpName, [output])
            shutil.copyfile(tmpHostIpName, '/etc/hosts')
            g_file.removeFile(tmpHostIpName)
            for (key, value) in result.items():
                hostIPInfo += '%s  %s  %s\n' % (key, value, HOSTS_MAPPING_FLAG)
            hostIPInfo = hostIPInfo[:-1]
            ipInfoList = [hostIPInfo]
            g_file.writeFile("/etc/hosts", ipInfoList)
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("Successfully write local hostname and Ip into /etc/hosts.", "constant")
            else:
                self.logger.debug("Successfully write local hostname and Ip into /etc/hosts.")

    def writeRemoteHostName(self, ip):
        """
        function:
         Write hostname and Ip into /etc/hosts when there's not the same one in /etc/hosts file by threading
        precondition:
          NA
        postcondition:
           NA
        input: ip
        output: NA
        hideninfo:NA
        """
        writeResult = []
        result = {}
        tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
        username = pwd.getpwuid(os.getuid()).pw_name
        try:
            ssh = paramiko.Transport((ip, 22))
        except Exception as e:
            raise Exception(ErrorCode.GAUSS_511["GAUSS_51107"] + " Error: \n%s" % str(e))
        try:
            polic = paramiko.SSHClient()
            polic.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            ssh.connect(username=username, password=self.passwd[0])
        except Exception as e:
            ssh.close()
            raise Exception(ErrorCode.GAUSS_503["GAUSS_50317"] + " Error: \n%s" % str(e))
        cmd = "grep -v '%s' %s > %s && cp %s %s && rm -rf %s" % \
              (" #Gauss.* IP Hosts Mapping", '/etc/hosts',
               tmpHostIpName, tmpHostIpName, '/etc/hosts', tmpHostIpName)
        channel = ssh.open_session()
        channel.exec_command(cmd)
        ipHosts = channel.recv(9999).decode().strip()
        errInfo = channel.recv_stderr(9999).decode().strip()
        if (errInfo):
            writeResult.append(errInfo)
        else:
            if (not ipHosts):
                index = 0
                tmp_list = ipHostInfo.split("\n")
                while True:
                    if IP_NUM >= len(tmp_list) - index:
                        cmd = "echo '%s' >> /etc/hosts" % ("\n".join(tmp_list[index:]))
                    else:
                        cmd = "echo '%s' >> /etc/hosts" % ("\n".join(tmp_list[index:index + IP_NUM]))
                    channel = ssh.open_session()
                    channel.exec_command(cmd)
                    errInfo = channel.recv_stderr(9999).strip()
                    if (errInfo):
                        writeResult.append(errInfo)
                    if IP_NUM >= len(tmp_list) - index:
                        break
                    index += IP_NUM
        try:
            ssh.close()
        except Exception:
            pass
        result[ip] = writeResult
        if (len(writeResult) > 0):
            return (False, result)
        else:
            return (True, result)

    def writeRemoteHosts(self, result, username, rootPasswd):
        """
        function:
         Write hostname and Ip into /etc/hosts when there's not the same one in /etc/hosts file
        precondition:
          NA
        postcondition:
           NA
        input: Dictionary result,key is IP and value is hostname
                    rootPasswd
        output: NA
        hideninfo:NA
        """
        self.writelog("Write remote hostname and Ip into /etc/hosts.")
        global ipHostInfo
        boolInvalidIp = False
        ipHostInfo = ""
        if os.getuid() == 0:
            writeResult = []
            tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()

            if len(rootPasswd) == 1:
                result1 = {}
                for (key, value) in result.items():
                    ipHostInfo += '%s  %s  %s\n' % (key, value, HOSTS_MAPPING_FLAG)
                    if value != self.localHost:
                        if value not in result1.keys():
                            result1[value] = key

                sshIps = result1.keys()
                ipHostInfo = ipHostInfo[:-1]
                if sshIps:
                    ipRemoteHostname = parallelTool.parallelExecute(self.writeRemoteHostName, sshIps)
                    errorMsg = ""
                    for (key, value) in ipRemoteHostname:
                        if not key:
                            errorMsg = errorMsg + '\n' + str(value)
                    if errorMsg != "":
                        raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"] + " Error: %s" % errorMsg)
            else:
                for (key, value) in result.items():
                    if value == self.localHost:
                        continue
                    for pswd in rootPasswd:
                        try:
                            ssh = paramiko.Transport((key, 22))
                        except Exception as e:
                            self.logger.debug(str(e))
                            boolInvalidIp = True
                            break
                        try:
                            polic = paramiko.SSHClient()
                            polic.set_missing_host_key_policy(paramiko.AutoAddPolicy())
                            ssh.connect(username=username, password=pswd)
                            break
                        except Exception as e:
                            self.logger.debug(str(e))
                            continue
                    if (boolInvalidIp):
                        boolInvalidIp = False
                        continue
                    cmd = "grep -v '%s' %s > %s && cp %s %s && rm -rf %s" % \
                          (" #Gauss.* IP Hosts Mapping", '/etc/hosts',
                           tmpHostIpName, tmpHostIpName, '/etc/hosts', tmpHostIpName)
                    channel = ssh.open_session()
                    channel.exec_command(cmd)
                    ipHosts = channel.recv(9999).decode().strip()
                    errInfo = channel.recv_stderr(9999).decode().strip()
                    if (errInfo):
                        writeResult.append(errInfo)
                    else:
                        if (not ipHosts):
                            ipHostInfo = ""
                            for (key1, value1) in result.items():
                                ipHostInfo += '%s  %s  %s\n' % (key1, value1, HOSTS_MAPPING_FLAG)
                            ipHostInfo = ipHostInfo[:-1]
                            index = 0
                            tmp_list = ipHostInfo.split("\n")
                            while True:
                                if IP_NUM >= len(tmp_list) - index:
                                    cmd = "echo '%s' >> /etc/hosts" % ("\n".join(tmp_list[index:]))
                                else:
                                    cmd = "echo '%s' >> /etc/hosts" % ("\n".join(tmp_list[index:index + IP_NUM]))
                                channel = ssh.open_session()
                                channel.exec_command(cmd)
                                errInfo = channel.recv_stderr(9999).strip()
                                if (errInfo):
                                    writeResult.append(errInfo)
                                if IP_NUM >= len(tmp_list) - index:
                                    break
                                index += IP_NUM

                    try:
                        ssh.close()
                    except Exception:
                        pass

                if (len(writeResult) > 0):
                    raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"] + " Error: \n%s" % writeResult)
        self.writelog("Successfully write remote hostname and Ip into /etc/hosts.")

    def writelog(self, info=""):
        """
        function: write log
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug(info, "addStep")
            else:
                self.logger.debug(info)

    def checkNetworkInfo(self):
        """
        function: check  local node to other node Network Information
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Checking network information.", "addStep")
            else:
                self.logger.log("Checking network information.")
        else:
            self.logger.log("Checking network information.")
        try:
            netWorkList = DefaultValue.checkIsPing(self.hostList)
            if not netWorkList:
                self.logger.log("All nodes in the network are Normal.")
            else:
                self.logger.logExit(ErrorCode.GAUSS_506["GAUSS_50600"] + "The IP list is:%s." % netWorkList)
        except Exception as e:
            self.logger.logExit(str(e))
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Successfully checked network information.", "constant")
            else:
                self.logger.log("Successfully checked network information.")
        else:
            self.logger.log("Successfully checked network information.")

    def run(self):
        """
        function: Do create SSH trust
        input : NA
        output: NA
        """
        self.parseCommandLine()
        self.checkParameter()
        self.localHost = socket.gethostname()

        if (self.logFile != ""):
            self.initLogger(functionName)
        else:
            self.logger = PrintOnScreen()

        global tmp_files
        tmp_files = "/tmp/%s" % TMP_TRUST_FILE
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.debug("gs_sshexkey execution takes %s steps in total" %
                                  ClusterCommand.countTotalSteps("gs_sshexkey", "",
                                                                 self.skipHostnameSet))
        Ips = []
        Ips.extend(self.hostList)
        result = self.getAllHosts(Ips)
        self.checkNetworkInfo()

        if (not self.skipHostnameSet):
            self.writeLocalHosts(result)
            self.writeRemoteHosts(result, self.user, self.passwd)

        self.logger.log("Creating SSH trust.")
        try:
            self.localID = self.createPublicPrivateKeyFile()
            self.addLocalAuthorized()
            self.updateKnow_hostsFile(result)
            self.addRemoteAuthorization()
            self.determinePublicAuthorityFile()
            self.synchronizationLicenseFile()
            self.verifyTrust()
            self.logger.log("Successfully created SSH trust.")
        except Exception as e:
            self.logger.logExit(str(e))

    def createPublicPrivateKeyFile(self):
        """
        function: create  local public private key file
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Creating the local key file.", "addStep")
            else:
                self.logger.log("Creating the local key file.")
        else:
            self.logger.log("Creating the local key file.")

        if os.path.exists(self.sshDir):
            g_file.removeDirectory(self.sshDir)
        cmd = 'ssh-keygen -t rsa -N \"\" -f ~/.ssh/id_rsa < /dev/null'
        cmd += "&& chmod %s %s %s" % (DefaultValue.KEY_FILE_MODE, self.id_rsa_fname, self.id_rsa_pub_fname)
        (status, output) = subprocess.getstatusoutput(cmd)
        if (status != 0):
            raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"] + " Error:\n%s" % output)
        f = None
        try:
            try:
                f = open(self.id_rsa_pub_fname, 'r')
                return f.readline().strip()
            except IOError as e:
                self.logger.debug(str(e))
                raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"] +
                                " Unable to read the generated file." + self.id_rsa_pub_fname)
        finally:
            if f:
                f.close()
            if (self.logFile != ""):
                if (not os.path.exists(tmp_files)):
                    self.logger.log("Successfully created the local key files.", "constant")
                else:
                    self.logger.log("Successfully created the local key files.")
            else:
                self.logger.log("Successfully created the local key files.")

    def addLocalAuthorized(self):
        """
        function: append the local id_rsa.pub value provided to authorized_keys
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Appending local ID to authorized_keys.", "addStep")
            else:
                self.logger.log("Appending local ID to authorized_keys.")
        else:
            self.logger.log("Appending local ID to authorized_keys.")
        f = None
        try:
            f = open(self.authorized_keys_fname, 'a+')
            for line in f:
                if line.strip() == self.localID:
                    # The localID is already in authorizedKeys; no need to add
                    return
            f.write(self.localID)
            f.write('\n')
            if (self.logFile != ""):
                if (not os.path.exists(tmp_files)):
                    self.logger.log("Successfully appended local ID to authorized_keys.", "constant")
                else:
                    self.logger.log("Successfully appended local ID to authorized_keys.")
            else:
                self.logger.log("Successfully appended local ID to authorized_keys.")
        finally:
            if f:
                f.close()
        g_file.changeMode(DefaultValue.KEY_FILE_MODE, self.authorized_keys_fname)

    def checkAuthentication(self, hostname):
        """
        function: Ensure the proper password-less access to the remote host.
        input : hostname
        output: True/False, hostname
        """
        cmd = 'export LD_LIBRARY_PATH=/lib64:$LD_LIBRARY_PATH; ssh -n %s %s true' % (DefaultValue.SSH_OPTION, hostname)
        (status, output) = subprocess.getstatusoutput(cmd)
        if (status != 0):
            self.logger.debug("Failed to check authentication. Hostname:%s. Error: \n%s" % (hostname, output))
            return (False, hostname)
        return (True, hostname)

    def updateKnow_hostsFile(self, result):
        """
        function: keyscan all hosts and update known_hosts file
        input : result
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Updating the known_hosts file.", "addStep")
            else:
                self.logger.log("Updating the known_hosts file.")
        else:
            self.logger.log("Updating the known_hosts file.")
        hostnameList = []
        hostnameList.extend(self.hostList)
        for (_, value) in result.items():
            hostnameList.append(value)
        # obtaining the ssh public key file and chmod
        parallelTool.parallelExecute(self.obtainSshPubKeyFile, hostnameList)
        (status, _) = self.checkAuthentication(self.localHost)
        if not status:
            raise Exception(ErrorCode.GAUSS_511["GAUSS_51100"] % self.localHost)
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Successfully updated the known_hosts file.", "constant")
            else:
                self.logger.log("Successfully updated the known_hosts file.")
        else:
            self.logger.log("Successfully updated the known_hosts file.")

    def obtainSshPubKeyFile(self, hostname):
        """
        function: obtaining the ssh public key file
        input : hostname
        output: NA
        """
        cmd = 'export LD_LIBRARY_PATH=/lib64:$LD_LIBRARY_PATH; ' \
              'ssh-keyscan -t rsa %s >> %s ' % (hostname, self.known_hosts_fname)
        cmd += "&& chmod %s %s" % (DefaultValue.KEY_FILE_MODE, self.known_hosts_fname)
        (status, output) = subprocess.getstatusoutput(cmd)
        if (status != 0):
            raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % cmd + " Error:\n%s" % output)

    def tryParamikoConnect(self, hostname, client, pswd=None, silence=False):
        """
        function: try paramiko connect
        input : hostname, client, pswd, silence
        output: True/False
        """
        try:
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            client.connect(hostname, password=pswd, allow_agent=False, look_for_keys=False)
            return True
        except Exception:
            if not silence:
                self.logger.debug('[SSHException %s] %s' % (hostname, str(e)))
            client.close()
            return False

    def addRemoteAuthorization(self):
        """
        function: Send local ID to remote over SSH, and append to authorized_key
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Appending authorized_key on the remote node.", "addStep")
            else:
                self.logger.log("Appending authorized_key on the remote node.")
        else:
            self.logger.log("Appending authorized_key on the remote node.")
        try:
            parallelTool.parallelExecute(self.sendRemoteAuthorization, self.hostList)
            if (self.incorrectPasswdInfo != ""):
                self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"] % (self.incorrectPasswdInfo.rstrip("\n")))
            if (self.failedToAppendInfo != ""):
                self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"] % (self.failedToAppendInfo.rstrip("\n")))
        except Exception as e:
            self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51111"] + " Error:%s." % str(e))
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Successfully appended authorized_key on all remote node.", "constant")
            else:
                self.logger.log("Successfully appended authorized_key on all remote node.")
        else:
            self.logger.log("Successfully appended authorized_key on all remote node.")

    def sendRemoteAuthorization(self, hostname):
        """
        function: send remote authorization
        input : hostname
        output: NA
        """
        if (hostname != self.localHost):
            p = None
            cin = cout = cerr = None
            try:
                # ssh Remote Connection other node
                p = paramiko.SSHClient()
                ok = self.tryParamikoConnect(hostname, p, self.passwd[0], silence=True)
                if not ok:
                    for pswd in self.passwd[1:]:
                        ok = self.tryParamikoConnect(hostname, p, pswd, silence=True)
                        if ok:
                            break
                if not ok:
                    self.incorrectPasswdInfo += "Without this node[%s] of the correct password.\n" % hostname
                    return
                # Create .ssh directory and ensure content meets permission requirements
                # for password-less SSH
                cmd = ('mkdir -p .ssh; ' + "chown -R %s:%s %s; " %
                       (self.user, self.group, self.sshDir) + 'chmod %s .ssh; ' % DefaultValue.KEY_DIRECTORY_MODE +
                       'touch .ssh/authorized_keys; ' + 'touch .ssh/known_hosts; ' +
                       'chmod %s .ssh/auth* .ssh/id* .ssh/known_hosts; ' % DefaultValue.KEY_FILE_MODE)
                (cin, cout, cerr) = p.exec_command(cmd)
                cin.close()
                cout.close()
                cerr.close()

                # Append the ID to authorized_keys;
                cnt = 0
                cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ok' % self.localID
                (cin, cout, cerr) = p.exec_command(cmd)
                cin.close()
                # readline will read other msg.
                line = cout.read().decode()
                while (line.find("ok ok ok") < 0):
                    time.sleep(cnt * 2)
                    cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ok' % self.localID
                    (cin, cout, cerr) = p.exec_command(cmd)
                    cin.close()
                    cnt += 1
                    line = cout.readline()
                    if (cnt >= 3):
                        break
                    if (line.find("ok ok ok") < 0):
                        continue
                    else:
                        break

                if (line.find("ok ok ok") < 0):
                    self.failedToAppendInfo += "...send to %s\nFailed to append local ID to " \
                                               "authorized_keys on remote node %s.\n" % \
                                               (hostname, hostname)
                    return
                cout.close()
                cerr.close()
                self.logger.debug("Send to %s\nSuccessfully appended authorized_key "
                                  "on remote node %s." % (hostname, hostname))
            finally:
                if cin:
                    cin.close()
                if cout:
                    cout.close()
                if cerr:
                    cerr.close()
                if p:
                    p.close()

    def determinePublicAuthorityFile(self):
        '''
        function: determine common authentication file content
        input : NA
        output: NA
        '''
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Checking common authentication file content.", "addStep")
            else:
                self.logger.log("Checking common authentication file content.")
        else:
            self.logger.log("Checking common authentication file content.")
        # eliminate duplicates in known_hosts file
        try:
            tab = self.readKnownHosts()
            self.writeKnownHosts(tab)
        except IOError as e:
            self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"] % "known hosts file" + " Error:\n%s" % str(e))

        # eliminate duploicates in authorized_keys file
        try:
            tab = self.readAuthorizedKeys()
            self.writeAuthorizedKeys(tab)
        except IOError as e:
            self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"] % "authorized keys file" + " Error:\n%s" % str(e))
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Successfully checked common authentication content.", "constant")
            else:
                self.logger.log("Successfully checked common authentication content.")
        else:
            self.logger.log("Successfully checked common authentication content.")

    def addRemoteID(self, tab, line):
        """
        function: add remote node id
        input : tab, line
        output: True/False
        """
        IDKey = line.strip().split()
        if not (len(IDKey) == 3 and line[0] != '#'):
            return False
        tab[IDKey[2]] = line
        return True

    def readAuthorizedKeys(self, tab=None, keysFile=None):
        """
        function: read authorized keys
        input : tab, keysFile
        output: tab
        """
        if not keysFile:
            keysFile = self.authorized_keys_fname
        f = None
        if not tab:
            tab = {}
        try:
            f = open(keysFile, 'r')
            for line in f:
                self.addRemoteID(tab, line)
        finally:
            if f:
                f.close()
        return tab

    def writeAuthorizedKeys(self, tab, keysFile=None):
        """
        function: write authorized keys
        input : tab, keysFile
        output: True/False
        """
        if not keysFile:
            keysFile = self.authorized_keys_fname
        f = None
        try:
            f = open(keysFile, 'w')
            for IDKey in tab:
                f.write(tab[IDKey])
        finally:
            if f:
                f.close()

    def addKnownHost(self, tab, line):
        """
        function: add known host
        input : tab, line
        output: True/False
        """
        key = line.strip().split()
        if not (len(key) == 3 and line[0] != '#'):
            return False
        tab[key[0]] = line
        return True

    def readKnownHosts(self, tab=None, hostsFile=None):
        """
        function: read known host
        input : tab, hostsFile
        output: tab
        """
        if not hostsFile:
            hostsFile = self.known_hosts_fname
        f = None
        if not tab:
            tab = {}
        try:
            f = open(hostsFile, 'r')
            for line in f:
                self.addKnownHost(tab, line)
        finally:
            if f:
                f.close()
        return tab

    def writeKnownHosts(self, tab, hostsFile=None):
        """
        function: write known host
        input : tab, hostsFile
        output: NA
        """
        if not hostsFile:
            hostsFile = self.known_hosts_fname
        f = None
        try:
            f = open(hostsFile, 'w')
            for key in tab:
                f.write(tab[key])
        finally:
            if f:
                f.close()

    def sendTrustFile(self, hostname):
        '''
        function: Set or update the authentication files on  hostname
        input : hostname
        output: NA
        '''
        # For IPv6, scp command mis-recognizes the hostname as there are many ':' in it.
        hostname = g_network.makeSCPHost(hostname)
        cmd = ('export LD_LIBRARY_PATH=/lib64:$LD_LIBRARY_PATH; '
               'scp -q -o "BatchMode yes" -o "NumberOfPasswordPrompts 0" ' + '%s %s %s %s %s:.ssh/' % (
            self.authorized_keys_fname, self.known_hosts_fname, self.id_rsa_fname, self.id_rsa_pub_fname, hostname))
        (status, output) = subprocess.getstatusoutput(cmd)
        if (status != 0):
            raise Exception(ErrorCode.GAUSS_502["GAUSS_50223"] % "the authentication" +
                            " Node:%s. Error:\n%s" % (hostname, output))

    def synchronizationLicenseFile(self):
        '''
        function: Distribution of documents through concurrent execution ThreadPool.
        input : NA
        output: NA
        '''
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Distributing SSH trust file to all node.", "addStep")
            else:
                self.logger.log("Distributing SSH trust file to all node.")
        else:
            self.logger.log("Distributing SSH trust file to all node.")
        try:
            parallelTool.parallelExecute(self.sendTrustFile, self.hostList)
        except Exception as e:
            self.logger.logExit(str(e))
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Successfully distributed SSH trust file to all node.", "constant")
            else:
                self.logger.log("Successfully distributed SSH trust file to all node.")
        else:
            self.logger.log("Successfully distributed SSH trust file to all node.")

    def verifyTrust(self):
        """
        function: Verify creating SSH trust is successful
        input : NA
        output: NA
        """
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Verifying SSH trust on all hosts.", "addStep")
            else:
                self.logger.log("Verifying SSH trust on all hosts.")
        else:
            self.logger.log("Verifying SSH trust on all hosts.")
        try:
            results = parallelTool.parallelExecute(self.checkAuthentication, self.hostList)
            hostnames = ""
            for (key, value) in results:
                if (not key):
                    hostnames = hostnames + ',' + value
            if (hostnames != ""):
                raise Exception(ErrorCode.GAUSS_511["GAUSS_51100"] % hostnames.lstrip(','))
        except Exception as e:
            self.logger.logExit(str(e))
        if (self.logFile != ""):
            if (not os.path.exists(tmp_files)):
                self.logger.log("Successfully verified SSH trust on all hosts.", "constant")
            else:
                self.logger.log("Successfully verified SSH trust on all hosts.")
        else:
            self.logger.log("Successfully verified SSH trust on all hosts.")

    def getUserPasswd(self):
        """
        function: get user passwd from cache
        input: NA
        output: NA
        """
        user_passwd = []
        if (sys.stdin.isatty()):
            GaussLog.printMessage("Please enter password for current user[%s]." % self.user)
            user_passwd.append(getpass.getpass())
        else:
            user_passwd.append(sys.stdin.readline().strip('\n'))

        if (not user_passwd):
            GaussLog.exitWithError("Password should not be empty")

        return user_passwd


if __name__ == '__main__':
    # main function
    createTrust = None
    try:
        createTrust = GaussCreateTrust()
        createTrust.run()
    except Exception as e:
        GaussLog.exitWithError(str(e))

    sys.exit(0)
