#!/usr/bin/env python3
# -*- coding=utf-8 -*-
# ############################################################################
# Copyright (c): 2012-2017, Huawei Tech. Co., Ltd.
# Description  : TaskPool.py is a utility to manage tasks.
# ############################################################################
import stat
import configparser
import sys
import os
import signal
import subprocess
import threading
import time
import base64
import json
import socket
import pickle

DIRECTORY_PERMISSION = 0o750
MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024
SCHEDULER_PORT = 12019
CHUNK_SIZE = 1 * 1024 * 1024


class WriterThread(threading.Thread):
    """
    class writer.
    Thread that processes the result content from TaskThread
     and writes the result content to a file.
    """

    def __init__(self, f_out, f_std):
        super(WriterThread, self).__init__()
        self.out_file = f_out
        self.err_file = f_std

        self.stdout = None
        self.stderr = None

    def run(self):
        """
        Writing the result content to a file.
        """
        if self.out_file:
            write_file(self.out_file, self.stdout.encode())

        if self.err_file:
            write_file(self.err_file, self.stderr.encode())


class TaskThread(threading.Thread):
    """
    class task
    Starts a task thread.
    """

    def __init__(self, host, cmd, task_type, f_out="", f_err="",
                 detail=False, timeout=0, shell_mode=False, inline=False, region=""):
        super(TaskThread, self).__init__()
        self.setDaemon(True)

        self.task_type = task_type
        self.host = host
        self.cmd = cmd
        self.detail = bool(detail)
        self.timeout = timeout
        self.shell_mode = shell_mode
        self.inline = inline
        self.region = region

        self.status = 0
        self.stdout, self.stderr = "", ""
        self.failures = []
        self.proc = None
        self.timestamp = time.time()
        self.isKill = False
        self.writer = WriterThread(f_out, f_err) if (f_out or f_err) else None

    def kill(self):
        """
        Kill the process of cmd.
        :param p: object of subprocess
        :return: NA
        """
        self.failures.append("Timed out")
        # kill process
        try:
            if self.proc:
                self.proc.kill()
        except (OSError, IOError):
            # If the kill fails, then just assume the process is dead.
            pass
        self.isKill = True
        # Set the status
        self.status = -1 * signal.SIGKILL
        self.failures.append("Killed by signal %s" % signal.SIGKILL)

    def getElapsedTime(self):
        """
         Getting elapsed timestamp.
        :return: timestamp
        """
        return time.time() - self.timestamp

    def checkTimeout(self):
        """
        check timed-out process.
        """
        if self.isKill or self.timeout <= 0:
            return False
        timeleft = self.timeout - self.getElapsedTime()
        if (timeleft <= 0):
            return True
        return False

    def run(self):
        """
        Execute the cmd on host.
        :return: NA
        """
        self.timestamp = time.time()
        if self.task_type == "rpc_cmd":
            self.rpc_exec_cmd()
        elif self.task_type == "rpc_sendfile":
            self.rpc_send_file()
        elif self.task_type == "rpc_getfile":
            self.rpc_get_file()
        else:
            self.exec_ssh_cmd()

    def rpc_exec_cmd(self):
        try:
            sys.path.append(sys.path[0] + "/../../../")
            from scheduler.src.sche.grpc_server.ScheConn_pb2 import HelloRequest

            params = {"cmd": ' '.join(self.cmd), "region": self.region}
            params = base64.b64encode(json.dumps(params).encode()).decode()
            cmd = "cmd %s" % params

            stub = self.get_rpc_stub()
            response = stub.SayHello(HelloRequest(cmd=cmd))
            status, msg = pickle.loads(bytes.fromhex(response.message))
            if str(msg):
                msg = str(msg) if str(msg)[-1] == '\n' else str(msg) + '\n'
        except Exception as e:
            status, msg = 1, str(e)

        self.status = status
        if self.status == 0:
            self.stdout += msg
        else:
            self.stderr += msg

    def get_rpc_stub(self):
        sys.path.append(sys.path[0] + "/../../../")
        import grpc
        from scheduler.src.sche.grpc_server.ScheConn_pb2_grpc import ScheConnServiceStub

        host_name, host_ip = get_host_info(self.host)
        address = '{}:{}'.format(host_ip, str(SCHEDULER_PORT))
        options = (('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
                   ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH))

        channel = grpc.insecure_channel(address, options=options)
        stub = ScheConnServiceStub(channel)

        return stub

    def rpc_send_file(self):
        try:
            # send source_path to dest_path of remote host
            stub = self.get_rpc_stub()
            src_path, dest_path = self.cmd[0], self.cmd[1]
            response = stub.SendFile(generate_send_chunks(src_path, dest_path))

            if str(response.message):
                msg = str(response.message) if str(response.message)[-1] == '\n' \
                    else str(response.message) + '\n'
            else:
                msg = ""

            if response.success:
                self.status = 0
                self.stdout += msg
            else:
                self.status = 1
                self.stderr += msg
        except Exception as e:
            self.status = 1
            self.stderr += str(e)

    def rpc_get_file(self):
        try:
            sys.path.append(sys.path[0] + "/../../../")
            from scheduler.src.sche.grpc_server.ScheConn_pb2 import FileRequest

            stub = self.get_rpc_stub()
            src_path, dest_path = self.cmd[0], self.cmd[1]
            chunks = stub.GetFile(FileRequest(filename=src_path))

            with open(dest_path, "wb") as fp:
                for chunk in chunks:
                    fp.write(chunk.content)
            os.chmod(dest_path, 0o700)
            self.status = 0
        except Exception as e:
            self.status = 1
            self.stderr += str(e)

    def exec_ssh_cmd(self):
        env_dict = dict()
        for key in os.environ:
            env_dict[key] = os.environ.get(key)
        env_dict["LD_LIBRARY_PATH"] = "/lib64:%s" % os.environ.get("LD_LIBRARY_PATH")
        self.proc = subprocess.Popen(self.cmd, shell=False,
                                     stdout=subprocess.PIPE,
                                     stderr=subprocess.PIPE,
                                     env=env_dict)
        if self.timeout == 0:
            stdout, stderr = self.proc.communicate(timeout=None)
        else:
            stdout, stderr = self.proc.communicate(timeout=self.timeout)
        self.stdout += stdout.decode()
        self.stderr += stderr.decode()
        self.status = self.proc.returncode

    def __print_out(self):
        if not self.stdout and not self.stderr:
            return
        if self.shell_mode:
            sys.stderr.write("%s" % self.stderr)
            sys.stdout.write("%s" % self.stdout)
        else:
            if self.stdout:
                sys.stdout.write("%s: %s" % (self.host, self.stdout))
        # Use [-1] replace of .endswith, can avoid the problem about
        # coding inconsistencies
        if self.stdout and self.stdout[-1] != os.linesep:
            sys.stdout.write(os.linesep)
        if self.shell_mode and self.stderr and self.stderr[-1] != os.linesep:
            sys.stderr.write(os.linesep)

    def __print_result(self, index):
        """
        Print the result into sys.stdout
        :return: NA
        """
        if self.shell_mode:
            str_ = ""
        else:
            str_ = "[%s] %s [%s] %s" % (
                index,
                time.asctime().split()[3],
                "SUCCESS" if not self.status else "FAILURE",
                self.host
            )
            if self.status > 0:
                str_ += " Exited with error code %s" % self.status

        if self.failures:
            failures_msg = ", ".join(self.failures)
            str_ = str_ + " " + failures_msg

        if str_:
            print(str_)
        if self.inline:
            sys.stdout.write("%s" % self.stdout)

    def write(self, index):
        """
        Write the output into sys.stdout and files.
        :return: object of writer or None
        """
        # Print the stdout into sys.stdout
        if self.detail:
            self.__print_out()
        # Print the status
        self.__print_result(index)

        # Write the self.stdout and self.stderr into files.
        if self.writer:
            self.writer.stdout = self.stdout
            self.writer.stderr = self.stderr
            self.writer.start()
        return self.writer


