Newer
Older
# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
# in the top-level directory.
# cython: language_level=3, boundscheck=True
import numpy as np

Brian Toonen
committed
from typing import Optional, Tuple

Eric Pershey
committed
from .bitmask import EnhancedBitmask, EnhancedBitmaskLike
from .bitmask import build_cache, check_eb_overunderflow
from .bitmask import eb
from .bitmask_globals import __base_dtype__

Brian Toonen
committed
from .bitmask_lite import LiteBitmask, IntervalList

Brian Toonen
committed
default_hex_block_size = 256
hex_map = "0123456789abcdef"

Brian Toonen
committed
def hex_string_to_bit_intervals0(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]:
"""
Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots.
"""
hex_str_len = len(hex_str)
hex_str_num_bits = hex_str_len * 4

Brian Toonen
committed
limit = hex_str_num_bits
else:
assert limit <= hex_str_num_bits
bin_str = bin(int(hex_str, 16))[2:]
intervals: IntervalList = []
intervals_append = intervals.append
bit_count = 0
bit_offset = hex_str_num_bits - len(bin_str)

Brian Toonen
committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
prev_val = 0
loc = 0
val = 0
for loc, val in enumerate((int(bit) for bit in bin_str), bit_offset):
if loc == limit:
loc -= 1
val = prev_val
break
if val == prev_val:
continue
if val == 0:
intervals_append([start_loc, loc - 1])
bit_count += loc - start_loc
else:
start_loc = loc
prev_val = val
if val == prev_val == 1:
intervals_append([start_loc, loc])
bit_count += loc - start_loc + 1
return bit_count, intervals
def hex_string_to_bit_intervals1(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]:
"""
Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots.
"""
hex_str_len = len(hex_str)
hex_str_num_bits = hex_str_len * 4
if limit is None:
limit = hex_str_num_bits
else:
assert limit <= hex_str_num_bits
bin_str = bin(int(hex_str, 16))[2:]
intervals: IntervalList = []
intervals_append = intervals.append

Brian Toonen
committed
bit_count = 0
bit_offset = hex_str_num_bits - len(bin_str)
start_loc = 0
prev_val = "0"
loc = 0
val = "0"
for loc, val in enumerate(bin_str, bit_offset):

Brian Toonen
committed
loc -= 1
val = prev_val

Brian Toonen
committed
if val == prev_val:

Eric Pershey
committed
if val == "0":

Brian Toonen
committed
intervals_append([start_loc, loc - 1])
bit_count += loc - start_loc

