# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.

# cython: language_level=3, boundscheck=True
import numpy as np
from typing import Optional, Tuple

from .bitmask import EnhancedBitmask, EnhancedBitmaskLike
from .bitmask import build_cache, check_eb_overunderflow
from .bitmask import eb
from .bitmask_globals import __base_dtype__
from .bitmask_lite import LiteBitmask, IntervalList


default_hex_block_size = 256
hex_map = "0123456789abcdef"


def _hex_string_to_bit_intervals_v0(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]:
    """
    Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots.
    """
    hex_str_len = len(hex_str)
    hex_str_num_bits = hex_str_len * 4
    if limit is None:
        limit = hex_str_num_bits
    else:
        assert limit <= hex_str_num_bits
    bin_str = bin(int(hex_str, 16))[2:]
    intervals: IntervalList = []
    intervals_append = intervals.append
    bit_count = 0
    bit_offset = hex_str_num_bits - len(bin_str)
    start_loc = 0
    prev_val = 0
    loc = 0
    val = 0
    for loc, val in enumerate((int(bit) for bit in bin_str), bit_offset):
        if loc == limit:
            loc -= 1
            val = prev_val
            break
        if val == prev_val:
            continue
        if val == 0:
            intervals_append([start_loc, loc - 1])
            bit_count += loc - start_loc
        else:
            start_loc = loc
        prev_val = val
    if val == prev_val == 1:
        intervals_append([start_loc, loc])
        bit_count += loc - start_loc + 1
    return bit_count, intervals


def _hex_string_to_bit_intervals_v1(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]:
    """
    Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots.
    """
    hex_str_len = len(hex_str)
    hex_str_num_bits = hex_str_len * 4
    if limit is None:
        limit = hex_str_num_bits
    else:
        assert limit <= hex_str_num_bits
    bin_str = bin(int(hex_str, 16))[2:]
    intervals: IntervalList = []
    intervals_append = intervals.append
    bit_count = 0
    bit_offset = hex_str_num_bits - len(bin_str)
    start_loc = 0
    prev_val = "0"
    loc = 0
    val = "0"
    for loc, val in enumerate(bin_str, bit_offset):
        if loc == limit:
            loc -= 1
            val = prev_val
            break
        if val == prev_val:
            continue
        if val == "0":
            intervals_append([start_loc, loc - 1])
            bit_count += loc - start_loc
        else:
            start_loc = loc
        prev_val = val
    if val == prev_val == "1":
        intervals_append([start_loc, loc])
        bit_count += loc - start_loc + 1
    return bit_count, intervals


def _hex_string_to_bit_intervals_v2(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]:
    """
    Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots.
    """
    hex_str_len = len(hex_str)
    hex_str_num_bits = hex_str_len * 4
    if limit is None:
        limit = hex_str_num_bits
    else:
        assert limit <= hex_str_num_bits
    intervals: IntervalList = []
    intervals_append = intervals.append
    bit_count = 0
    nybble_loc = 0
    start_loc = 0
    prev_val = "0"
    loc = 0
    val = "0"
    for nybble_char in hex_str:
        if nybble_char == "0":
            if prev_val == "1":
                intervals_append([start_loc, nybble_loc - 1])
                bit_count += nybble_loc - start_loc
                prev_val = "0"
        elif nybble_char == "f":
            if prev_val == "0":
                start_loc = nybble_loc
                prev_val = "1"
            loc = nybble_loc + 3
            val = "1"
        else:
            bin_str = bin(int(nybble_char, 16))[2:].zfill(4)
            for loc, val in enumerate(bin_str, nybble_loc):
                # print(f"{bin_str=}, {nybble_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
                if loc >= limit:
                    val = prev_val
                    break
                if val == prev_val:
                    continue
                if val == "0":
                    intervals_append([start_loc, loc - 1])
                    bit_count += loc - start_loc
                else:
                    start_loc = loc
                prev_val = val
        if loc >= limit:
            loc = limit - 1
            break
        nybble_loc += 4
    if val == prev_val == "1":
        intervals_append([start_loc, loc])
        bit_count += loc - start_loc + 1
    return bit_count, intervals


