# Copyright (C) 2024, UChicago Argonne, LLC # Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file # in the top-level directory. import pytest from contextlib import nullcontext from Octeres.bitmask import EnhancedBitmask from Octeres.bitmask import BitmaskSerializer from Octeres.bitmask.bitmask_lite import LiteBitmask, convert_bitmask_to_litebitmask, intersection_all from Octeres.bitmask.bitmask_serial import intervals_to_hex_string BitS = BitmaskSerializer() def fixtures_encode_decode_bitmask_to_hex(ids=False): deck = [ pytest.param("bad", 10, "ff", None, EnhancedBitmask.ones(8), LiteBitmask.ones(8), pytest.raises(Exception)), pytest.param("good", 8, "ff", None, EnhancedBitmask.ones(8), LiteBitmask.ones(8), nullcontext(enter_result=True)), pytest.param("good", 7, "ff", "fe", EnhancedBitmask.ones(7), LiteBitmask.ones(7), nullcontext(enter_result=True)), pytest.param( "good", 36, "00000aa00", None, EnhancedBitmask.zeros_and_set_bits(36, [20, 22, 24, 26]), LiteBitmask.zeros_and_set_bits(36, [20, 22, 24, 26]), nullcontext(enter_result=True), ), pytest.param( "good", 36, "fffffffff", None, EnhancedBitmask.ones(36), LiteBitmask.zeros_and_set_intervals(36, [[0, 35]]), nullcontext(enter_result=True), ), # pytest.param("good", 12, 'aaa', EnhancedBitmask.ones(12), # LiteBitmask.zeros_and_set_intervals(36, [[0, 35]]), nullcontext(enter_result=True)), pytest.param( "good", 16, "aaaa", None, EnhancedBitmask.zeros_and_set_bits(16, [0, 2, 4, 6, 8, 10, 12, 14]), LiteBitmask.zeros_and_set_bits(16, [0, 2, 4, 6, 8, 10, 12, 14]), nullcontext(enter_result=True), ), pytest.param( "good", 16, "5555", None, EnhancedBitmask.zeros_and_set_bits(16, [1, 3, 5, 7, 9, 11, 13, 15]), LiteBitmask.zeros_and_set_bits(16, [1, 3, 5, 7, 9, 11, 13, 15]), nullcontext(enter_result=True), ), ] if ids: deck = [f"{i:0>4}" for i, _ in enumerate(deck)] return deck @pytest.mark.parametrize( "comment,length,hexmask_in,hexmask_out,eb_bitmask,l_bitmask,expectation", fixtures_encode_decode_bitmask_to_hex(), ids=fixtures_encode_decode_bitmask_to_hex(ids=True), ) def test_encode_decode_bitmask_to_hex( comment: str, length: int, hexmask_in: str, hexmask_out: str, eb_bitmask: EnhancedBitmask, l_bitmask: LiteBitmask, expectation ): with expectation: if hexmask_out is None: hexmask_out = hexmask_in eb_bitmask_r = BitS.decode_hex_to_bitmask(hexmask_in, limit=length) l_bitmask_r = convert_bitmask_to_litebitmask(eb_bitmask_r) l_bitmask_r2 = BitS.decode_hex_to_litebitmask(hexmask_in, limit=length) l_hexmask_r2 = intervals_to_hex_string(l_bitmask_r2.length, l_bitmask_r2.intervals) # todo: # add code to encode it to hexmask and then back again, asserting correctness. # check validation of each of the bitmasks provided. s_bitmask_r = l_bitmask_r.to_slots() s_bitmask_r2 = l_bitmask_r2.to_slots() print(eb_bitmask_r) assert intersection_all(l_bitmask_r, l_bitmask) assert eb_bitmask_r.intersection_all(eb_bitmask) assert intersection_all(s_bitmask_r, s_bitmask_r2) assert l_hexmask_r2 == hexmask_out