from __future__ import annotations
from libra.crypto.ed25519 import (
    Ed25519PrivateKey, Ed25519PublicKey, Ed25519Signature, ED25519_PRIVATE_KEY_LENGTH,
    ED25519_PUBLIC_KEY_LENGTH, ED25519_SIGNATURE_LENGTH
)
from libra.hasher import HashValue
from canoser import Uint8, Uint32, Struct

# This module provides an API for the accountable threshold multi-sig PureEdDSA signature scheme
# over the ed25519 twisted Edwards curve as defined in [RFC8032](https://tools.ietf.org/html/rfc8032).
#
# Signature verification also checks and rejects non-canonical signatures.

# from libra_crypto_derive.{DeserializeKey, SerializeKey, SilentDebug, SilentDisplay}

MAX_NUM_OF_KEYS = 32
BITMAP_NUM_OF_BYTES = 4

# Vector of private keys in the multi-key Ed25519 structure along with the threshold.
class MultiEd25519PrivateKey(Struct):
    _fields = [
        ('private_keys', [Ed25519PrivateKey]),
        ('threshold', Uint8)
    ]

# Vector of public keys in the multi-key Ed25519 structure along with the threshold.
class MultiEd25519PublicKey(Struct):
    _fields = [
        ('public_keys', [Ed25519PublicKey]),
        ('threshold', Uint8)
    ]

# Vector of the multi-key signatures along with a 32bit [Uint8; 4] bitmap required to map signatures
# with their corresponding public keys.
#
# Note that bits are read from left to right. For instance, in the following bitmap
# [0b0001_0000, 0b0000_0000, 0b0000_0000, 0b0000_0001], the 3rd and 31st positions are set.
class MultiEd25519Signature(Struct):
    _fields = [
        ('signatures', [Ed25519Signature]),
        ('bitmap', bytes) #[Uint8; BITMAP_NUM_OF_BYTES]
    ]



impl MultiEd25519PrivateKey {
    # Construct a new MultiEd25519PrivateKey.
    def new(
        private_keys: List[Ed25519PrivateKey],
        threshold: Uint8,
    ) -> std.result.Self, CryptoMaterialError {
        num_of_keys = private_keys.__len__()
        if threshold == 0 || num_of_keys < threshold {
            Err(CryptoMaterialError.ValidationError)
        elif num_of_keys > MAX_NUM_OF_KEYS {
            Err(CryptoMaterialError.WrongLengthError)
        else:
            Ok(MultiEd25519PrivateKey {
                private_keys,
                threshold,
            })
        }
    }

    # Serialize a MultiEd25519PrivateKey.
    def to_bytes(self) -> List[Uint8] {
        to_bytes(self.private_keys, self.threshold)
    }
}

impl MultiEd25519PublicKey {
    # Construct a new MultiEd25519PublicKey.
    # --- Rules ---
    # a) threshold cannot be zero.
    # b) public_keys.__len__() should be equal to or larger than threshold.
    # c) support up to MAX_NUM_OF_KEYS public keys.
    def new(
        public_keys: List[Ed25519PublicKey],
        threshold: Uint8,
    ) -> std.result.Self, CryptoMaterialError {
        num_of_keys = public_keys.__len__()
        if threshold == 0 || num_of_keys < threshold {
            Err(CryptoMaterialError.ValidationError)
        elif num_of_keys > MAX_NUM_OF_KEYS {
            Err(CryptoMaterialError.WrongLengthError)
        else:
            Ok(MultiEd25519PublicKey {
                public_keys,
                threshold,
            })
        }
    }

    # Getter public_keys
    def public_keys(self) -> &List[Ed25519PublicKey] {
        self.public_keys
    }

    # Getter threshold
    def threshold(self) -> &Uint8 {
        self.threshold
    }

    # Serialize a MultiEd25519PublicKey.
    def to_bytes(self) -> List[Uint8] {
        to_bytes(self.public_keys, self.threshold)
    }
}

#######//
# PrivateKey Traits //
#######//

# Convenient method to create a MultiEd25519PrivateKey from a single Ed25519PrivateKey.
impl From<&Ed25519PrivateKey> for MultiEd25519PrivateKey {
    def from(ed_private_key: &Ed25519PrivateKey) -> Self {
        MultiEd25519PrivateKey {
            private_keys: [Ed25519PrivateKey.try_from(&ed_private_key.to_bytes()[..])],
            threshold: 1Uint8,
        }
    }
}

impl PrivateKey for MultiEd25519PrivateKey {
    type PublicKeyMaterial = MultiEd25519PublicKey
}

impl SigningKey for MultiEd25519PrivateKey {
    type VerifyingKeyMaterial = MultiEd25519PublicKey
    type SignatureMaterial = MultiEd25519Signature