def _hex_string_to_bit_intervals_v3(
    hex_str: str, limit: Optional[int] = None, hex_block_size: Optional[int] = None
) -> Tuple[int, IntervalList]:
    """
    Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to
    create a LiteBitmaskSlots.
    """
    hex_str_len = len(hex_str)
    hex_str_num_bits = hex_str_len * 4
    if limit is None:
        limit = hex_str_num_bits
    else:
        assert limit <= hex_str_num_bits
    if hex_block_size == None:
        # TODO: adjust block size based on the limit?
        hex_block_size = default_hex_block_size
    intervals: IntervalList = []
    intervals_append = intervals.append
    bit_count = 0
    block_bit_loc = 0
    start_loc = 0
    prev_val = "0"
    loc = 0
    val = "0"
    hex_block_zeros = "0" * hex_block_size
    hex_block_num_bits = hex_block_size * 4
    hex_block_start_offset = 0
    while hex_block_start_offset < hex_str_len:
        hex_block_end_offset = hex_block_start_offset + hex_block_size
        if hex_block_end_offset > hex_str_len:
            hex_block_end_offset = hex_str_len
            hex_block_size = hex_block_end_offset - hex_block_start_offset
            hex_block_num_bits = hex_block_size * 4
            hex_block_zeros = "0" * hex_block_size
        hex_block_str = hex_str[hex_block_start_offset:hex_block_end_offset]
        if hex_block_str == hex_block_zeros:
            if prev_val == "1":
                intervals_append([start_loc, block_bit_loc - 1])
                bit_count += block_bit_loc - start_loc
                prev_val = "0"
        else:
            bin_str = bin(int(hex_block_str, 16))[2:].zfill(hex_block_num_bits)
            for loc, val in enumerate(bin_str, block_bit_loc):
                # print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
                if loc >= limit:
                    loc -= 1
                    val = prev_val
                    break
                if val == prev_val:
                    continue
                if val == "0":
                    intervals_append([start_loc, loc - 1])
                    bit_count += loc - start_loc
                else:
                    start_loc = loc
                prev_val = val
        block_bit_loc += hex_block_num_bits
        hex_block_start_offset += hex_block_size
        if block_bit_loc >= limit:
            break
    if val == prev_val == "1":
        intervals_append([start_loc, loc])
        bit_count += loc - start_loc + 1
    return bit_count, intervals


def _hex_string_to_bit_intervals_v4(
    hex_str: str, limit: Optional[int] = None, hex_block_size: Optional[int] = None
) -> Tuple[int, IntervalList]:
    """
    Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to
    create a LiteBitmaskSlots.
    """
    hex_str_len = len(hex_str)
    hex_str_num_bits = hex_str_len * 4
    if limit is None:
        limit = hex_str_num_bits
    else:
        assert limit <= hex_str_num_bits
    if hex_block_size == None:
        # TODO: adjust block size based on the limit?
        hex_block_size = default_hex_block_size
    intervals: IntervalList = []
    intervals_append = intervals.append
    bit_count = 0
    block_bit_loc = 0
    start_loc = 0
    prev_val = "0"
    loc = 0
    val = "0"
    hex_block_num_bits = hex_block_size * 4
    hex_block_start_offset = 0
    while hex_block_start_offset < hex_str_len:
        hex_block_end_offset = hex_block_start_offset + hex_block_size
        if hex_block_end_offset > hex_str_len:
            hex_block_end_offset = hex_str_len
            hex_block_size = hex_block_end_offset - hex_block_start_offset
            hex_block_num_bits = hex_block_size * 4
        hex_block_int = int(hex_str[hex_block_start_offset:hex_block_end_offset], 16)
        # print(f"{hex_block_start_offset=}, {hex_block_int=}")
        if hex_block_int == 0:
            if prev_val == "1":
                intervals_append([start_loc, block_bit_loc - 1])
                bit_count += block_bit_loc - start_loc
                prev_val = "0"
        else:
            bin_str = bin(hex_block_int)[2:].zfill(hex_block_num_bits)
            for loc, val in enumerate(bin_str, block_bit_loc):
                # print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
                if loc >= limit:
                    loc -= 1
                    val = prev_val
                    break
                if val == prev_val:
                    continue
                if val == "0":
                    intervals_append([start_loc, loc - 1])
                    bit_count += loc - start_loc
                else:
                    start_loc = loc
                prev_val = val
        block_bit_loc += hex_block_num_bits
        hex_block_start_offset += hex_block_size
        if block_bit_loc >= limit:
            break
    if val == prev_val == "1":
        intervals_append([start_loc, loc])
        bit_count += loc - start_loc + 1
    return bit_count, intervals


