Skip to content
Snippets Groups Projects
bitmask_serial.py 7.79 KiB
Newer Older
# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.

# cython: language_level=3, boundscheck=True
import numpy as np

from .bitmask import EnhancedBitmask, EnhancedBitmaskLike
from .bitmask import build_cache, check_eb_overunderflow
from .bitmask import eb
from .bitmask_globals import __base_dtype__
from .bitmask_lite import LiteBitmask, LiteBitmaskSlots


def hex_string_to_bit_intervals(hexstr: str, limit=None) -> list:
    """given a hex string and limit, return a set of intervals that will be used
    to create a LiteBitmaskSlots.

    Brian Toonen provided this function and Eric Pershey worked to get it in place.
    New code incoming."""
    len_hexstr = len(hexstr)
    if limit is None:
        limit = len_hexstr * 4
    binstr = bin(int(hexstr, 16))[2:]
    intervals = []
    begin_offset = len_hexstr * 4 - len(binstr)
    start_loc = 0
    cur_val = '0'
    intervals_append = intervals.append
    val: str
    for loc, val in enumerate((bit for bit in binstr), begin_offset):
        if loc == limit:
            val = cur_val
            loc = loc - 1
            break
        if val == cur_val:
            continue
        if val == '0':
            intervals_append((start_loc, loc - 1))
        else:
            start_loc = loc
        cur_val = val
    if val == cur_val == '1':
        intervals_append((start_loc, loc))
    return intervals


class _BitmaskSerializer:
    """
    a helper class for converting to and from bitmasks

    instead of having the function imported, allow it to cache
    """

    _instance = None
    base_dtype = __base_dtype__
    lookup_hexmask_to_bitmask = {}
    lookup_binary_string_to_hexmask = {}
    lookup_hexmask_to_binary_string = {}

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(_BitmaskSerializer, cls).__new__(cls)
            df = eb.__debug_flag__
            eb.__debug_flag__ = False
            hexmask_to_bitmask, binary_string_to_hexmask, hexmask_to_binary_string = build_cache()
            eb.__debug_flag__ = df
            cls.lookup_hexmask_to_bitmask = hexmask_to_bitmask
            cls.lookup_binary_string_to_hexmask = binary_string_to_hexmask
            cls.lookup_hexmask_to_binary_string = hexmask_to_binary_string
        return cls._instance

    def encode_bitmask_to_hex(self, bitmask: EnhancedBitmask) -> str:
        """Given a bitmask, convert it to a hex string (lower)
        was mask_encode_hex"""
        if type(bitmask) in [np.ndarray, EnhancedBitmask]:
            assert bitmask.dtype == __base_dtype__, f"{bitmask.dtype=} != {__base_dtype__}"
            check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__)
            grouping = 16
            hex_string = ""
            bitmask = bitmask.view(np.ndarray)
            for offset in range(0, len(bitmask), grouping):
                chunk = bitmask[offset : (offset + grouping)]
                binchunk = chunk.tobytes()
                    hex_string += self.lookup_binary_string_to_hexmask[binchunk]
                except KeyError:
                    # this could be an odd size like 111111
                    # -> 00111111
                    # -> 3f
                    # fix for issue #7
                    # the cache has 8 bit chunks cached
                        hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(8, b"\x00")]
                    except KeyError:
                        # the cache also has 16 bit chunks cached
                            hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(16, b"\x00")]
                        except KeyError:
                            try:
                                hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(32, b"\x00")]
                            except KeyError:
                                raise
        elif isinstance(bitmask, EnhancedBitmaskLike):
            grouping = 16
            hex_string = ""
            # for offset in range(0, len(bitmask), grouping):
            bits = bitmask.get_idx_lst_from_bitmask(bitmask)
            bit = bits.pop(0)
            for offset in range(0, len(bitmask), grouping):
                byteint = 0
                if bit is not None:
                    extent = offset + grouping
                    while bit >= offset and bit < extent:
                        byteint += 1 << (grouping - (bit - offset) - 1)
                        if bits:
                            bit = bits.pop(0)
                        else:
                            bit = None
                            break
                hex_group = f'{byteint:0>4X}'  # must be grouping / 4
                hex_string += hex_group
        else:
            raise Exception(f"unsupported bitmask: {type(bitmask)}")
        return hex_string

    def encode_bitmask_to_bin(self, bitmask: EnhancedBitmask) -> str:
        check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__)
        bin_string = ""
        for idx in range(0, len(bitmask)):
            value = bitmask[idx]
            bin_string += str(value)
        return bin_string

    def decode_hex_to_bitmask(self, hex_string: str, limit=None) -> EnhancedBitmask:
        """given a hex_string(lower case), convert it from a hex_string
        to a bitmask.

        Use limit to create a bitmask not divisible by 4
        """
        hex_string_length = len(hex_string)
        required_hex_string_length = ((limit // 4) + (1 if limit % 4 else 0))
        if limit is None:
            raise Exception("limit is required to get a correct sized bitmask, please send in limit")
        elif hex_string_length < required_hex_string_length:
            raise Exception(f"hex string is too short, is {hex_string_length} req:{required_hex_string_length}")
        lookup = self.lookup_hexmask_to_binary_string
        ungrouping = 4
        bitmask_lst = []
        bitmask_lst_append = bitmask_lst.append
        for offset in range(0, len(hex_string), ungrouping):
            hex_extent = offset + ungrouping
            bin_extent = hex_extent * 4
            chunk = hex_string[offset:hex_extent]
            if hex_extent > hex_string_length:
                bin_extent = hex_string_length * 4
            _bitmask = np.frombuffer(lookup[chunk], dtype=self.base_dtype)
            if limit and bin_extent > limit:
                _bitmask = _bitmask[: -(bin_extent - limit)]
            bitmask_lst_append(_bitmask)
        bitmask = eb.concatenate(bitmask_lst)
        return bitmask

    # def encode_litebitmask_to_hex(self, bitmask: EnhancedBitmask) -> str:
    #     pass

    def decode_hex_to_litebitmask(self, hex_string: str, limit=None) -> LiteBitmask:
        """given a hex_string(lower case), convert it from a hex_string
        to a bitmask.

        Use limit to create a bitmask not divisible by 4
        """
        intervals = hex_string_to_bit_intervals(hex_string, limit=limit)
        # bit_idx_first = intervals[0][0] if intervals else -1
        # bit_idx_last = intervals[-1][1] if intervals else -1
        # bitmask = LiteBitmaskSlots(limit, bit_count, intervals, bit_idx_first, bit_idx_last)
        # todo: evaluate correctness of validate=True and merge=False.
        bitmask = LiteBitmask.zeros_and_set_intervals(limit, intervals, validate=True, merge=False)
        # import ipdb; ipdb.set_trace()
        return bitmask

    def get_bitmask_sum(self, bitmask: EnhancedBitmask) -> int:
        """this gets the sum, not the count, be careful"""
        return int(eb.sum(bitmask))

    def get_bitmask_count(self, bitmask: EnhancedBitmask) -> int:
        """this gets the count of the bits set."""
        return bitmask.bit_count


def BitmaskSerializer():
    BitS = _BitmaskSerializer()
    return BitS