#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# ############################################################################
# Copyright (c): 2020, Huawei Tech. Co., Ltd.
# FileName     : pssh
# Version      : Gauss
# Date         : 2019-12-31
# Description  : Parallel ssh to the set of nodes in hosts.txt.
#                 For each node, this essentially does an "ssh host command".
#                 from each remote node in a directory.
#                 Each output file in that directory will be named
#                 by the corresponding remote node's hostname or IP address.
# ############################################################################
try:
    import optparse
    import sys
    import shlex
    from TaskPool import TaskPool
    from TaskPool import read_host_file
    import os
    import stat
except ImportError as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

TIME_OUT = 0
PARALLEL_NUM = 32


def parse_command():
    """
    return: parser
    """
    parser = optparse.OptionParser(conflict_handler='resolve')
    parser.disable_interspersed_args()
    parser.usage = "%prog [OPTIONS] command"
    parser.epilog = "Example: pssh -H hostname 'id'"
    parser.add_option('-H', dest='hostname', action='append',
                      help='Nodes to be connected')
    parser.add_option('-h', dest='hostfile',
                      help='Host file with each line per node')
    parser.add_option('-t', dest='timeout', type='int',
                      help='Timeouts in seconds')
    parser.add_option('-p', dest='parallel', type='int',
                      help='Maximum number of parallel')
    parser.add_option('-o', dest='outdir', help='Output results folder')
    parser.add_option('-e', dest='errdir', help='Error results folder')
    parser.add_option('-P', dest='print', action='store_true',
                      help='Print output')
    parser.add_option('-s', dest='shellmode', action='store_true',
                      help='Output only execution results')
    parser.add_option('-x', dest='extra',
                      help='Extra command-line arguments')
    parser.add_option('-i', dest='inline', action='store_true',
                      help='aggregated output and error for each server')
    parser.add_option('-O', dest='opt', action='append',
                      help='Additional ssh parameters')
    return parser


def check_parse(parser_info):
    """
    :param parser_info: Parameter key-value pairs
    :return: NA
    """
    # set defaults parallel and timeout value
    defaults = dict(parallel=PARALLEL_NUM, timeout=TIME_OUT)
    parser_info.set_defaults(**defaults)
    opts_info, args_info = parser_info.parse_args()

    if not opts_info:
        parser_info.error("The commands is request.")
    if not opts_info.hostname and not opts_info.hostfile:
        parser_info.error("The host info is request.")

    return opts_info, args_info


def log_error(err_path, err_code, err_msg):
    log_msg = f"pssh error, error code: {err_code}, error message: {err_msg}"
    if opts.errdir is None or opts.errdir == "":
        return
    try:
        err_file_path = os.path.join(err_path, 'pssh_err.log')
        flag = os.O_WRONLY | os.O_CREAT | os.O_APPEND       # open file with 'append' mode
        mode = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP   # log file access permissions is 640
        with os.fdopen(os.open(err_file_path, flag, mode), 'w') as err_file:
            err_file.write(str(log_msg))
    except Exception:
        return

def run(hosts):
    """
    function: do run process
    input : hosts
    output: NA
    """

    manager = TaskPool(opts)
    def_args = {"ConnectionAttempts": "10",
                "ConnectTimeout": "30",
                "ServerAliveCountMax": "10",
                "ServerAliveInterval": "30", }
    for host in hosts:
        cmd = ["ssh", host, '-n',
               "-o", "BatchMode=yes",
               "-o", "NumberOfPasswordPrompts=1",
               "-o", "TCPKeepAlive=yes"]
        if opts.extra:
            extraInfo = shlex.split(opts.extra)
            cmd.extend(extraInfo)
        if opts.opt:
            for i in opts.opt:
                cmd.append("-o")
                cmd.append(i)
        for key in def_args:
            if opts.opt:
                if key not in ' '.join(opts.opt):
                    cmd.append("-o")
                    cmd.append("%s=%s" % (key, def_args[key]))
            else:
                cmd.append("-o")
                cmd.append("%s=%s" % (key, def_args[key]))
        cmd.extend(args)
        manager.add_task(host, cmd)
    try:
        statuses = manager.start()
        if min(statuses) < 0:
            # At least one process was killed.
            exit_code = 3
            log_error(opts.errdir, exit_code, "pssh error: At least one ssh process was killed.")
            sys.exit(exit_code)
        for status in statuses:
            if status == 255 and not opts.shellmode:
                exit_code = 4
                log_error(opts.errdir, exit_code, "pssh error: An error occurred with an ssh connection.")
                sys.exit(exit_code)
        for status in statuses:
            if (status != 0 and not opts.shellmode):
                exit_code = 5
                log_error(opts.errdir, exit_code,
                          "pssh error: An error occurred when the command was executed and shellmode is False.")
                sys.exit(exit_code)
            elif status != 0:
                exit_code = status
                log_error(opts.errdir, exit_code, "pssh error: An error occurred when the command was executed.")
                sys.exit(exit_code)

    except Exception as ex:
        exit_code = 1
        log_error(opts.errdir, exit_code, f"pssh error: Exception: {str(ex)}")
        print(str(ex))
        sys.exit(exit_code)


if __name__ == "__main__":
    parsers = parse_command()
    opts, args = check_parse(parsers)
    if opts.hostfile:
        host_list = read_host_file(opts.hostfile)
    else:
        host_list = opts.hostname
    host_list = list(set(host_list))
    run(host_list)