def _hex_string_to_bit_intervals_v5(
    hex_str: str, limit: Optional[int] = None, hex_block_size: Optional[int] = None
) -> Tuple[int, IntervalList]:
    """
    Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to
    create a LiteBitmaskSlots.
    """
    hex_str_len = len(hex_str)
    hex_str_num_bits = hex_str_len * 4
    if limit is None:
        limit = hex_str_num_bits
    else:
        assert limit <= hex_str_num_bits
    if hex_block_size == None:
        # TODO: adjust block size based on the limit?
        hex_block_size = default_hex_block_size
    intervals: IntervalList = []
    intervals_append = intervals.append
    bit_count = 0
    block_bit_loc = 0
    start_loc = 0
    prev_val = "0"
    loc = 0
    val = "0"
    hex_block_0 = "0" * hex_block_size
    hex_block_f = "f" * hex_block_size
    hex_block_num_bits = hex_block_size * 4
    hex_block_start_offset = 0
    while hex_block_start_offset < hex_str_len:
        hex_block_end_offset = hex_block_start_offset + hex_block_size
        if hex_block_end_offset > hex_str_len:
            hex_block_end_offset = hex_str_len
            hex_block_size = hex_block_end_offset - hex_block_start_offset
            hex_block_num_bits = hex_block_size * 4
            hex_block_0 = "0" * hex_block_size
            hex_block_f = "f" * hex_block_size
        hex_block_str = hex_str[hex_block_start_offset:hex_block_end_offset]
        if hex_block_str == hex_block_0:
            if prev_val == "1":
                intervals_append([start_loc, block_bit_loc - 1])
                bit_count += block_bit_loc - start_loc
                prev_val = "0"
        elif hex_block_str == hex_block_f:
            if prev_val == "0":
                start_loc = block_bit_loc
                prev_val = "1"
            loc = block_bit_loc + hex_block_num_bits - 1
            val = "1"
            if loc >= limit:
                loc = limit - 1
                break
        else:
            bin_str = bin(int(hex_block_str, 16))[2:].zfill(hex_block_num_bits)
            for loc, val in enumerate(bin_str, block_bit_loc):
                # print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
                if loc >= limit:
                    loc = limit - 1
                    val = prev_val
                    break
                if val == prev_val:
                    continue
                if val == "0":
                    intervals_append([start_loc, loc - 1])
                    bit_count += loc - start_loc
                else:
                    start_loc = loc
                prev_val = val
        block_bit_loc += hex_block_num_bits
        hex_block_start_offset += hex_block_size
        if block_bit_loc >= limit:
            break
    if val == prev_val == "1":
        intervals_append([start_loc, loc])
        bit_count += loc - start_loc + 1
    return bit_count, intervals