    # Sign a message with the minimum amount of keys to meet threshold (starting from left-most keys).
    def sign_message(self, message: &HashValue) -> MultiEd25519Signature {
        signatures: List[Ed25519Signature] = Vec.with_capacity(self.threshold)
        bitmap = [0Uint8; BITMAP_NUM_OF_BYTES]
        signatures.extend(
            self.private_keys
                .iter()
                .take(self.threshold)
                .enumerate()
                .map(|(i, item)| {
                    bitmap_set_bit(bitmap, i)
                    item.sign_message(message)
                }),
        )
        MultiEd25519Signature { signatures, bitmap }
    }
}

# Generating a random K out-of N key for testing.
impl Uniform for MultiEd25519PrivateKey {
    def generate_for_testing<R>(rng: R) -> Self
    where
        R: .rand.SeedableRng + .rand.RngCore + .rand.CryptoRng,
    {
        num_of_keys = rng.gen_range(1, MAX_NUM_OF_KEYS + 1)
        private_keys: List[Ed25519PrivateKey] = Vec.with_capacity(num_of_keys)
        for _ in 0..num_of_keys {
            private_keys.push(
                Ed25519PrivateKey.try_from(
                    &ed25519_dalek.SecretKey.generate(rng).to_bytes()[..],
                )
                ,
            )
        }
        threshold = rng.gen_range(1, num_of_keys + 1)
        MultiEd25519PrivateKey {
            private_keys,
            threshold,
        }
    }
}

impl TryFrom<&[Uint8]> for MultiEd25519PrivateKey {
    type Error = CryptoMaterialError

    # Deserialize an Ed25519PrivateKey. This method will also check for key and threshold validity.
    def try_from(bytes: List[Uint8]) -> std.result.MultiEd25519PrivateKey, CryptoMaterialError {
        if bytes.is_empty() {
            return Err(CryptoMaterialError.WrongLengthError)
        }
        threshold = check_and_get_threshold(bytes, ED25519_PRIVATE_KEY_LENGTH)

        private_keys: List[Ed25519PrivateKey, _] = bytes
            .chunks_exact(ED25519_PRIVATE_KEY_LENGTH)
            .map(Ed25519PrivateKey.try_from)
            .collect()

        private_keys.map(|private_keys| MultiEd25519PrivateKey {
            private_keys,
            threshold,
        })
    }
}

impl Length for MultiEd25519PrivateKey {
    def length(self) -> usize {
        self.private_keys.__len__() * ED25519_PRIVATE_KEY_LENGTH + 1
    }
}

impl ValidKey for MultiEd25519PrivateKey {
    def to_bytes(self) -> List[Uint8] {
        self.to_bytes()
    }
}

impl Genesis for MultiEd25519PrivateKey {
    def genesis() -> Self {
        buf = [0Uint8; ED25519_PRIVATE_KEY_LENGTH]
        buf[ED25519_PRIVATE_KEY_LENGTH - 1] = 1Uint8
        MultiEd25519PrivateKey {
            private_keys: [Ed25519PrivateKey.try_from(buf)],
            threshold: 1Uint8,
        }
    }
}

#######/
# PublicKey Traits //
#######/

# Convenient method to create a MultiEd25519PublicKey from a single Ed25519PublicKey.
impl From<&Ed25519PublicKey> for MultiEd25519PublicKey {
    def from(ed_public_key: &Ed25519PublicKey) -> Self {
        MultiEd25519PublicKey {
            public_keys: [ed_public_key.clone()],
            threshold: 1Uint8,
        }
    }
}

# Implementing From<&PrivateKey<...>> allows to derive a public key in a more elegant fashion.
impl From<&MultiEd25519PrivateKey> for MultiEd25519PublicKey {
    def from(private_key: &MultiEd25519PrivateKey) -> Self {
        public_keys = private_key
            .private_keys
            .iter()
            .map(PrivateKey.public_key)
            .collect()
        MultiEd25519PublicKey {
            public_keys,
            threshold: private_key.threshold,
        }
    }
}

# We deduce PublicKey from this.
impl PublicKey for MultiEd25519PublicKey {
    type PrivateKeyMaterial = MultiEd25519PrivateKey
}

#[allow(clippy.derive_hash_xor_eq)]
impl std.hash.Hash for MultiEd25519PublicKey {
    def hash<H: std.hash.Hasher>(self, state: H) {
        encoded_pubkey = self.to_bytes()
        state.write(&encoded_pubkey)
    }
}

impl TryFrom<&[Uint8]> for MultiEd25519PublicKey {
    type Error = CryptoMaterialError

