#!/usr/bin/env python3

import argparse
import base64
import struct
import sys

# python3 -m pip install pycryptodomex
from Cryptodome.Hash import HMAC, SHA1, SHA256
from Cryptodome.PublicKey import RSA
from Cryptodome.Signature import pkcs1_15

_CONFIG = {'DEBUG': False}
MAX_BIN_SIZE = 110 * 1024

def debug(msg):
    """
    Outputs the msg string to stdout if DEBUG is enabled (via -d)
    """
    if _CONFIG['DEBUG']:
        sys.stderr.write(str(msg) + '\n')

class ImageFile:
    """
    Signed binary image
    """
    def __init__(self, filename, max_size_kb):
        self._filename = filename
        self._bytes_written = 0
        if self._filename is None:
           self._of = sys.stdout
        else:
           self._of = open(self._filename, "wb")
        self._max_size_kb = max_size_kb
        self._bytes = bytearray()

        debug("%8s %20s: [%6s] %s" % ('OFFSET', 'TYPE', 'SIZE', 'DESCRIPTION'))
        debug("")

    def append(self, data):
        """
        Appends a blob of binary data to the image
        """
        self._bytes.extend(data)

    def append_file(self, source_file):
        """
        Appends the binary contents of source_file to the current image. If
        source_file is None then a base64 encoded blob is read from stdin.
        """
        if source_file is None:
           b64 = ""
           for l in sys.stdin.readlines():
              b64 += l
           file_bytes = base64.b64decode(b64)
        else:
           file_bytes = bytearray(open(source_file, 'rb').read())
        size = len(file_bytes)
        debug("%08x %20s: [%6d] %s" % (self.pos(), 'FILE', size, source_file))
        self.append(file_bytes)

    def append_keynum(self, keynum):
        """
        Appends a given key number as a 32-bit LE integer.
        """
        if (keynum < 0 or keynum > 4) and keynum != 16:
            raise Exception("Bad key number %d" % keynum)
        debug("%08x %20s: [%6d] %d" % (self.pos(), "KEYNUM", 4, keynum))
        self.append(struct.pack('<i', keynum))

    def append_version(self, version):
        """
        Appends a version number, 0-32 to avoid firmware rollback, a Raspberry Pi
        with OTP bit n set will not execute a firmware without bit n set.
        """
        if version < 0 or version > 32:
            raise Exception("Bad version number %d must be between 0-32" % version)
        debug("%08x %20s: [%6d] %d" % (self.pos(), "VERSION", 4, version))
        self.append(struct.pack('<i', version))

    def append_length(self):
        """
        Appends the current length to the image as a 32-bit LE integer
        """
        length = len(self._bytes)
        debug("%08x %20s: [%6d] %d" % (self.pos(), "LEN", 4, length))
        self.append(struct.pack('<i', length))

    def append_public_key(self, pem_file):
        """
        Converts an RSA public key into the format expected by the ROM
        and appends it to the image.

        If a private key is passed then only the public key part is extracted.
        """
        arr = bytearray()
        key = RSA.importKey(open(pem_file, 'r').read())

        if key.size_in_bits() != 2048:
            raise Exception("RSA key size must be 2048")

        # Export N and E in little endian format
        arr.extend(key.n.to_bytes(256, byteorder='little'))
        arr.extend(key.e.to_bytes(8, byteorder='little'))
        debug("%08x %20s: [%6d] %s" % (self.pos(), 'RSA', len(arr), pem_file))
        self.append(arr)

    def append_rsa_signature(self, digest_alg, private_pem):
        """
        Append a RSA 2048 signature of the SHA256 of the data so far
        """
        key = RSA.importKey(open(private_pem, 'r').read())
        if key.size_in_bits() != 2048:
            raise Exception("RSA key size must be 2048")

        if digest_alg == 'sha256':
            digest = SHA256.new(self._bytes)
        elif digest_alg == 'sha1':
            digest = SHA1.new(self._bytes)

        signature = pkcs1_15.new(key).sign(digest)
        self.append(signature)
        debug("%08x %20s: [%6d] digest %s signature %s" % (self.pos(), 'RSA2048 - SHA256', len(signature), digest.hexdigest(), signature.hex()))

    def append_digest(self, digest_alg, hmac_keyfile):
        """
        Appends the hash/digest to the image
        """
        if hmac_keyfile is not None:
            hmac_key = str(open(hmac_keyfile, 'r').read()).strip()
            expected_keylen = 40
            if len(hmac_key) != expected_keylen:
                raise Exception("Bad key length %d expected %d" % (len(hmac_key), expected_keylen))

        if digest_alg == 'hmac-sha256':
            digest = HMAC.new(base64.b16decode(hmac_key, True), self._bytes, digestmod=SHA256)
        elif digest_alg == 'hmac-sha1':
            digest = HMAC.new(base64.b16decode(hmac_key, True), self._bytes, digestmod=SHA1)
        elif digest_alg == 'sha256':
            digest = SHA256.new(self._bytes)
        elif digest_alg == 'sha1':
            digest = SHA1.new(self._bytes)
        else:
            raise Exception("Digest not supported %s" % (digest_alg))

        debug("%08x %20s: [%6d] %s" % (self.pos(), digest_alg, len(digest.digest()), digest.hexdigest()))
        self.append(digest.digest())

    def pos(self):
        return len(self._bytes)

    def write(self):
        if len(self._bytes) > self._max_size_kb:
            raise Exception("Signed binary size %d is too large. Max size %d" % (len(self._bytes), MAX_BIN_SIZE))
        debug("Image size %d" % len(self._bytes))
        if self._filename is None:
           self._of.buffer.write(base64.b64encode(self._bytes))
        else:
           self._of.write(self._bytes)

    def close(self):
        self._of.close()

