# Copyright (C) 2024, UChicago Argonne, LLC # Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file # in the top-level directory. # cython: language_level=3, boundscheck=True import numpy as np from .bitmask import EnhancedBitmask, EnhancedBitmaskLike from .bitmask import build_cache, check_eb_overunderflow from .bitmask import eb from .bitmask_globals import __base_dtype__ class _BitmaskSerializer: """ a helper class for converting to and from bitmasks instead of having the function imported, allow it to cache """ _instance = None base_dtype = __base_dtype__ lookup_hexmask_to_bitmask = {} lookup_binary_string_to_hexmask = {} lookup_hexmask_to_binary_string = {} def __new__(cls): if cls._instance is None: cls._instance = super(_BitmaskSerializer, cls).__new__(cls) df = eb.__debug_flag__ eb.__debug_flag__ = False hexmask_to_bitmask, binary_string_to_hexmask, hexmask_to_binary_string = build_cache() eb.__debug_flag__ = df cls.lookup_hexmask_to_bitmask = hexmask_to_bitmask cls.lookup_binary_string_to_hexmask = binary_string_to_hexmask cls.lookup_hexmask_to_binary_string = hexmask_to_binary_string return cls._instance def encode_bitmask_to_hex(self, bitmask: EnhancedBitmask) -> str: """Given a bitmask, convert it to a hex string (lower) was mask_encode_hex""" if type(bitmask) in [np.ndarray, EnhancedBitmask]: assert bitmask.dtype == __base_dtype__, f"{bitmask.dtype=} != {__base_dtype__}" check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__) grouping = 16 hex_string = "" bitmask = bitmask.view(np.ndarray) for offset in range(0, len(bitmask), grouping): chunk = bitmask[offset : (offset + grouping)] binchunk = chunk.tobytes() try: hex_string += self.lookup_binary_string_to_hexmask[binchunk] except KeyError: # this could be an odd size like 111111 # -> 00111111 # -> 3f # fix for issue #7 # the cache has 8 bit chunks cached try: hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(8, b"\x00")] except KeyError: # the cache also has 16 bit chunks cached try: hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(16, b"\x00")] except KeyError: try: hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(32, b"\x00")] except KeyError: raise elif isinstance(bitmask, EnhancedBitmaskLike): grouping = 16 hex_string = "" # for offset in range(0, len(bitmask), grouping): bits = bitmask.get_idx_lst_from_bitmask(bitmask) bit = bits.pop(0) for offset in range(0, len(bitmask), grouping): byteint = 0 if bit is not None: extent = offset + grouping while bit >= offset and bit < extent: byteint += 1 << (grouping - (bit - offset) - 1) if bits: bit = bits.pop(0) else: bit = None break hex_group = f'{byteint:0>4X}' # must be grouping / 4 hex_string += hex_group else: raise Exception(f"unsupported bitmask: {type(bitmask)}") return hex_string def encode_bitmask_to_bin(self, bitmask: EnhancedBitmask) -> str: check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__) bin_string = "" for idx in range(0, len(bitmask)): value = bitmask[idx] bin_string += str(value) return bin_string def decode_hex_to_bitmask(self, hex_string: str, limit=None) -> EnhancedBitmask: """given a hex_string(lower case), convert it from a hex_string to a bitmask. Use limit to create a bitmask not divisible by 4 """ if limit is None: raise Exception("limit is required to get a correct sized bitmask, please send in limit") lookup = self.lookup_hexmask_to_binary_string ungrouping = 4 bitmask_lst = [] bitmask_lst_append = bitmask_lst.append hex_string_length = len(hex_string) for offset in range(0, len(hex_string), ungrouping): hex_extent = offset + ungrouping bin_extent = hex_extent * 4 chunk = hex_string[offset:hex_extent] if hex_extent > hex_string_length: bin_extent = hex_string_length * 4 _bitmask = np.frombuffer(lookup[chunk], dtype=self.base_dtype) if limit and bin_extent > limit: _bitmask = _bitmask[: -(bin_extent - limit)] bitmask_lst_append(_bitmask) bitmask = eb.concatenate(bitmask_lst) return bitmask def get_bitmask_sum(self, bitmask: EnhancedBitmask) -> int: """this gets the sum, not the count, be careful""" return int(eb.sum(bitmask)) def get_bitmask_count(self, bitmask: EnhancedBitmask) -> int: """this gets the count of the bits set.""" return bitmask.bit_count def BitmaskSerializer(): BitS = _BitmaskSerializer() return BitS