#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c): 2012-2017, Huawei Tech. Co., Ltd.
# Description  : Encryption and decryption character string
#############################################################################
from __future__ import absolute_import

try:
    import os
    import sys
    import subprocess
    import hashlib
    from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
    from cryptography.hazmat.backends import default_backend
except Exception as ie:
    sys.exit("[GAUSS-52200] : Unable to import module: %s." % str(ie))
#############################################################################
BRANCH_NAME_XUANYUAN = "XuanYuan"
BRANCH_NAME_GAUSSDB = "GaussDB"
BRANCH_NAME = ""
ITERATE_TIMES_XUANYUAN = 10000
ITERATE_TIMES_GAUSSDB = 50000
ITERATE_TIMES = 0
CIPHER_FILE_LEN_GAUSSDB = 40
MAX_CONTENT_LEN_FOR_GS_GUC = 15


##############################################################################
class AesCbcUtil(object):
    """
    aes  cbc tool
    """
    __BLOCK_SIZE_16 = BLOCK_SIZE_16 = 16

    @staticmethod
    def check_content_key(content, key):
        """
        check ase cbc content and key
        """
        if not isinstance(content, bytes):
            raise Exception("Content's type must be bytes.")

        if not isinstance(key, (bytes, str)):
            raise Exception("Bytes's type must be in (bytes, str).")

        iv_len = AesCbcUtil.__BLOCK_SIZE_16
        if len(content) < (iv_len + AesCbcUtil.__BLOCK_SIZE_16):
            raise Exception("Content's len must >= 32.")

    @staticmethod
    def aes_cbc_decrypt(cipher_txt, key):
        """
        aes cbc decrypt for cipher_txt and key
        """
        AesCbcUtil.check_content_key(cipher_txt, key)
        if type(key) == str:
            key = bytes(key)

        if BRANCH_NAME == BRANCH_NAME_XUANYUAN:
            # pre shared key iv
            iv_value = cipher_txt[AesCbcUtil.__BLOCK_SIZE_16 + 1 +
                                  AesCbcUtil.__BLOCK_SIZE_16 + 1:AesCbcUtil.__BLOCK_SIZE_16 + 1 +
                                                                 AesCbcUtil.__BLOCK_SIZE_16 + 1 +
                                                                 AesCbcUtil.__BLOCK_SIZE_16]
        else:
            # pre shared key iv
            iv_value = cipher_txt[AesCbcUtil.__BLOCK_SIZE_16 +
                                  1:AesCbcUtil.__BLOCK_SIZE_16 +
                                    AesCbcUtil.__BLOCK_SIZE_16 + 1]
        # pre shared key enctryt
        enc_content = cipher_txt[:AesCbcUtil.__BLOCK_SIZE_16]
        backend = default_backend()
        cipher = Cipher(algorithms.AES(key), modes.CBC(iv_value), backend=backend)
        decrypter = cipher.decryptor()
        dec_content = decrypter.update(enc_content) + decrypter.finalize()
        server_decipher_key = dec_content.decode("utf-8", "ignore").rstrip('\0')
        return server_decipher_key

    @staticmethod
    def aes_cbc_decrypt_with_path(path, key_name="server"):
        """
        aes cbc decrypt for one path
        """
        ######################################
        # decrypt:
        # 1. salt + content -> decrypt_key
        # 2. decrypt_key + iv + cipherText -> server_decipher_key
        ######################################
        key_cipher = os.path.join(path, '%s.key.cipher' % key_name)
        key_rand = os.path.join(path, '%s.key.rand' % key_name)
        with open(key_cipher, 'rb') as cipher_file:
            cipher_txt = cipher_file.read()
        with open(key_rand, 'rb') as rand_file:
            rand_txt = rand_file.read()
        if cipher_txt is None or cipher_txt == "":
            return None

        global BRANCH_NAME
        global ITERATE_TIMES
        if len(cipher_txt) > CIPHER_FILE_LEN_GAUSSDB:
            BRANCH_NAME = BRANCH_NAME_XUANYUAN
            ITERATE_TIMES = ITERATE_TIMES_XUANYUAN
        else:
            BRANCH_NAME = BRANCH_NAME_GAUSSDB
            ITERATE_TIMES = ITERATE_TIMES_GAUSSDB

        if BRANCH_NAME == BRANCH_NAME_XUANYUAN:
            # pre shared salt
            salt = cipher_txt[AesCbcUtil.__BLOCK_SIZE_16 + 1:AesCbcUtil.__BLOCK_SIZE_16 +
                                                             AesCbcUtil.__BLOCK_SIZE_16 + 1]
        else:
            # pre shared salt
            salt = rand_txt[AesCbcUtil.__BLOCK_SIZE_16 + 1:AesCbcUtil.__BLOCK_SIZE_16 +
                                                           AesCbcUtil.__BLOCK_SIZE_16 + 1]
        # pre shared contentd
        content = rand_txt[:AesCbcUtil.__BLOCK_SIZE_16]
        # worker key
        decrypt_key = hashlib.pbkdf2_hmac('sha256', content,
                                          salt, ITERATE_TIMES,
                                          AesCbcUtil.__BLOCK_SIZE_16)
        enc = AesCbcUtil.aes_cbc_decrypt(cipher_txt, decrypt_key)
        return enc

    @staticmethod
    def aes_cbc_decrypt_with_multi(root_path, key_name="server"):
        """
        decrypt message with multi depth
        """
        num = 0
        decrypt_str = ""
        while True:
            path = os.path.join(root_path, 'key_%s' % str(num))
            if not os.path.isdir(path):
                break
            part = AesCbcUtil.aes_cbc_decrypt_with_path(path, key_name)
            if part is None or part == "":
                break
            elif len(part) < AesCbcUtil.__BLOCK_SIZE_16 - 1:
                decrypt_str = decrypt_str + AesCbcUtil.aes_cbc_decrypt_with_path(path, key_name)
                break
            else:
                decrypt_str = decrypt_str + AesCbcUtil.aes_cbc_decrypt_with_path(path, key_name)
            num = num + 1

        return decrypt_str

    @staticmethod
    def aes_cbc_encrypt_with_multi(content, dest_path, logger, user_profile="", key_name="server"):
        """
        """
        logger.debug("Start to generate cluster user password files.")

        # create encrypt ca path
        AesCbcUtil.__create_empty_directory(dest_path)

        content_len = len(content)
        num = content_len // MAX_CONTENT_LEN_FOR_GS_GUC
        indx = 0
        while indx <= num:
            # generate part content
            start_indx = indx * MAX_CONTENT_LEN_FOR_GS_GUC
            end_indx = (indx + 1) * MAX_CONTENT_LEN_FOR_GS_GUC
            part_content = content[start_indx:end_indx]
            if part_content == "":
                break
            # execute encrypt
            encrypt_path = os.path.join(dest_path, 'key_%s' % str(indx))
            AesCbcUtil.__create_empty_directory(encrypt_path)
            logger.debug("Successfully created ca path %s." % encrypt_path)
            part_content = part_content.replace("\'", "\'\"\'\"\'")
            cmd = "gs_guc generate -o %s -S '%s' -D %s" % (key_name, part_content, encrypt_path)
            if user_profile != "":
                cmd = "source %s; %s " % (user_profile, cmd)
            AesCbcUtil.__exec_local_command(cmd)
            indx += 1
        logger.debug("Successfully generated cluster user password files.")

    @staticmethod
    def __create_empty_directory(path):
        """
        """
        cmd = " if [ -e %s ]; then rm -rf %s; fi" % (path, path)
        cmd += " && mkdir -p %s -m 700" % path
        AesCbcUtil.__exec_local_command(cmd)

    @staticmethod
    def __exec_local_command(cmd):
        """
        """
        (status, output) = subprocess.getstatusoutput(cmd)
        if status != 0:
            raise Exception("[GAUSS-51400] : Failed to execute the command: %s." % cmd + " Error: \n%s" % output)