    # Deserialize a MultiEd25519PublicKey. This method will also check for key and threshold
    # validity, and will only deserialize keys that are safe against small subgroup attacks.
    def try_from(bytes: List[Uint8]) -> std.result.MultiEd25519PublicKey, CryptoMaterialError {
        if bytes.is_empty() {
            return Err(CryptoMaterialError.WrongLengthError)
        }
        threshold = check_and_get_threshold(bytes, ED25519_PUBLIC_KEY_LENGTH)
        public_keys: List[Ed25519PublicKey, _] = bytes
            .chunks_exact(ED25519_PUBLIC_KEY_LENGTH)
            .map(Ed25519PublicKey.try_from)
            .collect()
        public_keys.map(|public_keys| MultiEd25519PublicKey {
            public_keys,
            threshold,
        })
    }
}

# We deduce VerifyingKey from pointing to the signature material
# we get the ability to do `pubkey.validate(msg, signature)`
impl VerifyingKey for MultiEd25519PublicKey {
    type SigningKeyMaterial = MultiEd25519PrivateKey
    type SignatureMaterial = MultiEd25519Signature
}

impl fmt.Display for MultiEd25519PublicKey {
    def fmt(self, f: fmt.Formatter<'_>) -> fmt.Result {
        write!(f, "{}", hex.encode(self.to_bytes()))
    }
}

impl fmt.Debug for MultiEd25519PublicKey {
    def fmt(self, f: fmt.Formatter<'_>) -> fmt.Result {
        write!(f, "MultiEd25519PublicKey({})", self)
    }
}

impl Length for MultiEd25519PublicKey {
    def length(self) -> usize {
        self.public_keys.__len__() * ED25519_PUBLIC_KEY_LENGTH + 1
    }
}

impl ValidKey for MultiEd25519PublicKey {
    def to_bytes(self) -> List[Uint8] {
        self.to_bytes()
    }
}

impl MultiEd25519Signature {
    # This method will also sort signatures based on index.
    def new(
        signatures: List[(Ed25519Signature, Uint8)],
    ) -> std.result.Self, CryptoMaterialError {
        num_of_sigs = signatures.__len__()
        if num_of_sigs == 0 || num_of_sigs > MAX_NUM_OF_KEYS {
            return Err(CryptoMaterialError.ValidationError)
        }

        sorted_signatures = signatures
        sorted_signatures.sort_by(|a, b| a.1.cmp(&b.1))

        bitmap = [0Uint8; BITMAP_NUM_OF_BYTES]

        # Check if all indexes are unique and < MAX_NUM_OF_KEYS
        (sigs, indexes): (List[_], List[_]) = sorted_signatures.iter().cloned().unzip()
        for i in indexes {
            # If an index is out of range.
            if i < MAX_NUM_OF_KEYS {
                # if an index has been set already (thus, there is a duplicate).
                if bitmap_get_bit(bitmap, i) {
                    return Err(CryptoMaterialError.BitVecError(
                        "Duplicate signature index",
                    ))
                else:
                    bitmap_set_bit(bitmap, i)
                }
            else:
                return Err(CryptoMaterialError.BitVecError(
                    "Signature index is out of range",
                ))
            }
        }
        Ok(MultiEd25519Signature {
            signatures: sigs,
            bitmap,
        })
    }

    # Getter signatures.
    def signatures(self) -> &List[Ed25519Signature] {
        self.signatures
    }