Brian Toonen
committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
prev_val = val
if val == prev_val == "1":
intervals_append([start_loc, loc])
bit_count += loc - start_loc + 1
return bit_count, intervals
def hex_string_to_bit_intervals2(hex_str: str, limit: Optional[int] = None) -> Tuple[int, IntervalList]:
"""
Given a hex string and limit (number of bits), return a set of intervals that will be used to create a LiteBitmaskSlots.
"""
hex_str_len = len(hex_str)
hex_str_num_bits = hex_str_len * 4
if limit is None:
limit = hex_str_num_bits
else:
assert limit <= hex_str_num_bits
intervals: IntervalList = []
intervals_append = intervals.append
bit_count = 0
nybble_loc = 0
start_loc = 0
prev_val = "0"
loc = 0
val = "0"
for nybble_char in hex_str:
if nybble_char == "0":
if prev_val == "1":
intervals_append([start_loc, nybble_loc - 1])
bit_count += nybble_loc - start_loc
prev_val = "0"
elif nybble_char == "f":
if prev_val == "0":
start_loc = nybble_loc
prev_val = "1"
loc = nybble_loc + 3
val = "1"
else:
bin_str = bin(int(nybble_char, 16))[2:].zfill(4)
for loc, val in enumerate(bin_str, nybble_loc):
# print(f"{bin_str=}, {nybble_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
if loc >= limit:
val = prev_val
break
if val == prev_val:
continue
if val == "0":
intervals_append([start_loc, loc - 1])
bit_count += loc - start_loc
else:
start_loc = loc
prev_val = val
if loc >= limit:
loc = limit - 1
break
nybble_loc += 4
if val == prev_val == "1":
intervals_append([start_loc, loc])
bit_count += loc - start_loc + 1
return bit_count, intervals
def hex_string_to_bit_intervals3(
hex_str: str,
limit: Optional[int] = None,
hex_block_size: Optional[int] = None,
) -> Tuple[int, IntervalList]:
"""
Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to
create a LiteBitmaskSlots.
"""
hex_str_len = len(hex_str)
hex_str_num_bits = hex_str_len * 4
if limit is None:
limit = hex_str_num_bits
else:
assert limit <= hex_str_num_bits
if hex_block_size == None:
# TODO: adjust block size based on the limit?
hex_block_size = default_hex_block_size
intervals: IntervalList = []
intervals_append = intervals.append
bit_count = 0
block_bit_loc = 0
start_loc = 0
prev_val = "0"
loc = 0
val = "0"
hex_block_zeros = "0" * hex_block_size
hex_block_num_bits = hex_block_size * 4
hex_block_start_offset = 0
while hex_block_start_offset < hex_str_len:
hex_block_end_offset = hex_block_start_offset + hex_block_size
if hex_block_end_offset > hex_str_len:
hex_block_end_offset = hex_str_len
hex_block_size = hex_block_end_offset - hex_block_start_offset
hex_block_num_bits = hex_block_size * 4
hex_block_zeros = "0" * hex_block_size
hex_block_str = hex_str[hex_block_start_offset:hex_block_end_offset]
if hex_block_str == hex_block_zeros:
if prev_val == "1":
intervals_append([start_loc, block_bit_loc - 1])
bit_count += block_bit_loc - start_loc
prev_val = "0"
else:
bin_str = bin(int(hex_block_str, 16))[2:].zfill(hex_block_num_bits)
for loc, val in enumerate(bin_str, block_bit_loc):
# print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
if loc >= limit:
loc -= 1
val = prev_val
break
if val == prev_val:
continue
if val == "0":
intervals_append([start_loc, loc - 1])
bit_count += loc - start_loc
else:
start_loc = loc
prev_val = val
block_bit_loc += hex_block_num_bits
hex_block_start_offset += hex_block_size
if block_bit_loc >= limit:
break
if val == prev_val == "1":
intervals_append([start_loc, loc])
bit_count += loc - start_loc + 1
return bit_count, intervals
def hex_string_to_bit_intervals4(
hex_str: str,
limit: Optional[int] = None,
hex_block_size: Optional[int] = None,
) -> Tuple[int, IntervalList]:
"""
Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to
create a LiteBitmaskSlots.
"""
hex_str_len = len(hex_str)
hex_str_num_bits = hex_str_len * 4
if limit is None:
limit = hex_str_num_bits
else:
assert limit <= hex_str_num_bits
if hex_block_size == None:
# TODO: adjust block size based on the limit?
hex_block_size = default_hex_block_size
intervals: IntervalList = []
intervals_append = intervals.append
bit_count = 0
block_bit_loc = 0
start_loc = 0
prev_val = "0"
loc = 0
val = "0"
hex_block_num_bits = hex_block_size * 4
hex_block_start_offset = 0
while hex_block_start_offset < hex_str_len:
hex_block_end_offset = hex_block_start_offset + hex_block_size
if hex_block_end_offset > hex_str_len:
hex_block_end_offset = hex_str_len
hex_block_size = hex_block_end_offset - hex_block_start_offset
hex_block_num_bits = hex_block_size * 4
hex_block_int = int(hex_str[hex_block_start_offset:hex_block_end_offset], 16)
# print(f"{hex_block_start_offset=}, {hex_block_int=}")
if hex_block_int == 0:
if prev_val == "1":
intervals_append([start_loc, block_bit_loc - 1])
bit_count += block_bit_loc - start_loc
prev_val = "0"
else:
bin_str = bin(hex_block_int)[2:].zfill(hex_block_num_bits)
for loc, val in enumerate(bin_str, block_bit_loc):
# print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
if loc >= limit:
loc -= 1
val = prev_val
break
if val == prev_val:
continue
if val == "0":
intervals_append([start_loc, loc - 1])
bit_count += loc - start_loc
else:
start_loc = loc
prev_val = val
block_bit_loc += hex_block_num_bits
hex_block_start_offset += hex_block_size
if block_bit_loc >= limit:
break
if val == prev_val == "1":
intervals_append([start_loc, loc])
bit_count += loc - start_loc + 1
return bit_count, intervals
def hex_string_to_bit_intervals5(
hex_str: str,
limit: Optional[int] = None,
hex_block_size: Optional[int] = None,
) -> Tuple[int, IntervalList]:
"""
Given a hex string, a limit (number of bits) and an optional hex block size, return a set of intervals that will be used to
create a LiteBitmaskSlots.
"""
hex_str_len = len(hex_str)
hex_str_num_bits = hex_str_len * 4
if limit is None:
limit = hex_str_num_bits
else:
assert limit <= hex_str_num_bits
if hex_block_size == None:
# TODO: adjust block size based on the limit?
hex_block_size = default_hex_block_size
intervals: IntervalList = []
intervals_append = intervals.append
bit_count = 0
block_bit_loc = 0
start_loc = 0
prev_val = "0"
loc = 0
val = "0"
hex_block_0 = "0" * hex_block_size
hex_block_f = "f" * hex_block_size
hex_block_num_bits = hex_block_size * 4
hex_block_start_offset = 0
while hex_block_start_offset < hex_str_len:
hex_block_end_offset = hex_block_start_offset + hex_block_size
if hex_block_end_offset > hex_str_len:
hex_block_end_offset = hex_str_len
hex_block_size = hex_block_end_offset - hex_block_start_offset
hex_block_num_bits = hex_block_size * 4
hex_block_0 = "0" * hex_block_size
hex_block_f = "f" * hex_block_size
hex_block_str = hex_str[hex_block_start_offset:hex_block_end_offset]
if hex_block_str == hex_block_0:
if prev_val == "1":
intervals_append([start_loc, block_bit_loc - 1])
bit_count += block_bit_loc - start_loc
prev_val = "0"
elif hex_block_str == hex_block_f:
if prev_val == "0":
start_loc = block_bit_loc
prev_val = "1"
loc = block_bit_loc + hex_block_num_bits - 1
val = "1"
if loc >= limit:
loc = limit - 1
break
else:
bin_str = bin(int(hex_block_str, 16))[2:].zfill(hex_block_num_bits)
for loc, val in enumerate(bin_str, block_bit_loc):
# print(f"{bin_str=}, {block_bit_loc=}, {loc=}, {val=}, {start_loc=}, {prev_val=}")
if loc >= limit:
loc = limit - 1
val = prev_val
break
if val == prev_val:
continue
if val == "0":
intervals_append([start_loc, loc - 1])
bit_count += loc - start_loc
else:
start_loc = loc
prev_val = val
block_bit_loc += hex_block_num_bits
hex_block_start_offset += hex_block_size
if block_bit_loc >= limit:
break
if val == prev_val == "1":
intervals_append([start_loc, loc])
bit_count += loc - start_loc + 1
return bit_count, intervals
def intervals_to_hex_string0(length: int, intervals: IntervalList) -> str:
bits = 0
end_offset = length - 1
for bit_start, bit_end in intervals:
bits += ((1 << (bit_end - bit_start + 1)) - 1) << (end_offset - bit_end)
return hex(bits)[2:].zfill((length + 3) // 4)
def intervals_to_hex_string1(length: int, intervals: IntervalList) -> str:
hex_str = ""
hex_prev = 0
prev_bits_int = 0
for bit_start, bit_end in intervals:
hex_start = bit_start // 4
hex_end = bit_end // 4
# print(f"loop start {bit_start=}, {bit_end=}, {hex_prev=}, {hex_start=}, {hex_end=}, {prev_bits_int=}")
if prev_bits_int != 0:
# Given a hex digit only contains 4 bits, and the constraints that contiguous intervals must be combined, only two such
# intervals may exist for a hex digit. Therefore, we can assume if prev_bits_int is non-zero, that only one additional
# interval can be in the same nybble and thus need to be combined.
if hex_start == hex_prev:
hex_start_last_bit = min((hex_start + 1) * 4 - 1, bit_end)
prev_bits_int += ((1 << (hex_start_last_bit - bit_start + 1)) - 1) << (3 - hex_start_last_bit % 4)
hex_start += 1
bit_start = hex_start * 4
hex_str += hex_map[prev_bits_int]
# print(f"{hex_str=}")
prev_bits_int = 0
hex_str = hex_str.ljust(hex_start, "0")
if hex_end > hex_start:
hex_str += hex_map[(1 << (4 - bit_start % 4)) - 1]
hex_str = hex_str.ljust(hex_end, "f")
# print(f"{hex_str=}")
bit_start = hex_end * 4
if bit_start <= bit_end:
prev_bits_int = ((1 << (bit_end - bit_start + 1)) - 1) << (3 - bit_end % 4)
# print(f"new_prev {bit_start=}, {bit_end=}, {hex_prev=}, {hex_start=}, {hex_end=}, {prev_bits_int=}")
hex_prev = hex_end
if prev_bits_int != 0:
hex_str += hex_map[prev_bits_int]
# print(f"{hex_str=}")
return hex_str.ljust((length + 3) // 4, "0")
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
class _BitmaskSerializer:
"""
a helper class for converting to and from bitmasks
instead of having the function imported, allow it to cache
"""
_instance = None
base_dtype = __base_dtype__
lookup_hexmask_to_bitmask = {}
lookup_binary_string_to_hexmask = {}
lookup_hexmask_to_binary_string = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super(_BitmaskSerializer, cls).__new__(cls)
df = eb.__debug_flag__
eb.__debug_flag__ = False
hexmask_to_bitmask, binary_string_to_hexmask, hexmask_to_binary_string = build_cache()
eb.__debug_flag__ = df
cls.lookup_hexmask_to_bitmask = hexmask_to_bitmask
cls.lookup_binary_string_to_hexmask = binary_string_to_hexmask
cls.lookup_hexmask_to_binary_string = hexmask_to_binary_string
return cls._instance
def encode_bitmask_to_hex(self, bitmask: EnhancedBitmask) -> str:
"""Given a bitmask, convert it to a hex string (lower)
was mask_encode_hex"""

Eric Pershey
committed
if type(bitmask) in [np.ndarray, EnhancedBitmask]:
assert bitmask.dtype == __base_dtype__, f"{bitmask.dtype=} != {__base_dtype__}"
check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__)
grouping = 16
hex_string = ""
bitmask = bitmask.view(np.ndarray)
for offset in range(0, len(bitmask), grouping):
chunk = bitmask[offset : (offset + grouping)]
binchunk = chunk.tobytes()

Eric Pershey
committed
hex_string += self.lookup_binary_string_to_hexmask[binchunk]

Eric Pershey
committed
# this could be an odd size like 111111
# -> 00111111
# -> 3f
# fix for issue #7
# the cache has 8 bit chunks cached

Eric Pershey
committed
hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(8, b"\x00")]