def _intervals_to_hex_string_v0(length: int, intervals: IntervalList) -> str:
    bits = 0
    end_offset = length - 1
    for bit_start, bit_end in intervals:
        bits += ((1 << (bit_end - bit_start + 1)) - 1) << (end_offset - bit_end)

    return hex(bits)[2:].zfill((length + 3) // 4)


def _intervals_to_hex_string_v1(length: int, intervals: IntervalList) -> str:
    hex_str = ""
    hex_prev = 0
    prev_bits_int = 0
    for bit_start, bit_end in intervals:
        hex_start = bit_start // 4
        hex_end = bit_end // 4
        # print(f"loop start {bit_start=}, {bit_end=}, {hex_prev=}, {hex_start=}, {hex_end=}, {prev_bits_int=}")

        if prev_bits_int != 0:
            # Given a hex digit only contains 4 bits, and the constraints that contiguous intervals must be combined, only two such
            # intervals may exist for a hex digit.  Therefore, we can assume if prev_bits_int is non-zero, that only one additional
            # interval can be in the same nybble and thus need to be combined.
            if hex_start == hex_prev:
                hex_start_last_bit = min((hex_start + 1) * 4 - 1, bit_end)
                prev_bits_int += ((1 << (hex_start_last_bit - bit_start + 1)) - 1) << (3 - hex_start_last_bit % 4)
                hex_start += 1
                bit_start = hex_start * 4
            hex_str += hex_map[prev_bits_int]
            # print(f"{hex_str=}")
            prev_bits_int = 0

        hex_str = hex_str.ljust(hex_start, "0")
        if hex_end > hex_start:
            hex_str += hex_map[(1 << (4 - bit_start % 4)) - 1]
            hex_str = hex_str.ljust(hex_end, "f")
            # print(f"{hex_str=}")
            bit_start = hex_end * 4

        if bit_start <= bit_end:
            prev_bits_int = ((1 << (bit_end - bit_start + 1)) - 1) << (3 - bit_end % 4)
            # print(f"new_prev {bit_start=}, {bit_end=}, {hex_prev=}, {hex_start=}, {hex_end=}, {prev_bits_int=}")

        hex_prev = hex_end

    if prev_bits_int != 0:
        hex_str += hex_map[prev_bits_int]
        # print(f"{hex_str=}")

    return hex_str.ljust((length + 3) // 4, "0")


class _BitmaskSerializer:
    """
    a helper class for converting to and from bitmasks

    instead of having the function imported, allow it to cache
    """

    _instance = None
    base_dtype = __base_dtype__
    lookup_hexmask_to_bitmask = {}
    lookup_binary_string_to_hexmask = {}
    lookup_hexmask_to_binary_string = {}

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(_BitmaskSerializer, cls).__new__(cls)
            df = eb.__debug_flag__
            eb.__debug_flag__ = False
            hexmask_to_bitmask, binary_string_to_hexmask, hexmask_to_binary_string = build_cache()
            eb.__debug_flag__ = df
            cls.lookup_hexmask_to_bitmask = hexmask_to_bitmask
            cls.lookup_binary_string_to_hexmask = binary_string_to_hexmask
            cls.lookup_hexmask_to_binary_string = hexmask_to_binary_string
        return cls._instance

    def encode_bitmask_to_hex(self, bitmask: EnhancedBitmask) -> str:
        """Given a bitmask, convert it to a hex string (lower)
        was mask_encode_hex"""
        if type(bitmask) in [np.ndarray, EnhancedBitmask]:
            assert bitmask.dtype == __base_dtype__, f"{bitmask.dtype=} != {__base_dtype__}"
            check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__)
            grouping = 16
            hex_string = ""
            bitmask = bitmask.view(np.ndarray)
            for offset in range(0, len(bitmask), grouping):
                chunk = bitmask[offset : (offset + grouping)]
                binchunk = chunk.tobytes()
                try:
                    hex_string += self.lookup_binary_string_to_hexmask[binchunk]
                except KeyError:
                    # this could be an odd size like 111111
                    # -> 00111111
                    # -> 3f
                    # fix for issue #7
                    # the cache has 8 bit chunks cached
                    try:
                        hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(8, b"\x00")]
                    except KeyError:
                        # the cache also has 16 bit chunks cached
                        try:
                            hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(16, b"\x00")]
                        except KeyError:
                            try:
                                hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(32, b"\x00")]
                            except KeyError:
                                raise
        elif isinstance(bitmask, EnhancedBitmaskLike):
            grouping = 16
            hex_string = ""
            # for offset in range(0, len(bitmask), grouping):
            bits = bitmask.get_idx_lst_from_bitmask(bitmask)
            bit = bits.pop(0)
            for offset in range(0, len(bitmask), grouping):
                byteint = 0
                if bit is not None:
                    extent = offset + grouping
                    while bit >= offset and bit < extent:
                        byteint += 1 << (grouping - (bit - offset) - 1)
                        if bits:
                            bit = bits.pop(0)
                        else:
                            bit = None
                            break
                hex_group = f"{byteint:0>4X}"  # must be grouping / 4
                hex_string += hex_group
        else:
            raise Exception(f"unsupported bitmask: {type(bitmask)}")
        return hex_string

    def encode_bitmask_to_bin(self, bitmask: EnhancedBitmask) -> str:
        check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__)
        bin_string = ""
        for idx in range(0, len(bitmask)):
            value = bitmask[idx]
            bin_string += str(value)
        return bin_string

    def decode_hex_to_bitmask(self, hex_string: str, limit=None) -> EnhancedBitmask:
        """given a hex_string(lower case), convert it from a hex_string
        to a bitmask.

        Use limit to create a bitmask not divisible by 4
        """
        hex_string_length = len(hex_string)
        required_hex_string_length = (limit // 4) + (1 if limit % 4 else 0)
        if limit is None:
            raise Exception("limit is required to get a correct sized bitmask, please send in limit")
        elif hex_string_length < required_hex_string_length:
            raise Exception(f"hex string is too short, is {hex_string_length} req:{required_hex_string_length}")
        lookup = self.lookup_hexmask_to_binary_string
        ungrouping = 4
        bitmask_lst = []
        bitmask_lst_append = bitmask_lst.append
        for offset in range(0, len(hex_string), ungrouping):
            hex_extent = offset + ungrouping
            bin_extent = hex_extent * 4
            chunk = hex_string[offset:hex_extent]
            if hex_extent > hex_string_length:
                bin_extent = hex_string_length * 4
            _bitmask = np.frombuffer(lookup[chunk], dtype=self.base_dtype)
            if limit and bin_extent > limit:
                _bitmask = _bitmask[: -(bin_extent - limit)]
            bitmask_lst_append(_bitmask)
        bitmask = eb.concatenate(bitmask_lst)
        return bitmask

    def encode_litebitmask_to_hex(self, bitmask: LiteBitmask) -> str:
        return _intervals_to_hex_string(bitmask.length, bitmask.intervals)

    def decode_hex_to_litebitmask(self, hex_string: str, limit=None) -> LiteBitmask:
        """given a hex_string(lower case), convert it from a hex_string
        to a bitmask.

        Use limit to create a bitmask not divisible by 4
        """
        _, intervals = _hex_string_to_bit_intervals(hex_string, limit=limit)
        # bit_idx_first = intervals[0][0] if intervals else -1
        # bit_idx_last = intervals[-1][1] if intervals else -1
        # bitmask = LiteBitmaskSlots(limit, bit_count, intervals, bit_idx_first, bit_idx_last)
        # todo: evaluate correctness of validate=True and merge=False.
        bitmask = LiteBitmask.zeros_and_set_intervals(limit, intervals, validate=True, merge=False)
        # import ipdb; ipdb.set_trace()
        return bitmask

    def get_bitmask_sum(self, bitmask: EnhancedBitmask) -> int:
        """this gets the sum, not the count, be careful"""
        return int(eb.sum(bitmask))

    def get_bitmask_count(self, bitmask: EnhancedBitmask) -> int:
        """this gets the count of the bits set."""
        return bitmask.bit_count


def BitmaskSerializer():
    BitS = _BitmaskSerializer()
    return BitS


_hex_string_to_bit_intervals = _hex_string_to_bit_intervals_v5
_intervals_to_hex_string = _intervals_to_hex_string_v1