class TaskPool(object):
    """
    class manager
    """

    def __init__(self, opts):
        """
        Initialize
        """
        self.out_path = opts.outdir
        self.err_path = opts.errdir
        self.detail = True
        self.parallel_num = opts.parallel
        self.timeout = opts.timeout
        self.shell_mode = opts.shellmode
        self.inline = opts.inline
        self.region = opts.region if hasattr(opts, 'region') else ""

        self.tasks = []
        self.running_tasks = []
        self.writers = []
        self.task_status = {}

    def __get_task_files(self, host):
        """
        Obtain the result file of the task.
        """
        std_path = ""
        if self.out_path:
            std_path = os.path.join(self.out_path, host)

        err_path = ""
        if self.err_path:
            err_path = os.path.join(self.err_path, host)

        return std_path, err_path

    def add_task(self, host, cmd, task_type=""):
        """
        Adding a Task to the Task Pool
        """

        f_out, f_err = self.__get_task_files(host)
        task = TaskThread(host, cmd, task_type, f_out, f_err, self.detail,
                          self.timeout, self.shell_mode, self.inline, self.region)
        self.tasks.append(task)

    def __get_writing_task(self):
        """
        Check the task status and obtain the running tasks.
        """
        still_running = []
        not_running = []
        time.sleep(0.1)
        # Check whether the task times out. If the task times out,
        # stop the task.
        for task in self.running_tasks:
            if task.checkTimeout():
                task.kill()

        # filter the still running tasks and not running tasks
        for task in self.running_tasks:
            if task.is_alive():
                still_running.append(task)
            else:
                self.task_status[task.host] = task.status
                not_running.append(task)

        # Start the writing thread of completed tasks
        for task in not_running:
            index = len(self.writers) + 1
            writer = task.write(index)
            if writer:
                self.writers.append(writer)

        self.running_tasks = still_running

    def __start_limit_task(self):
        """
        Starts the tasks within a specified number of parallel.
        """
        while self.tasks and len(self.running_tasks) < self.parallel_num:
            task = self.tasks.pop(0)
            self.running_tasks.append(task)
            task.start()

    def start(self):
        """
        Start to execute all tasks.
        """
        # Create the path of stdout and stderr
        if self.out_path and not os.path.exists(self.out_path):
            os.makedirs(self.out_path, DIRECTORY_PERMISSION)
        if self.err_path and not os.path.exists(self.err_path):
            os.makedirs(self.err_path, DIRECTORY_PERMISSION)

        # Do cmd
        while self.tasks or self.running_tasks:
            self.__get_writing_task()
            self.__start_limit_task()

        # Waiting for writing files complete.
        for writer in self.writers:
            writer.join()

        return self.task_status.values()