Eric Pershey
committed
# the cache also has 16 bit chunks cached

Eric Pershey
committed
hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(16, b"\x00")]

Eric Pershey
committed
try:
hex_string += self.lookup_binary_string_to_hexmask[binchunk.ljust(32, b"\x00")]
except KeyError:
raise
elif isinstance(bitmask, EnhancedBitmaskLike):
grouping = 16
hex_string = ""
# for offset in range(0, len(bitmask), grouping):
bits = bitmask.get_idx_lst_from_bitmask(bitmask)
bit = bits.pop(0)
for offset in range(0, len(bitmask), grouping):
byteint = 0
if bit is not None:
extent = offset + grouping
while bit >= offset and bit < extent:
byteint += 1 << (grouping - (bit - offset) - 1)
if bits:
bit = bits.pop(0)
else:
bit = None
break

Eric Pershey
committed
hex_group = f"{byteint:0>4X}" # must be grouping / 4

Eric Pershey
committed
hex_string += hex_group
else:
raise Exception(f"unsupported bitmask: {type(bitmask)}")
return hex_string
def encode_bitmask_to_bin(self, bitmask: EnhancedBitmask) -> str:
check_eb_overunderflow(bitmask, debug=EnhancedBitmask.__debug_flag__)
bin_string = ""
for idx in range(0, len(bitmask)):
value = bitmask[idx]
bin_string += str(value)
return bin_string
def decode_hex_to_bitmask(self, hex_string: str, limit=None) -> EnhancedBitmask:
"""given a hex_string(lower case), convert it from a hex_string
to a bitmask.
Use limit to create a bitmask not divisible by 4
"""
hex_string_length = len(hex_string)

