#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# ############################################################################
# Copyright (c): 2020, Huawei Tech. Co., Ltd.
# FileName     : pssh
# Version      : Gauss
# Date         : 2019-12-31
# Description  : Parallel ssh to the set of nodes in hosts.txt.
#                 For each node, this essentially does an "ssh host command".
#                 from each remote node in a directory.
#                 Each output file in that directory will be named
#                 by the corresponding remote node's hostname or IP address.
# ############################################################################
import optparse
import sys
import shlex
import os
import stat

try:
    from TaskPool import TaskPool
    from TaskPool import read_host_file, get_om_conf
except ImportError as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

TIME_OUT = 0
PARALLEL_NUM = 32
PARALLEL_EXECUTE_COMMAND_CHANNEL = "SSH_TRUST"


def parse_command():
    """
    return: parser
    """
    parser = optparse.OptionParser(conflict_handler='resolve')
    parser.disable_interspersed_args()
    parser.usage = "%prog [OPTIONS] command"
    parser.epilog = "Example: pssh -H hostname 'id'"
    parser.add_option('-H', dest='hostname', action='append',
                      help='Nodes to be connected')
    parser.add_option('-h', dest='hostfile',
                      help='Host file with each line per node')
    parser.add_option('-t', dest='timeout', type='int',
                      help='Timeouts in seconds')
    parser.add_option('-p', dest='parallel', type='int',
                      help='Maximum number of parallel')
    parser.add_option('-o', dest='outdir', help='Output results folder')
    parser.add_option('-e', dest='errdir', help='Error results folder')
    parser.add_option('-P', dest='print', action='store_true',
                      help='Print output')
    parser.add_option('-s', dest='shellmode', action='store_true',
                      help='Output only execution results')
    parser.add_option('-x', dest='extra',
                      help='Extra command-line arguments')
    parser.add_option('-i', dest='inline', action='store_true',
                      help='aggregated output and error for each server')
    parser.add_option('-O', dest='opt', action='append',
                      help='Additional ssh parameters')
    parser.add_option('-r', dest='region', help='chroot or Ruby')
    return parser


def check_parse(parser_info):
    """
    :param parser_info: Parameter key-value pairs
    :return: NA
    """
    # set defaults parallel and timeout value
    defaults = dict(parallel=PARALLEL_NUM, timeout=TIME_OUT)
    parser_info.set_defaults(**defaults)
    opts_info, args_info = parser_info.parse_args()

    if not opts_info:
        parser_info.error("The commands is request.")
        parser_info.print_help()
    if not opts_info.hostname and not opts_info.hostfile:
        parser_info.error("The host info is request.")
        parser_info.print_help()

    return opts_info, args_info


def log_error(err_path, err_code, err_msg):
    log_msg = f"pssh error, error code: {err_code}, error message: {err_msg}"
    try:
        err_file_path = os.path.join(err_path, 'pssh_err.log')
        flag = os.O_WRONLY | os.O_CREAT | os.O_APPEND  # open file with 'append' mode
        mode = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP  # log file access permissions is 640
        with os.fdopen(os.open(err_file_path, flag, mode), 'w') as err_file:
            err_file.write(str(log_msg))
    except Exception:
        return


def run(hosts, opts, args):
    """
    function: do run process
    input : hosts
    output: NA
    """

    manager = TaskPool(opts)
    for host in hosts:
        if PARALLEL_EXECUTE_COMMAND_CHANNEL == "SSH_TRUST":
            cmd = get_ssh_cmd(host, opts, args)
            manager.add_task(host, cmd)
        else:
            cmd = args
            manager.add_task(host, cmd, "rpc_cmd")
    try:
        statuses = manager.start()

        if len(list(statuses)) == 1:
            sys.exit(list(statuses)[0])

        if min(statuses) < 0:
            # At least one process was killed.
            exit_code = 3
            if opts.errdir:
                log_error(opts.errdir, exit_code,
                          "pssh error: At least one ssh process was killed.")
            sys.exit(exit_code)
        for status in statuses:
            if status == 255 and not opts.shellmode:
                exit_code = 4
                if opts.errdir:
                    log_error(opts.errdir, exit_code,
                              "pssh error: An error occurred with an ssh connection.")
                sys.exit(exit_code)
        for status in statuses:
            if status != 0 and not opts.shellmode:
                exit_code = 5
                if opts.errdir:
                    log_error(opts.errdir, exit_code,
                              "pssh error: An error occurred "
                              "when the command was executed and shellmode is False.")
                sys.exit(exit_code)
            elif status != 0:
                exit_code = status
                if opts.errdir:
                    log_error(opts.errdir, exit_code,
                              "pssh error: An error occurred when the command was executed.")
                sys.exit(exit_code)

    except Exception as ex:
        exit_code = 1
        if opts.errdir:
            log_error(opts.errdir, exit_code, f"pssh error: Exception: {str(ex)}")
        print(str(ex))
        sys.exit(exit_code)


def get_ssh_cmd(host, opts, args):
    def_args = {"ConnectionAttempts": "10",
                "ConnectTimeout": "30",
                "ServerAliveCountMax": "10",
                "ServerAliveInterval": "30", }

    cmd = ["ssh", host, '-n',
           "-o", "BatchMode=yes",
           "-o", "NumberOfPasswordPrompts=1",
           "-o", "TCPKeepAlive=yes"]

    if opts.extra:
        extraInfo = shlex.split(opts.extra)
        cmd.extend(extraInfo)

    if opts.opt:
        for i in opts.opt:
            cmd.append("-o")
            cmd.append(i)

    for key in def_args:
        if opts.opt:
            if key not in ' '.join(opts.opt):
                cmd.append("-o")
                cmd.append("%s=%s" % (key, def_args[key]))
        else:
            cmd.append("-o")
            cmd.append("%s=%s" % (key, def_args[key]))

    cmd.extend(args)

    return cmd


def main():
    parsers = parse_command()
    opts, args = check_parse(parsers)
    if opts.hostfile:
        host_list = read_host_file(opts.hostfile)
    else:
        host_list = opts.hostname
    host_list = list(set(host_list))

    global PARALLEL_EXECUTE_COMMAND_CHANNEL
    exec_channel = get_om_conf('common',
                               'parallel_execute_command_channel')
    if exec_channel:
        PARALLEL_EXECUTE_COMMAND_CHANNEL = exec_channel

    run(host_list, opts, args)


if __name__ == "__main__":
    main()
