#!/usr/bin/env python3

import array
import struct
import uuid
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers import modes
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15


class MasterKeyEncryptor:
	def __init__(self, master_key_pubkey):
		self.master_key_pubkey = master_key_pubkey
		self.aes_key = os.urandom(32)
		self.actual_aes_key = self.aes_key

	def encrypt(self, data):
		# apply PKCS#7 padding
		pkcs7_padder = padding.PKCS7(algorithms.AES.block_size).padder()
		padded_data = pkcs7_padder.update(bytes(data)) + pkcs7_padder.finalize()

		# encrypt the original AES key
		encrypted_aes_key = self.master_key_pubkey.encrypt(self.aes_key, PKCS1v15())

		# generate a new IV for each run of encrypt
		aes_iv = os.urandom(int(algorithms.AES.block_size / 8))

		# encrypt padded data with /actual/ AES key (might be different than the original one!)
		aes_encryptor = Cipher(algorithms.AES(self.actual_aes_key), modes.CBC(aes_iv), default_backend()).encryptor()
		encrypted_data = aes_encryptor.update(bytes(padded_data)) + aes_encryptor.finalize()

		# encrypted data consists of AES IV, followed by encrypted AES key
		return aes_iv + encrypted_aes_key + encrypted_data

	def set_ta_uuid(self, ta_uuid):
		derived_aes_key = hashes.Hash(hashes.SHA256(), backend=default_backend())
		derived_aes_key.update(self.aes_key)
		derived_aes_key.update(ta_uuid.bytes_le)
		self.actual_aes_key = derived_aes_key.finalize()


def serialize_rsa_key(key):
	def add_to_buffer(buff, rsa_long):
		arr = bytearray()
		while rsa_long:
			arr = struct.pack("<B", rsa_long % 256) + arr
			rsa_long = rsa_long >> 8
		buff += bytearray(struct.pack('<I', len(arr)))
		for x in arr:
			buff += bytearray(struct.pack("B", x))

	buff = bytearray()

	add_to_buffer(buff, key.public_numbers.n)
	add_to_buffer(buff, key.public_numbers.e)
	add_to_buffer(buff, key.d)
	add_to_buffer(buff, key.p)
	add_to_buffer(buff, key.q)
	add_to_buffer(buff, key.dmp1)
	add_to_buffer(buff, key.dmq1)
	add_to_buffer(buff, key.iqmp)

	return buff


def get_args():
	from argparse import ArgumentParser

	parser = ArgumentParser()
	parser.add_argument('--key', required=True, help='master key used for encryption (PEM format)')

	group = parser.add_mutually_exclusive_group(required=True)
	group.add_argument('--inKey', '--in', dest='inKey', help='private RSA key that will be encrypted (PEM format)')
	group.add_argument('--inData', help='data that will be encrypted')

	parser.add_argument('--out', type=str, required=True, help='destination path for encrypted RSA key')
	parser.add_argument('--ta', type=str, required=False, help='TA uuid')
	return parser.parse_args()


def main():
	args = get_args()
	pub_key = None
	data_to_encrypt = None
	ta_uuid = None

	# encrypting key could be either a pubkey or privkey -- handle both
	try:
		with open(args.key, 'rb') as f:
			pub_key = serialization.load_pem_public_key(f.read(), backend=default_backend())
	except ValueError:
		with open(args.key, 'rb') as f:
			k = serialization.load_pem_private_key(f.read(), password=None, backend=default_backend())
			pub_key = k.public_key()

	if args.inKey:
		with open(args.inKey, 'rb') as f:
			prv_key = serialization.load_pem_private_key(f.read(), password=None, backend=default_backend())
			data_to_encrypt = serialize_rsa_key(prv_key.private_numbers())

	if args.inData:
		with open(args.inData, 'rb') as f:
			data_to_encrypt = f.read()

	if args.ta:
		ta_uuid = uuid.UUID(args.ta)

	e = MasterKeyEncryptor(pub_key)

	if ta_uuid:
		e.set_ta_uuid(ta_uuid)

	with open(args.out, 'wb+') as f:
		f.write(e.encrypt(data_to_encrypt))


if __name__ == "__main__":
	main()
