# Copyright (C) 2024, UChicago Argonne, LLC # Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file # in the top-level directory. # cython: language_level=3, boundscheck=True import numpy as np from typing import Optional, Tuple from .bitmask import EnhancedBitmask, EnhancedBitmaskLike from .bitmask import build_cache, check_eb_overunderflow from .bitmask import eb from .bitmask_globals import __base_dtype__ from .bitmask_lite import LiteBitmask, IntervalList default_hex_block_size = 256 hex_map = "0123456789abcdef" def _hex_string_to_bit_intervals_v0(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]: """ Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots. """ hex_str_len = len(hex_str) hex_str_num_bits = hex_str_len * 4 if limit is None: limit = hex_str_num_bits else: assert limit <= hex_str_num_bits bin_str = bin(int(hex_str, 16))[2:] intervals: IntervalList = [] intervals_append = intervals.append bit_count = 0 bit_offset = hex_str_num_bits - len(bin_str) start_loc = 0 prev_val = 0 loc = 0 val = 0 for loc, val in enumerate((int(bit) for bit in bin_str), bit_offset): if loc == limit: loc -= 1 val = prev_val break if val == prev_val: continue if val == 0: intervals_append([start_loc, loc - 1]) bit_count += loc - start_loc else: start_loc = loc prev_val = val if val == prev_val == 1: intervals_append([start_loc, loc]) bit_count += loc - start_loc + 1 return bit_count, intervals def _hex_string_to_bit_intervals_v1(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]: """ Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots. """ hex_str_len = len(hex_str) hex_str_num_bits = hex_str_len * 4 if limit is None: limit = hex_str_num_bits else: assert limit <= hex_str_num_bits bin_str = bin(int(hex_str, 16))[2:] intervals: IntervalList = [] intervals_append = intervals.append bit_count = 0 bit_offset = hex_str_num_bits - len(bin_str) start_loc = 0 prev_val = "0" loc = 0 val = "0" for loc, val in enumerate(bin_str, bit_offset): if loc == limit: loc -= 1 val = prev_val break if val == prev_val: continue if val == "0": intervals_append([start_loc, loc - 1]) bit_count += loc - start_loc else: start_loc = loc prev_val = val if val == prev_val == "1": intervals_append([start_loc, loc]) bit_count += loc - start_loc + 1 return bit_count, intervals def _hex_string_to_bit_intervals_v2(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]: """ Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots. """ hex_str_len = len(hex_str) hex_str_num_bits = hex_str_len * 4 if limit is None: limit = hex_str_num_bits else: assert limit <= hex_str_num_bits intervals: IntervalList = [] intervals_append = intervals.append bit_count = 0 nybble_loc = 0 start_loc = 0 prev_val = "0" loc = 0 val = "0" for nybble_char in hex_str: if nybble_char == "0": if prev_val == "1": intervals_append([start_loc, nybble_loc - 1]) bit_count += nybble_loc - start_loc prev_val = "0" elif nybble_char == "f": if prev_val == "0": start_loc = nybble_loc prev_val = "1" loc = nybble_loc + 3 val = "1" else: bin_str = bin(int(nybble_char, 16))[2:].zfill(4) for loc, val in enumerate(bin_str, nybble_loc): # print(f"{bin_str=}, {nybble_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}") if loc >= limit: val = prev_val break if val == prev_val: continue if val == "0": intervals_append([start_loc, loc - 1]) bit_count += loc - start_loc else: start_loc = loc prev_val = val if loc >= limit: loc = limit - 1 break nybble_loc += 4 if val == prev_val == "1": intervals_append([start_loc, loc]) bit_count += loc - start_loc + 1 return bit_count, intervals def _hex_string_to_bit_intervals_v3( hex_str: str, limit: Optional[int] = None, hex_block_size: Optional[int] = None ) -> Tuple[int, IntervalList]: """ Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to create a LiteBitmaskSlots. """ hex_str_len = len(hex_str) hex_str_num_bits = hex_str_len * 4 if limit is None: limit = hex_str_num_bits else: assert limit <= hex_str_num_bits if hex_block_size == None: # TODO: adjust block size based on the limit? hex_block_size = default_hex_block_size intervals: IntervalList = [] intervals_append = intervals.append bit_count = 0 block_bit_loc = 0 start_loc = 0 prev_val = "0" loc = 0 val = "0" hex_block_zeros = "0" * hex_block_size hex_block_num_bits = hex_block_size * 4 hex_block_start_offset = 0 while hex_block_start_offset < hex_str_len: hex_block_end_offset = hex_block_start_offset + hex_block_size if hex_block_end_offset > hex_str_len: hex_block_end_offset = hex_str_len hex_block_size = hex_block_end_offset - hex_block_start_offset hex_block_num_bits = hex_block_size * 4 hex_block_zeros = "0" * hex_block_size hex_block_str = hex_str[hex_block_start_offset:hex_block_end_offset] if hex_block_str == hex_block_zeros: if prev_val == "1": intervals_append([start_loc, block_bit_loc - 1]) bit_count += block_bit_loc - start_loc prev_val = "0" else: bin_str = bin(int(hex_block_str, 16))[2:].zfill(hex_block_num_bits) for loc, val in enumerate(bin_str, block_bit_loc): # print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}") if loc >= limit: loc -= 1 val = prev_val break if val == prev_val: continue if val == "0": intervals_append([start_loc, loc - 1]) bit_count += loc - start_loc else: start_loc = loc prev_val = val block_bit_loc += hex_block_num_bits hex_block_start_offset += hex_block_size if block_bit_loc >= limit: break if val == prev_val == "1": intervals_append([start_loc, loc]) bit_count += loc - start_loc + 1 return bit_count, intervals def _hex_string_to_bit_intervals_v4( hex_str: str, limit: Optional[int] = None, hex_block_size: Optional[int] = None ) -> Tuple[int, IntervalList]: """ Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to create a LiteBitmaskSlots. """ hex_str_len = len(hex_str) hex_str_num_bits = hex_str_len * 4 if limit is None: limit = hex_str_num_bits else: assert limit <= hex_str_num_bits if hex_block_size == None: # TODO: adjust block size based on the limit? hex_block_size = default_hex_block_size intervals: IntervalList = [] intervals_append = intervals.append bit_count = 0 block_bit_loc = 0 start_loc = 0 prev_val = "0" loc = 0 val = "0" hex_block_num_bits = hex_block_size * 4 hex_block_start_offset = 0 while hex_block_start_offset < hex_str_len: hex_block_end_offset = hex_block_start_offset + hex_block_size if hex_block_end_offset > hex_str_len: hex_block_end_offset = hex_str_len hex_block_size = hex_block_end_offset - hex_block_start_offset hex_block_num_bits = hex_block_size * 4 hex_block_int = int(hex_str[hex_block_start_offset:hex_block_end_offset], 16) # print(f"{hex_block_start_offset=}, {hex_block_int=}") if hex_block_int == 0: if prev_val == "1": intervals_append([start_loc, block_bit_loc - 1]) bit_count += block_bit_loc - start_loc prev_val = "0" else: bin_str = bin(hex_block_int)[2:].zfill(hex_block_num_bits) for loc, val in enumerate(bin_str, block_bit_loc): # print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}") if loc >= limit: loc -= 1 val = prev_val break if val == prev_val: continue if val == "0": intervals_append([start_loc, loc - 1]) bit_count += loc - start_loc else: start_loc = loc prev_val = val block_bit_loc += hex_block_num_bits hex_block_start_offset += hex_block_size if block_bit_loc >= limit: break if val == prev_val == "1": intervals_append([start_loc, loc]) bit_count += loc - start_loc + 1 return bit_count, intervals def _hex_string_to_bit_intervals_v5( hex_str: str, limit: Optional[int] = None, hex_block_size: Optional[int] = None ) -> Tuple[int, IntervalList]: """ Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to create a LiteBitmaskSlots. """ hex_str_len = len(hex_str) hex_str_num_bits = hex_str_len * 4 if limit is None: limit = hex_str_num_bits else: assert limit <= hex_str_num_bits if hex_block_size == None: # TODO: adjust block size based on the limit? hex_block_size = default_hex_block_size intervals: IntervalList = [] intervals_append = intervals.append bit_count = 0 block_bit_loc = 0 start_loc = 0 prev_val = "0" loc = 0 val = "0" hex_block_0 = "0" * hex_block_size hex_block_f = "f" * hex_block_size hex_block_num_bits = hex_block_size * 4 hex_block_start_offset = 0 while hex_block_start_offset < hex_str_len: hex_block_end_offset = hex_block_start_offset + hex_block_size if hex_block_end_offset > hex_str_len: hex_block_end_offset = hex_str_len hex_block_size = hex_block_end_offset - hex_block_start_offset hex_block_num_bits = hex_block_size * 4 hex_block_0 = "0" * hex_block_size hex_block_f = "f" * hex_block_size hex_block_str = hex_str[hex_block_start_offset:hex_block_end_offset] if hex_block_str == hex_block_0: if prev_val == "1": intervals_append([start_loc, block_bit_loc - 1]) bit_count += block_bit_loc - start_loc prev_val = "0" elif hex_block_str == hex_block_f: if prev_val == "0": start_loc = block_bit_loc prev_val = "1" loc = block_bit_loc + hex_block_num_bits - 1 val = "1" if loc >= limit: loc = limit - 1 break else: bin_str = bin(int(hex_block_str, 16))[2:].zfill(hex_block_num_bits) for loc, val in enumerate(bin_str, block_bit_loc): # print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}") if loc >= limit: loc = limit - 1 val = prev_val break if val == prev_val: continue if val == "0": intervals_append([start_loc, loc - 1]) bit_count += loc - start_loc else: start_loc = loc prev_val = val block_bit_loc += hex_block_num_bits hex_block_start_offset += hex_block_size if block_bit_loc >= limit: break if val == prev_val == "1": intervals_append([start_loc, loc]) bit_count += loc - start_loc + 1 return bit_count, intervals def _intervals_to_hex_string_v0(length: int, intervals: IntervalList) -> str: bits = 0 end_offset = length - 1 for bit_start, bit_end in intervals: bits += ((1 << (bit_end - bit_start + 1)) - 1) << (end_offset - bit_end) return hex(bits)[2:].zfill((length + 3) // 4) def _intervals_to_hex_string_v1(length: int, intervals: IntervalList) -> str: hex_str = "" hex_prev = 0 prev_bits_int = 0 for bit_start, bit_end in intervals: hex_start = bit_start // 4 hex_end = bit_end // 4 # print(f"loop start {bit_start=}, {bit_end=}, {hex_prev=}, {hex_start=}, {hex_end=}, {prev_bits_int=}") if prev_bits_int != 0: # Given a hex digit only contains 4 bits, and the constraints that contiguous intervals must be combined, only two such # intervals may exist for a hex digit. Therefore, we can assume if prev_bits_int is non-zero, that only one additional # interval can be in the same nybble and thus need to be combined. if hex_start == hex_prev: hex_start_last_bit = min((hex_start + 1) * 4 - 1, bit_end) prev_bits_int += ((1 << (hex_start_last_bit - bit_start + 1)) - 1) << (3 - hex_start_last_bit % 4) hex_start += 1 bit_start = hex_start * 4 hex_str += hex_map[prev_bits_int] # print(f"{hex_str=}") prev_bits_int = 0 hex_str = hex_str.ljust(hex_start, "0") if hex_end > hex_start: hex_str += hex_map[(1 << (4 - bit_start % 4)) - 1] hex_str = hex_str.ljust(hex_end, "f") # print(f"{hex_str=}") bit_start = hex_end * 4 if bit_start <= bit_end: prev_bits_int = ((1 << (bit_end - bit_start + 1)) - 1) << (3 - bit_end % 4) # print(f"new_prev {bit_start=}, {bit_end=}, {hex_prev=}, {hex_start=}, {hex_end=}, {prev_bits_int=}") hex_prev = hex_end if prev_bits_int != 0: hex_str += hex_map[prev_bits_int] # print(f"{hex_str=}") return hex_str.ljust((length + 3) // 4, "0") class _BitmaskSerializer: """ a helper class for converting to and from bitmasks instead of having the function imported, allow it to cache """ _instance = None base_dtype = __base_dtype__ lookup_hexmask_to_bitmask = {} lookup_binary_string_to_hexmask = {} lookup_hexmask_to_binary_string = {} def __new__(cls): if cls._instance is None: cls._instance = super(_BitmaskSerializer, cls).__new__(cls) df = eb.__debug_flag__ eb.__debug_flag__ = False hexmask_to_bitmask, binary_string_to_hexmask, hexmask_to_binary_string = build_cache() eb.__debug_flag__ = df cls.lookup_hexmask_to_bitmask = hexmask_to_bitmask cls.lookup_binary_string_to_hexmask = binary_string_to_hexmask cls.lookup_hexmask_to_binary_string = hexmask_to_binary_string return cls._instance def encode_bitmask_to_hex(self, bitmask: EnhancedBitmask) -> str: """Given a bitmask, convert it to a hex string (lower) was mask_encode_hex""" if type(bitmask) in [np.ndarray, EnhancedBitmask]: assert bitmask.dtype == __base_dtype__, f"{bitmask.dtype=} != {__base_dtype__}" check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__) grouping = 16 hex_string = "" bitmask = bitmask.view(np.ndarray) for offset in range(0, len(bitmask), grouping): chunk = bitmask[offset : (offset + grouping)] binchunk = chunk.tobytes() try: hex_string += self.lookup_binary_string_to_hexmask[binchunk] except KeyError: # this could be an odd size like 111111 # -> 00111111 # -> 3f # fix for issue #7 # the cache has 8 bit chunks cached try: hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(8, b"\x00")] except KeyError: # the cache also has 16 bit chunks cached try: hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(16, b"\x00")] except KeyError: try: hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(32, b"\x00")] except KeyError: raise elif isinstance(bitmask, EnhancedBitmaskLike): grouping = 16 hex_string = "" # for offset in range(0, len(bitmask), grouping): bits = bitmask.get_idx_lst_from_bitmask(bitmask) bit = bits.pop(0) for offset in range(0, len(bitmask), grouping): byteint = 0 if bit is not None: extent = offset + grouping while bit >= offset and bit < extent: byteint += 1 << (grouping - (bit - offset) - 1) if bits: bit = bits.pop(0) else: bit = None break hex_group = f"{byteint:0>4X}" # must be grouping / 4 hex_string += hex_group else: raise Exception(f"unsupported bitmask: {type(bitmask)}") return hex_string def encode_bitmask_to_bin(self, bitmask: EnhancedBitmask) -> str: check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__) bin_string = "" for idx in range(0, len(bitmask)): value = bitmask[idx] bin_string += str(value) return bin_string def decode_hex_to_bitmask(self, hex_string: str, limit=None) -> EnhancedBitmask: """given a hex_string(lower case), convert it from a hex_string to a bitmask. Use limit to create a bitmask not divisible by 4 """ hex_string_length = len(hex_string) required_hex_string_length = (limit // 4) + (1 if limit % 4 else 0) if limit is None: raise Exception("limit is required to get a correct sized bitmask, please send in limit") elif hex_string_length < required_hex_string_length: raise Exception(f"hex string is too short, is {hex_string_length} req:{required_hex_string_length}") lookup = self.lookup_hexmask_to_binary_string ungrouping = 4 bitmask_lst = [] bitmask_lst_append = bitmask_lst.append for offset in range(0, len(hex_string), ungrouping): hex_extent = offset + ungrouping bin_extent = hex_extent * 4 chunk = hex_string[offset:hex_extent] if hex_extent > hex_string_length: bin_extent = hex_string_length * 4 _bitmask = np.frombuffer(lookup[chunk], dtype=self.base_dtype) if limit and bin_extent > limit: _bitmask = _bitmask[: -(bin_extent - limit)] bitmask_lst_append(_bitmask) bitmask = eb.concatenate(bitmask_lst) return bitmask def encode_litebitmask_to_hex(self, bitmask: LiteBitmask) -> str: return _intervals_to_hex_string(bitmask.length, bitmask.intervals) def decode_hex_to_litebitmask(self, hex_string: str, limit=None) -> LiteBitmask: """given a hex_string(lower case), convert it from a hex_string to a bitmask. Use limit to create a bitmask not divisible by 4 """ _, intervals = _hex_string_to_bit_intervals(hex_string, limit=limit) # bit_idx_first = intervals[0][0] if intervals else -1 # bit_idx_last = intervals[-1][1] if intervals else -1 # bitmask = LiteBitmaskSlots(limit, bit_count, intervals, bit_idx_first, bit_idx_last) # todo: evaluate correctness of validate=True and merge=False. bitmask = LiteBitmask.zeros_and_set_intervals(limit, intervals, validate=True, merge=False) # import ipdb; ipdb.set_trace() return bitmask def get_bitmask_sum(self, bitmask: EnhancedBitmask) -> int: """this gets the sum, not the count, be careful""" return int(eb.sum(bitmask)) def get_bitmask_count(self, bitmask: EnhancedBitmask) -> int: """this gets the count of the bits set.""" return bitmask.bit_count def BitmaskSerializer(): BitS = _BitmaskSerializer() return BitS _hex_string_to_bit_intervals = _hex_string_to_bit_intervals_v5 _intervals_to_hex_string = _intervals_to_hex_string_v1