Eric Pershey
committed
required_hex_string_length = (limit // 4) + (1 if limit % 4 else 0)
if limit is None:
raise Exception("limit is required to get a correct sized bitmask, please send in limit")
elif hex_string_length < required_hex_string_length:
raise Exception(f"hex string is too short, is {hex_string_length} req:{required_hex_string_length}")
lookup = self.lookup_hexmask_to_binary_string
ungrouping = 4
bitmask_lst = []
bitmask_lst_append = bitmask_lst.append
for offset in range(0, len(hex_string), ungrouping):
hex_extent = offset + ungrouping
bin_extent = hex_extent * 4
chunk = hex_string[offset:hex_extent]
if hex_extent > hex_string_length:
bin_extent = hex_string_length * 4
_bitmask = np.frombuffer(lookup[chunk], dtype=self.base_dtype)
if limit and bin_extent > limit:
_bitmask = _bitmask[: -(bin_extent - limit)]
bitmask_lst_append(_bitmask)
bitmask = eb.concatenate(bitmask_lst)
return bitmask
# def encode_litebitmask_to_hex(self, bitmask: EnhancedBitmask) -> str:
# pass
def decode_hex_to_litebitmask(self, hex_string: str, limit=None) -> LiteBitmask:
"""given a hex_string(lower case), convert it from a hex_string
to a bitmask.
Use limit to create a bitmask not divisible by 4
"""

Brian Toonen
committed
_, intervals = hex_string_to_bit_intervals(hex_string, limit=limit)
# bit_idx_first = intervals[0][0] if intervals else -1
# bit_idx_last = intervals[-1][1] if intervals else -1
# bitmask = LiteBitmaskSlots(limit, bit_count, intervals, bit_idx_first, bit_idx_last)
# todo: evaluate correctness of validate=True and merge=False.
bitmask = LiteBitmask.zeros_and_set_intervals(limit, intervals, validate=True, merge=False)
# import ipdb; ipdb.set_trace()
return bitmask
def get_bitmask_sum(self, bitmask: EnhancedBitmask) -> int:
"""this gets the sum, not the count, be careful"""
return int(eb.sum(bitmask))
def get_bitmask_count(self, bitmask: EnhancedBitmask) -> int:
"""this gets the count of the bits set."""
return bitmask.bit_count
def BitmaskSerializer():
BitS = _BitmaskSerializer()
return BitS

Brian Toonen
committed
hex_string_to_bit_intervals = hex_string_to_bit_intervals5
intervals_to_hex_string = intervals_to_hex_string1