def create_2711_image(output, bootcode, private_key, private_keynum, hmac):
    """
    Create a 2711 C0 secure-boot compatible seconds stage signed binary.
    """
    image = ImageFile(output, MAX_BIN_SIZE)
    image.append_file(bootcode)
    image.append_length()
    image.append_keynum(private_keynum)
    image.append_rsa_signature('sha1', private_key)
    image.append_digest('hmac-sha1', hmac)
    image.write()
    image.close()

def create_2712_image(output, bootcode, private_key, private_keynum, private_version):
    """
    Create 2712 signed bootloader. The HMAC is removed and the full public key is appended.
    """
    image = ImageFile(output, MAX_BIN_SIZE)
    image.append_file(bootcode)
    image.append_length()
    image.append_keynum(private_keynum)
    image.append_version(private_version)
    image.append_rsa_signature('sha256', private_key)
    image.append_public_key(private_key)
    image.write()
    image.close()

def main():
    help_text = """
    Signs a second stage bootloader image.

    Examples:
        2711 mode:
        rpi-sign-bootcode --debug -c 2711 -i bootcode.bin.clr -o bootcode.bin -k 2711_rsa_priv_0.pem -n 0 -m bootcode-production.key

        2712 C1 and D0 mode:
          * HMAC not included on 2712
          * RSA public key included - ROM just contains the hashes of the RPi public keys.

        Customer counter-signed signed:
          * Exactly the same as Raspberry Pi signing but the input is the Raspberry Pi signed bootcode.bin
          * The key number will probably always be 16 to indicate a customer signing

        rpi-sign-bootcode --debug -c 2712 -i bootcode.bin.sign2 -o bootcode.bin -k customer.pem
    """
    parser = argparse.ArgumentParser(help_text)
    parser.add_argument('-o', '--output', required=False, help='Output filename . If not specified the signed images is written to stdout in base64 format')
    parser.add_argument('-c', '--chip', required=True, type=int, help='Chip number')
    parser.add_argument('-i', '--input', required=False, help='Path of the unsigned bootcode.bin file OR RPi signed bootcode file sign with the customer key. If NULLL the binary is read from stdin in base64 format')
    parser.add_argument('-m', '--hmac', required=False, help='Path of the HMAC key file')
    parser.add_argument('-k', '--private-key', dest='private_key', required=True, help='Path of RSA private key (PEM format)')
    parser.add_argument('-n', '--private-keynum', dest='private_keynum', required=False, default=0, type=int, help='ROM key index for RPi signing stage')
    parser.add_argument('-d', '--debug', action='store_true')
    parser.add_argument('-v', '--private-version', dest='private_version', required=True, type=int, help='Version of firmware, stops firmware rollback, only valid 0-31')

    args = parser.parse_args()
    _CONFIG['DEBUG'] = args.debug
    if args.chip == 2711:
        if args.hmac is None:
            raise Exception("HMAC key requried for 2711")
        create_2711_image(args.output, args.input, args.private_key, args.private_keynum, args.hmac)
    elif args.chip == 2712:
        create_2712_image(args.output, args.input, args.private_key, args.private_keynum, args.private_version)

if __name__ == '__main__':
    main()