def read_host_file(host_file):
    """
    Reads the host file.
    Lines are of the form: host.
    Returns a list of host triples.
    """
    hosts = []
    try:
        with open(host_file) as fp:
            for line in fp:
                line = line.strip()
                if line or not line.startswith('#'):
                    hosts.append(line)
    except (OSError, IOError) as err:
        sys.stderr.write('Could not open hosts file: %s\n' % err)
        sys.exit(1)

    return hosts


def get_om_conf(section, key):
    """
    """
    try:
        om_conf = os.path.join(os.path.dirname(__file__), "../../etc/conf/om.conf")
        config = configparser.ConfigParser()
        config.read(om_conf)
        if section in config.sections() and key in config.options(section):
            return config.get(section, key).strip()
    except Exception:
        return None

    return None


def get_host_info(host_info):
    """
    :param host_info:
    :return: host_name, host_ip
    """
    config_map = "/var/chroot/opt/config_map/initDb"
    if os.path.exists(config_map):
        with open(config_map, 'r') as fp:
            context = json.load(fp)

        for inst in context['instances']:
            if host_info == inst['name']:
                return inst['serviceName'], socket.gethostbyname(inst['serviceName'])

    return host_info, socket.gethostbyname(host_info)


def write_file(file, lines):
    """
    :param file:
    :param lines:
    :return:
    """
    flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
    modes = stat.S_IWUSR | stat.S_IRUSR
    with os.fdopen(os.open(file, flags, modes), "wb") as fp:
        fp.write(lines)


def generate_send_chunks(src_path, dest_path, chunk_size=CHUNK_SIZE):
    """
    :param src_path:
    :param dest_path:
    :param chunk_size:
    :return:
    """
    sys.path.append(sys.path[0] + "/../../../")
    from scheduler.src.sche.grpc_server.ScheConn_pb2 import SendChunk

    with open(src_path, 'rb') as fp:
        while True:
            chunk = fp.read(chunk_size)
            if not chunk:
                break
            yield SendChunk(content=chunk, filename=dest_path)