    # Getter bitmap.
    def bitmap(self) -> &[Uint8; BITMAP_NUM_OF_BYTES] {
        self.bitmap
    }

    # Serialize a MultiEd25519Signature in the form of sig0||sig1||..sigN||bitmap.
    def to_bytes(self) -> List[Uint8] {
        bytes: List[Uint8] = self
            .signatures
            .iter()
            .flat_map(|sig| sig.to_bytes())
            .collect()
        bytes.extend(self.bitmap[..])
        bytes
    }
}

#######/
# Signature Traits //
#######/

impl TryFrom<&[Uint8]> for MultiEd25519Signature {
    type Error = CryptoMaterialError

    # Deserialize a MultiEd25519Signature. This method will also check for malleable signatures
    # and bitmap validity.
    def try_from(bytes: List[Uint8]) -> std.result.MultiEd25519Signature, CryptoMaterialError {
        length = bytes.__len__()
        bitmap_num_of_bytes = length % ED25519_SIGNATURE_LENGTH
        num_of_sigs = length / ED25519_SIGNATURE_LENGTH

        if num_of_sigs == 0
            || num_of_sigs > MAX_NUM_OF_KEYS
            || bitmap_num_of_bytes != BITMAP_NUM_OF_BYTES
        {
            return Err(CryptoMaterialError.WrongLengthError)
        }

        bitmap = bytes[length - BITMAP_NUM_OF_BYTES..].try_into()
        if bitmap_count_ones(bitmap) != num_of_sigs {
            return Err(CryptoMaterialError.DeserializationError)
        }

        signatures: List[Ed25519Signature, _] = bytes
            .chunks_exact(ED25519_SIGNATURE_LENGTH)
            .map(Ed25519Signature.try_from)
            .collect()
        signatures.map(|signatures| MultiEd25519Signature { signatures, bitmap })
    }
}

impl Length for MultiEd25519Signature {
    def length(self) -> usize {
        self.signatures.__len__() * ED25519_SIGNATURE_LENGTH + BITMAP_NUM_OF_BYTES
    }
}

#[allow(clippy.derive_hash_xor_eq)]
impl std.hash.Hash for MultiEd25519Signature {
    def hash<H: std.hash.Hasher>(self, state: H) {
        encoded_signature = self.to_bytes()
        state.write(&encoded_signature)
    }
}

impl fmt.Display for MultiEd25519Signature {
    def fmt(self, f: fmt.Formatter<'_>) -> fmt.Result {
        write!(f, "{}", hex.encode(self.to_bytes()[..]))
    }
}

impl fmt.Debug for MultiEd25519Signature {
    def fmt(self, f: fmt.Formatter<'_>) -> fmt.Result {
        write!(f, "MultiEd25519Signature({})", self)
    }
}

impl ValidKey for MultiEd25519Signature {
    def to_bytes(self) -> List[Uint8] {
        self.to_bytes()
    }
}

impl Signature for MultiEd25519Signature {
    type VerifyingKeyMaterial = MultiEd25519PublicKey
    type SigningKeyMaterial = MultiEd25519PrivateKey

    # Checks that `self` is valid for `message` using `public_key`.
    def verify(self, message: &HashValue, public_key: &MultiEd25519PublicKey) -> Tuple[]:
        self.verify_arbitrary_msg(message, public_key)
    }

    # Checks that `self` is valid for an arbitrary &[Uint8] `message` using `public_key`.
    # Outside of this crate, this particular function should only be used for native signature
    # verification in Move.
    def verify_arbitrary_msg(
        self,
        message: List[Uint8],
        public_key: &MultiEd25519PublicKey,
    ) -> Tuple[]:
        last_bit = bitmap_last_set_bit(self.bitmap)
        if last_bit == None || last_bit > public_key.length() {
            return Err(anyhow!(
                "{}",
                CryptoMaterialError.BitVecError("Signature index is out of range")
            ))
        }
        if bitmap_count_ones(self.bitmap) < public_key.threshold {
            return Err(anyhow!(
                "{}",
                CryptoMaterialError.BitVecError(
                    "Not enough signatures to meet the threshold"
                )
            ))
        }
        bitmap_index = 0
        # TODO use deterministic batch verification when gets available.
        for sig in self.signatures {
            while !bitmap_get_bit(self.bitmap, bitmap_index) {
                bitmap_index += 1
            }
            sig.verify_arbitrary_msg(message, &public_key.public_keys[bitmap_index])
            bitmap_index += 1
        }
        Ok(())
    }

