#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# ############################################################################
# Copyright (c): 2020, Huawei Tech. Co., Ltd.
# FileName     : pscp
# Version      : Gauss
# Date         : 2019-12-31
# Description  : Parallel scp to the set of nodes.
#                 For each node, do a scp [-r] local ip:remote. Note that
#                 remote must be an absolute path.
# ############################################################################
import optparse
import os
import sys

try:
    from TaskPool import TaskPool
    from TaskPool import read_host_file, get_om_conf
except ImportError as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))

TIME_OUT = 0
PARALLEL_NUM = 32
PARALLEL_EXECUTE_COMMAND_CHANNEL = "SSH_TRUST"


def parse_command():
    """
    :return: NA
    """
    parser = optparse.OptionParser(conflict_handler='resolve')
    parser.disable_interspersed_args()
    parser.usage = "%prog [OPTIONS] src_path dest_path"
    parser.epilog = "Example: pscp -H hostname test.txt /home/omm/test.txt"
    parser.add_option('-H', dest='hostname', action='append',
                      help='Nodes to be connected')
    parser.add_option('-h', dest='hostfile',
                      help='Host file with each line per node')
    parser.add_option('-t', dest='timeout', type='int',
                      help='Timeouts in seconds')
    parser.add_option('-p', dest='parallel', type='int',
                      help='Maximum number of parallel')
    parser.add_option('-o', dest='outdir', help='Output results folder')
    parser.add_option('-e', dest='errdir', help='Error results folder')
    parser.add_option('-r', dest='recursive', action='store_true',
                      help='recusively copy directories')
    parser.add_option('-v', dest='verbose', action='store_true',
                      help='turn on diagnostic messages')
    parser.add_option('-s', dest='shellmode', action='store_true',
                      help='Output only execution results')
    parser.add_option('-x', dest='extra', action='append',
                      help='Additional scp parameters')
    parser.add_option('-i', dest='inline', action='store_true',
                      help='aggregated output and error for each server')
    parser.add_option('-O', dest='opt', action='append',
                      help='Additional scp parameters')
    parser.add_option('-g', dest='get', action='store_true',
                      help='get file from remote')

    return parser


def check_parse(parser_info):
    """
    :param parser_info: Parameter key-value pairs
    :return: NA
    """
    # set defaults parallel and timeout value
    defaults = dict(parallel=PARALLEL_NUM, timeout=TIME_OUT)
    parser_info.set_defaults(**defaults)
    opts_info, args_info = parser_info.parse_args()

    if len(args_info) < 2:
        parser_info.error('path not specified.')

    if not opts_info.hostname and not opts_info.hostfile:
        parser_info.error('Hosts not specified.')

    return opts_info, args_info


def run(hosts, opts, args):
    """
    function: do run process
    input : hosts, opts, args
    output: NA
    """
    src_path, dest_path = get_file_path(hosts, args, opts)

    if opts.outdir and not os.path.exists(opts.outdir):
        os.makedirs(opts.outdir)

    if opts.errdir and not os.path.exists(opts.errdir):
        os.makedirs(opts.errdir)

    manager = TaskPool(opts)
    for host in hosts:
        if PARALLEL_EXECUTE_COMMAND_CHANNEL == "SSH_TRUST":
            cmd = get_scp_cmd(host, opts, src_path, dest_path)
            manager.add_task(host, cmd)
        else:
            task_type = "rpc_getfile" if opts.get else "rpc_sendfile"
            cmd = [src_path, dest_path]
            manager.add_task(host, cmd, task_type)
    try:
        statuses = manager.start()
        if min(statuses) < 0:
            # At least one process was killed
            sys.exit(3)
        for status in statuses:
            if status != 0:
                sys.exit(4)
    except Exception as ex:
        print(str(ex))
        sys.exit(1)


def get_file_path(hosts, args, opts):
    """
    :param hosts:
    :param args:
    :param opts:
    :return:
    """
    src_path = args[0:-1]
    dest_path = args[-1]

    if PARALLEL_EXECUTE_COMMAND_CHANNEL == "SSH_TRUST":
        if not os.path.isabs(dest_path):
            print("ERROR: Remote path %s must be an absolute path." % dest_path)
            sys.exit(3)
    else:
        if len(src_path) != 1:
            print("ERROR: Only supports transferring one file at a time.")
            sys.exit(3)

        src_path = src_path[0]
        if opts.get:
            # get one file from one remote host.
            if not os.path.isabs(src_path):
                print("ERROR: Source path %s must be an absolute path." % dest_path)
                sys.exit(3)

            dest_path = os.path.abspath(dest_path)
            if dest_path.endswith('/'):
                dest_path = os.path.join(dest_path, os.path.basename(src_path))

            if len(hosts) > 1:
                print("ERROR: Only support get file from one host.")
                sys.exit(3)
        else:
            # send one file to many remote hosts.
            src_path = os.path.abspath(src_path)
            if not os.path.exists(src_path):
                print("ERROR: Source path %s does not exists." % src_path)
                sys.exit(3)

            if not os.path.isabs(dest_path):
                print("ERROR: Destination path %s must be an absolute path." % dest_path)
                sys.exit(3)

            if dest_path.endswith('/'):
                dest_path = os.path.join(dest_path, os.path.basename(src_path))

    return src_path, dest_path


def get_scp_cmd(host, opts, src_path, dest_path):
    """
    :param host:
    :param opts:
    :param src_path:
    :param dest_path:
    :return:
    """
    cmd = ['scp', '-qC', '-o', 'BatchMode=yes']
    if opts.recursive:
        cmd.append('-r')

    if opts.extra:
        cmd.extend(opts.extra)

    if opts.opt:
        for i in opts.opt:
            cmd.append("-o")
            cmd.append(i)

    cmd.extend(src_path)
    cmd.append('%s:%s' % (host, dest_path))

    return cmd


def main():
    """
    :return:
    """
    parsers = parse_command()
    opts, args = check_parse(parsers)
    if opts.hostfile:
        host_list = read_host_file(opts.hostfile)
    else:
        host_list = opts.hostname
    host_list = list(set(host_list))

    global PARALLEL_EXECUTE_COMMAND_CHANNEL
    exec_channel = get_om_conf('common',
                               'parallel_execute_command_channel')
    if exec_channel:
        PARALLEL_EXECUTE_COMMAND_CHANNEL = exec_channel

    run(host_list, opts, args)


if __name__ == "__main__":
    main()