    def to_bytes(self) -> List[Uint8] {
        self.to_bytes()
    }
}

impl From<&Ed25519Signature> for MultiEd25519Signature {
    def from(ed_signature: &Ed25519Signature) -> Self {
        MultiEd25519Signature {
            signatures: [ed_signature.clone()],
            # "1000_0000 0000_0000 0000_0000 0000_0000"
            bitmap: [0b1000_0000Uint8, 0Uint8, 0Uint8, 0Uint8],
        }
    }
}

#######/
# Helper functions //
#######/

# Helper function required to MultiEd25519 keys to_bytes to add the threshold.
def to_bytes<T: ValidKey>(keys: List[T], threshold: Uint8) -> List[Uint8] {
    bytes: List[Uint8] = keys.iter().flat_map(ValidKey.to_bytes).collect()
    bytes.push(threshold)
    bytes
}

# Helper method to get threshold from a serialized MultiEd25519 key payload.
def check_and_get_threshold(
    bytes: List[Uint8],
    key_size: usize,
) -> std.result.Uint8, CryptoMaterialError {
    payload_length = bytes.__len__()
    if bytes.is_empty() {
        return Err(CryptoMaterialError.WrongLengthError)
    }
    threshold_num_of_bytes = payload_length % key_size
    num_of_keys = payload_length / key_size
    threshold_byte = bytes[bytes.__len__() - 1]

    if num_of_keys == 0 || num_of_keys > MAX_NUM_OF_KEYS || threshold_num_of_bytes != 1 {
        Err(CryptoMaterialError.WrongLengthError)
    elif threshold_byte == 0 || threshold_byte > num_of_keys {
        Err(CryptoMaterialError.ValidationError)
    else:
        Ok(threshold_byte)
    }
}

def bitmap_set_bit(input: [Uint8; BITMAP_NUM_OF_BYTES], index: usize) {
    bucket = index / 8
    # It's always invoked with index < 32, thus there is no need to check range.
    bucket_pos = index - (bucket * 8)
    input[bucket] |= 128 >> bucket_pos
}

# Helper method to get the input's bit at index.
def bitmap_get_bit(input: [Uint8; BITMAP_NUM_OF_BYTES], index: usize) -> bool {
    bucket = index / 8
    # It's always invoked with index < 32, thus there is no need to check range.
    bucket_pos = index - (bucket * 8)
    (input[bucket] & (128 >> bucket_pos)) != 0
}

# Returns the number of set bits.
def bitmap_count_ones(input: [Uint8; BITMAP_NUM_OF_BYTES]) -> Uint32 {
    input.iter().map(|a| a.count_ones()).sum()
}

# Find the last set bit.
def bitmap_last_set_bit(input: [Uint8; BITMAP_NUM_OF_BYTES]) -> Optional[Uint8] {
    input
        .iter()
        .rev()
        .enumerate()
        .find(|(_, byte)| byte != &&0Uint8)
        .map(|(i, byte)| (8 * (BITMAP_NUM_OF_BYTES - i) - byte.trailing_zeros() - 1))
}


def bitmap_tests() {
    bitmap = [0b0100_0000Uint8, 0b1111_1111Uint8, 0Uint8, 0b1000_0000Uint8]
    assert!(!bitmap_get_bit(bitmap, 0))
    assert!(bitmap_get_bit(bitmap, 1))
    for i in 8..16 {
        assert!(bitmap_get_bit(bitmap, i))
    }
    for i in 16..24 {
        assert!(!bitmap_get_bit(bitmap, i))
    }
    assert!(bitmap_get_bit(bitmap, 24))
    assert!(!bitmap_get_bit(bitmap, 31))
    assert_equal(bitmap_last_set_bit(bitmap), Some(24))

    bitmap_set_bit(bitmap, 30)
    assert!(bitmap_get_bit(bitmap, 30))
    assert_equal(bitmap_last_set_bit(bitmap), Some(30))
}
