# Copyright (C) 2024, UChicago Argonne, LLC # Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file # in the top-level directory. import abc import copy import operator import pickle import sys from dataclasses import dataclass, asdict from functools import reduce from operator import itemgetter import numpy as np from itertools import groupby from pprint import pformat from typing import Union, Tuple, Generator, List, Set import simplejson from Octeres.bitmask import EnhancedBitmaskLike, EnhancedBitmask, bool_True, bool_False from Octeres.serialization import DataEncoder, DataDecoder """ invarients: all BitIntervals must have ordered and merged intervals. Evaluate LiteBitmask should be immutable. changing bit_idx_first and last to a property that's immutable. do a pass later an drop underscores on functions not needed outside. """ def intersection_possible(bitmaska, bitmaskb): if bitmaska.bit_count == 0 or bitmaskb.bit_count == 0: result = False elif bitmaska.bit_idx_first > bitmaskb.bit_idx_last: result = False elif bitmaska.bit_idx_last < bitmaskb.bit_idx_first: result = False else: result = True return result def intersection_any(bitmaska, bitmaskb): if not intersection_possible(bitmaska, bitmaskb): result = False else: if bitmaska.bit_idx_first == bitmaskb.bit_idx_first or bitmaska.bit_idx_first == bitmaskb.bit_idx_last: result = True elif bitmaska.bit_idx_last == bitmaskb.bit_idx_first or bitmaska.bit_idx_last == bitmaskb.bit_idx_last: result = True else: result = False ni = _get_combined_intervals_until_all_comsumed(bitmaska.bit_intervals, bitmaskb.bit_intervals) intervals = [] try: prev_int = next(ni) for next_int in ni: max_low = max(prev_int[0], next_int[0]) min_high = min(prev_int[1], next_int[1]) max_high = max(prev_int[1], next_int[1]) # print(f"{prev_int=}, {next_int=}, {max_low=}, {min_high=}, {max_high=}") if max_low <= min_high: intervals.append([max_low, min_high]) if max_high > min_high: prev_int = [min_high + 1, max_high] else: prev_int = next(ni) except StopIteration: pass for bit_idx in bitmaska.get_bits(bit_idx_first=bitmaskb.bit_idx_first): # todo: use first and last bits. # if bit_idx < other.bit_idx_first: # continue # elif bit_idx > other.bit_idx_last: # continue if bitmaskb[bit_idx]: result = True break return result IntervalList = List[List[int]] class BitIntervals: def __init__(self, bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=False): self.bit_count = bit_count if no_copy: self.intervals: List[List] = intervals else: self.intervals: List[List] = copy.deepcopy(intervals) self.bit_idx_first = bit_idx_first self.bit_idx_last = bit_idx_last self.validate() def to_dict(self): dct = dict( bit_count=self.bit_count, intervals=self.intervals, bit_idx_first=self.bit_idx_first, bit_idx_last=self.bit_idx_last, ) return dct @staticmethod def from_dict(dct): bit_intervals = BitIntervals( dct["bit_count"], dct["intervals"], dct["bit_idx_first"], dct["bit_idx_last"], ) bit_intervals.validate() return bit_intervals @staticmethod def from_intervals_list(intervals: IntervalList, validate=False) -> "BitIntervals": # intervals, and thus intervals. should be immutable, sorted and merged # assume sorted and no overlaps or consecutive intervals (i.e., [5,6],[7,8] should be [5,8]) bit_count = 0 idx_first = -1 idx_last = -1 # FIXME: priority one, create test to detect problems in interval ids intervals = copy.deepcopy(intervals) for start_end in intervals: if len(start_end) != 2: raise IntervalError(f"interval must be a length of 2:{start_end}") start, end = start_end assert end >= start bit_count += end - start + 1 if bit_count: idx_first = intervals[0][0] idx_last = intervals[-1][1] bit_intervals = BitIntervals(bit_count, intervals, idx_first, idx_last, no_copy=True) if validate: bit_intervals.validate() return bit_intervals def __eq__(self, other: "BitIntervals") -> bool: result = ( self.intervals == other.intervals and self.bit_count == other.bit_count and self.bit_idx_first == other.bit_idx_first and self.bit_idx_last == other.bit_idx_last ) return result def __repr__(self): result = f"BitIntervals({self.bit_count}, {self.intervals}, {self.bit_idx_first}, {self.bit_idx_last})" return result # def __getitem__(self, item): # result = self.intervals[item] # return result # # def __sub__(self, other): # return intervals_subtract(self.intervals, other.intervals) # # def __add__(self, other): # return intervals_or(self.intervals, other.intervals) # # def __xor__(self, other): # return intervals_xor(self.intervals, other.intervals) def validate(self): intervals = self.intervals if self.bit_count > 0: assert self.bit_idx_first == intervals[0][0] assert self.bit_idx_last == intervals[-1][1] else: assert self.bit_idx_first == -1 assert self.bit_idx_last == -1 len_intervals = len(intervals) def get_delta(bir): a = bir[0] b = bir[1] if a > b: raise IntervalError(f"negative interval: {bir}") return (b - a) + 1 if intervals: bit_count = reduce(lambda a, b: a + b, map(get_delta, intervals)) else: bit_count = 0 for i, alar in enumerate(self.intervals): al, ar = alar if al > ar: raise IntervalError(f"negative interval: {alar}") if (i + 1) < len_intervals: blbr = intervals[i + 1] if alar == blbr: raise IntervalError(f"duplicate interval: {alar}") bl, br = blbr if ar > bl: raise IntervalError(f"interval out of order or not merged: {ar} > {bl} for {alar}->{blbr}") assert bit_count == self.bit_count def get_interval_bits(bit_intervals: BitIntervals, bit_idx_first=0, bit_idx_last=None) -> List: """get the indexes that are set. use start if you are looking for indexes above a count. This is slow and there are probably other better ways to do this. All calls should evaluate this. """ intervals: List[List] = bit_intervals.intervals bit_lst = [] for interval in intervals: for i in range(interval[0], interval[1] + 1): if bit_idx_first and i >= bit_idx_first: continue if ( bit_idx_last and i >= bit_idx_last ): # FIXME: priority 1, evaulate >= and make sure it's right. use slice methodology x[1:4]. At unit tests continue bit_lst.append(i) return bit_lst def get_interval_bits_set_v0(bit_intervals: BitIntervals) -> set: """get the indexes that are set.""" # FIXME: priority 1: stop using this, it's slow return {idx for interval in bit_intervals.intervals for idx in [i for i in range(interval[0], interval[1] + 1)]} def get_interval_bits_set(bit_intervals: BitIntervals) -> Set[int]: bit_set: Set[int] = set() for interval in bit_intervals.intervals: bit_set.update(range(interval[0], interval[1] + 1)) return bit_set def intervals_subtract_v0(bit_intervals_a: BitIntervals, bit_intervals_b: BitIntervals) -> BitIntervals: """subtract the indexes of a list of intervals from another""" # FIXME: this needs work, it may be very slow. set_a = get_interval_bits_set(bit_intervals_a) set_b = get_interval_bits_set(bit_intervals_b) set_c = set_a - set_b result = index_list_to_bitintervals(set_c) return result def intervals_subtract_v1(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Directly perform subtract function on the intervals""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals result_intervals = [] result_intervals_append = result_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 while i < a_intervals_len: a_start, a_end = a_intervals[i] while j < b_intervals_len: b_start, b_end = b_intervals[j] if b_end < a_start: j += 1 continue if b_start > a_end: break if b_start <= a_start and b_end >= a_end: a_start = a_end + 1 break elif b_start <= a_start: a_start = b_end + 1 elif b_end >= a_end: result_intervals_append([a_start, b_start - 1]) a_start = a_end + 1 break else: result_intervals_append([a_start, b_start - 1]) a_start = b_end + 1 j += 1 if a_start <= a_end: result_intervals_append([a_start, a_end]) i += 1 if not result_intervals: return BitIntervals(0, [], -1, -1) merged_intervals = [] merged_intervals_append = merged_intervals.append start, end = result_intervals[0] result_intervals_len = len(result_intervals) for i in range(1, result_intervals_len): current_start, current_end = result_intervals[i] if current_start > end + 1: merged_intervals_append([start, end]) start, end = current_start, current_end else: end = max(end, current_end) merged_intervals_append([start, end]) bit_count = sum(end - start + 1 for start, end in merged_intervals) bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_subtract_v2(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Refine the logic during the interval walkthrough""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals result_intervals = [] result_intervals_append = result_intervals.append i, j = 0, 0 a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) while i < a_intervals_len: a_start, a_end = a_intervals[i] # Skip b_intervals that end before the current a_start while j < b_intervals_len and b_intervals[j][1] < a_start: j += 1 while j < b_intervals_len: b_start, b_end = b_intervals[j] if b_start > a_end: break # Handle overlapping intervals if b_start <= a_start and b_end >= a_end: a_start = a_end + 1 break elif b_start <= a_start: a_start = b_end + 1 elif b_end >= a_end: result_intervals_append([a_start, b_start - 1]) a_start = a_end + 1 break else: result_intervals_append([a_start, b_start - 1]) a_start = b_end + 1 j += 1 if a_start <= a_end: result_intervals_append([a_start, a_end]) i += 1 if not result_intervals: return BitIntervals(0, [], -1, -1) merged_intervals = [] start, end = result_intervals[0] result_intervals_len = len(result_intervals) bit_count = 0 for i in range(1, result_intervals_len): current_start, current_end = result_intervals[i] if current_start > end + 1: merged_intervals.append([start, end]) bit_count += end - start + 1 start, end = current_start, current_end else: end = max(end, current_end) merged_intervals.append([start, end]) bit_count += end - start + 1 bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_subtract_v3(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """An attempt to cut the merge interval process, directly generate the result during the walkthrough""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals final_intervals = [] final_intervals_append = final_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 start, end = None, None while i < a_intervals_len: a_start, a_end = a_intervals[i] while j < b_intervals_len: b_start, b_end = b_intervals[j] if b_end < a_start: j += 1 continue if b_start > a_end: break if b_start <= a_start and b_end >= a_end: a_start = a_end + 1 break elif b_start <= a_start: a_start = b_end + 1 elif b_end >= a_end: if a_start < b_start: if start is None: start, end = a_start, b_start - 1 else: if a_start > end + 1: final_intervals_append([start, end]) start, end = a_start, b_start - 1 else: end = b_start - 1 a_start = a_end + 1 break else: if start is None: start, end = a_start, b_start - 1 else: if a_start > end + 1: final_intervals_append([start, end]) start, end = a_start, b_start - 1 else: end = b_start - 1 a_start = b_end + 1 j += 1 if a_start <= a_end: if start is None: start, end = a_start, a_end else: if a_start > end + 1: final_intervals_append([start, end]) start, end = a_start, a_end else: end = a_end i += 1 if start is not None: final_intervals_append([start, end]) if not final_intervals: return BitIntervals(0, [], -1, -1) bit_count = sum(end - start + 1 for start, end in final_intervals) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last) bit_intervals.validate() return bit_intervals def intervals_subtract_v4(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Anothe attempt to cut the merge interval process, refining the walkthrough logic""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals final_intervals = [] final_intervals_append = final_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 start, end = None, None bit_count = 0 while i < a_intervals_len: a_start, a_end = a_intervals[i] while j < b_intervals_len: b_start, b_end = b_intervals[j] if b_end < a_start: j += 1 continue if b_start > a_end: break if b_start <= a_start and b_end >= a_end: a_start = a_end + 1 break elif b_start <= a_start: a_start = b_end + 1 elif b_end >= a_end: if a_start < b_start: if start is None: start, end = a_start, b_start - 1 else: if a_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = a_start, b_start - 1 else: end = b_start - 1 a_start = a_end + 1 break else: if start is None: start, end = a_start, b_start - 1 else: if a_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = a_start, b_start - 1 else: end = b_start - 1 a_start = b_end + 1 j += 1 if a_start <= a_end: if start is None: start, end = a_start, a_end else: if a_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = a_start, a_end else: end = a_end i += 1 if start is not None: final_intervals_append([start, end]) bit_count += end - start + 1 if not final_intervals: return BitIntervals(0, [], -1, -1) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last) bit_intervals.validate() return bit_intervals def _get_combined_intervals_until_all_comsumed(abi: BitIntervals, bbi: BitIntervals) -> Generator[IntervalList, None, None]: alst = abi.intervals blst = bbi.intervals aidx = 0 bidx = 0 while True: try: amin, amax = alst[aidx] except IndexError: amin, amax = None, None try: bmin, bmax = blst[bidx] except IndexError: bmin, bmax = None, None if amin is None and bmin is None: break if bmin is None or (amin is not None and amin <= bmin): r = [amin, amax] aidx += 1 else: r = [bmin, bmax] bidx += 1 yield r def _get_combined_intervals_until_overlap_impossible(abi: BitIntervals, bbi: BitIntervals) -> Generator[IntervalList, None, None]: alst = abi.intervals blst = bbi.intervals aidx = 0 bidx = 0 while True: try: amin, amax = alst[aidx] except IndexError: amin, amax = None, alst[-1][1] if len(alst) else -1 try: bmin, bmax = blst[bidx] except IndexError: bmin, bmax = None, blst[-1][1] if len(blst) else -1 if (amin is None and bmin is None) or (amin is None and bmin > amax) or (bmin is None and amin > bmax): break if bmin is None or (amin is not None and amin <= bmin): r = [amin, amax] aidx += 1 else: r = [bmin, bmax] bidx += 1 yield r def intervals_or_v0(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals: """does a logical or of two bit intervals""" interval_gen = _get_combined_intervals_until_all_comsumed(abi, bbi) try: prev_start, prev_end = next(interval_gen) intervals = [[prev_start, prev_end]] bit_count = prev_end - prev_start + 1 except StopIteration: # both of the BitIntervals lists were empty bit_count = 0 intervals = [] bit_idx_first = -1 bit_idx_last = -1 else: for next_start, next_end in interval_gen: if next_start <= prev_end + 1: new_end = max(prev_end, next_end) intervals[-1] = [prev_start, new_end] bit_count += new_end - prev_end prev_end = new_end else: intervals.append([next_start, next_end]) bit_count += next_end - next_start + 1 prev_start, prev_end = next_start, next_end bit_idx_first = intervals[0][0] bit_idx_last = intervals[-1][1] bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True) return bit_intervals def intervals_or_v4(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals: """Fixed the over merge problem occurred in v1 and v2""" a_intervals = abi.intervals b_intervals = bbi.intervals final_intervals = [] final_intervals_append = final_intervals.append i, j = 0, 0 a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) start, end = None, None def append_interval(start, end): if start is not None: if final_intervals: last_start, last_end = final_intervals[-1] if start <= last_end + 1: final_intervals[-1][1] = max(last_end, end) else: final_intervals_append([start, end]) else: final_intervals_append([start, end]) while i < a_intervals_len and j < b_intervals_len: a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: append_interval(start, end) start, end = a_start, a_end i += 1 elif b_end < a_start: append_interval(start, end) start, end = b_start, b_end j += 1 else: append_interval(start, end) start = min(a_start, b_start) end = max(a_end, b_end) if a_end > b_end: j += 1 elif a_end < b_end: i += 1 else: i += 1 j += 1 while i < a_intervals_len: a_start, a_end = a_intervals[i] append_interval(start, end) start, end = a_start, a_end i += 1 while j < b_intervals_len: b_start, b_end = b_intervals[j] append_interval(start, end) start, end = b_start, b_end j += 1 append_interval(start, end) if not final_intervals: return BitIntervals(0, [], -1, -1) bit_count = sum(end - start + 1 for start, end in final_intervals) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_or_v5(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals: """Optimized version of intervals_or with minimized condition checks""" a_intervals = abi.intervals b_intervals = bbi.intervals final_intervals = [] final_intervals_append = final_intervals.append i, j = 0, 0 a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) bit_count = 0 def merge_or_append_interval_v5(start, end, bit_count): if final_intervals and final_intervals[-1][1] >= start - 1: previous_end = final_intervals[-1][1] final_intervals[-1][1] = max(final_intervals[-1][1], end) bit_count += final_intervals[-1][1] - previous_end else: final_intervals_append([start, end]) bit_count += end - start + 1 return bit_count while i < a_intervals_len and j < b_intervals_len: if a_intervals[i][0] <= b_intervals[j][0]: start, end = a_intervals[i] i += 1 else: start, end = b_intervals[j] j += 1 while i < a_intervals_len and a_intervals[i][0] <= end + 1: end = max(end, a_intervals[i][1]) i += 1 while j < b_intervals_len and b_intervals[j][0] <= end + 1: end = max(end, b_intervals[j][1]) j += 1 bit_count = merge_or_append_interval_v5(start, end, bit_count) while i < a_intervals_len: bit_count = merge_or_append_interval_v5(a_intervals[i][0], a_intervals[i][1], bit_count) i += 1 while j < b_intervals_len: bit_count = merge_or_append_interval_v5(b_intervals[j][0], b_intervals[j][1], bit_count) j += 1 if not final_intervals: return BitIntervals(0, [], -1, -1) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_or_v6(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals: """Optimized version of intervals_or_v4 with cut the redundant overlap and margin check""" a_intervals = abi.intervals b_intervals = bbi.intervals final_intervals = [] final_intervals_append = final_intervals.append i, j = 0, 0 a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) while i < a_intervals_len and j < b_intervals_len: if a_intervals[i][1] < b_intervals[j][0]: if final_intervals and final_intervals[-1][1] >= a_intervals[i][0] - 1: final_intervals[-1][1] = max(final_intervals[-1][1], a_intervals[i][1]) else: final_intervals_append(a_intervals[i]) i += 1 elif b_intervals[j][1] < a_intervals[i][0]: if final_intervals and final_intervals[-1][1] >= b_intervals[j][0] - 1: final_intervals[-1][1] = max(final_intervals[-1][1], b_intervals[j][1]) else: final_intervals_append(b_intervals[j]) j += 1 else: start = min(a_intervals[i][0], b_intervals[j][0]) end = max(a_intervals[i][1], b_intervals[j][1]) while i < a_intervals_len and a_intervals[i][0] <= end: end = max(end, a_intervals[i][1]) i += 1 while j < b_intervals_len and b_intervals[j][0] <= end: end = max(end, b_intervals[j][1]) j += 1 if final_intervals and final_intervals[-1][1] >= start - 1: final_intervals[-1][1] = max(final_intervals[-1][1], end) else: final_intervals_append([start, end]) while i < a_intervals_len: if final_intervals and final_intervals[-1][1] >= a_intervals[i][0] - 1: final_intervals[-1][1] = max(final_intervals[-1][1], a_intervals[i][1]) else: final_intervals_append(a_intervals[i]) i += 1 while j < b_intervals_len: if final_intervals and final_intervals[-1][1] >= b_intervals[j][0] - 1: final_intervals[-1][1] = max(final_intervals[-1][1], b_intervals[j][1]) else: final_intervals_append(b_intervals[j]) j += 1 if not final_intervals: return BitIntervals(0, [], -1, -1) bit_count = sum(end - start + 1 for start, end in final_intervals) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def _merge_or_append_interval_v7(start, end, final_intervals, final_intervals_append, bit_count): """Deal with the interval merging boundary for v7 and calculate the bit_count""" if final_intervals and final_intervals[-1][1] >= start - 1: previous_end = final_intervals[-1][1] final_intervals[-1][1] = max(final_intervals[-1][1], end) bit_count += final_intervals[-1][1] - previous_end else: final_intervals_append([start, end]) bit_count += end - start + 1 return bit_count def intervals_or_v7(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals: """Try to take merging function outside the or function to see if it could be faster""" a_intervals = abi.intervals b_intervals = bbi.intervals final_intervals = [] final_intervals_append = final_intervals.append i, j = 0, 0 a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) bit_count = 0 while i < a_intervals_len and j < b_intervals_len: if a_intervals[i][0] <= b_intervals[j][0]: start, end = a_intervals[i] i += 1 else: start, end = b_intervals[j] j += 1 while i < a_intervals_len and a_intervals[i][0] <= end + 1: end = max(end, a_intervals[i][1]) i += 1 while j < b_intervals_len and b_intervals[j][0] <= end + 1: end = max(end, b_intervals[j][1]) j += 1 bit_count = _merge_or_append_interval_v7(start, end, final_intervals, final_intervals_append, bit_count) while i < a_intervals_len: bit_count = _merge_or_append_interval_v7( a_intervals[i][0], a_intervals[i][1], final_intervals, final_intervals_append, bit_count ) i += 1 while j < b_intervals_len: bit_count = _merge_or_append_interval_v7( b_intervals[j][0], b_intervals[j][1], final_intervals, final_intervals_append, bit_count ) j += 1 if not final_intervals: return BitIntervals(0, [], -1, -1) # bit_count = sum(end - start + 1 for start, end in final_intervals) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def logical_not(bitmask: "LiteBitmask") -> "LiteBitmask": """negate the intervals""" if bitmask.all(): bitmask = LiteBitmask.zeros(bitmask.length) elif bitmask.none(): bitmask = LiteBitmask.ones(bitmask.length) else: bi = bitmask.bit_intervals length = bitmask.length if bi.intervals and bi.intervals[-1][1] >= length: raise IndexError("last interval exceeds specified length") intervals = [] interval_append = intervals.append bit_count = 0 not_start = 0 for bi_start, bi_end in bi.intervals: if bi_start >= length: break not_end = bi_start - 1 if not_end >= 0: interval_append([not_start, not_end]) bit_count += not_end - not_start + 1 not_start = bi_end + 1 if not_start < length: interval_append([not_start, length - 1]) bit_count += length - not_start bitmask = LiteBitmask.zeros_and_set_intervals(bitmask.length, intervals) return bitmask def intervals_xor_set_v0(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """negate the intervals""" # FIXME: possibly use (a-b)|(b-a), but probably better to use iteration of one list (b?) and comparison with an iteration of # the other list (a?). performance comparisons needed for the size of the two intervals. # FIXME: this needs work, it may be very slow. set_a = get_interval_bits_set(a_bit_intervals) set_b = get_interval_bits_set(b_bit_intervals) set_c = set_a.symmetric_difference(set_b) result = index_list_to_bitintervals_v0(set_c) return result def intervals_xor_set_v1(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Test a new version index_list_to_bitintervals""" set_a = get_interval_bits_set(a_bit_intervals) set_b = get_interval_bits_set(b_bit_intervals) set_c = set_a.symmetric_difference(set_b) result = index_list_to_bitintervals_v1(set_c) return result def intervals_xor_set_v2(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Test another version of index_list_to_bitintervals""" set_a = get_interval_bits_set(a_bit_intervals) set_b = get_interval_bits_set(b_bit_intervals) set_c = set_a.symmetric_difference(set_b) result = index_list_to_bitintervals_v2(set_c) return result def intervals_xor_v0(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Directly do xor on the intervals to avoid data type change""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals result_intervals = [] i, j = 0, 0 while i < len(a_intervals) and j < len(b_intervals): a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: result_intervals.append([a_start, a_end]) i += 1 elif b_end < a_start: result_intervals.append([b_start, b_end]) j += 1 else: if a_start < b_start: result_intervals.append([a_start, b_start - 1]) if a_end > b_end: result_intervals.append([b_end + 1, a_end]) j += 1 elif a_end < b_end: result_intervals.append([a_end + 1, b_end]) i += 1 else: i += 1 j += 1 while i < len(a_intervals): result_intervals.append(a_intervals[i]) i += 1 while j < len(b_intervals): result_intervals.append(b_intervals[j]) j += 1 merged_intervals = [] for start, end in result_intervals: if not merged_intervals or merged_intervals[-1][1] < start - 1: merged_intervals.append([start, end]) else: merged_intervals[-1][1] = max(merged_intervals[-1][1], end) bit_count = sum(interval[1] - interval[0] + 1 for interval in merged_intervals) bit_idx_first = merged_intervals[0][0] if merged_intervals else -1 bit_idx_last = merged_intervals[-1][1] if merged_intervals else -1 return BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last) def intervals_xor_v4(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Using copy to avoid the address problem when updating the value on original intervals""" a_intervals = a_bit_intervals.intervals.copy() b_intervals = b_bit_intervals.intervals.copy() result_intervals = [] result_intervals_append = result_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 while i < a_intervals_len and j < b_intervals_len: a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: result_intervals_append([a_start, a_end]) i += 1 elif b_end < a_start: result_intervals_append([b_start, b_end]) j += 1 else: if a_start < b_start: result_intervals_append([a_start, b_start - 1]) a_start = b_start if b_start < a_start: result_intervals_append([b_start, a_start - 1]) if a_end > b_end: a_intervals[i] = [b_end + 1, a_end] j += 1 elif a_end < b_end: b_intervals[j] = [a_end + 1, b_end] i += 1 else: i += 1 j += 1 while i < a_intervals_len: start, end = a_intervals[i] result_intervals_append([start, end]) i += 1 while j < b_intervals_len: start, end = b_intervals[j] result_intervals_append([start, end]) j += 1 if not result_intervals: return BitIntervals(0, [], -1, -1) merged_intervals = [] merged_intervals_append = merged_intervals.append start, end = result_intervals[0] result_intervals_len = len(result_intervals) for i in range(1, result_intervals_len): current_start, current_end = result_intervals[i] if current_start > end + 1: merged_intervals_append([start, end]) start, end = current_start, current_end else: end = max(end, current_end) merged_intervals_append([start, end]) bit_count = sum(end - start + 1 for start, end in merged_intervals) bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_xor_v5(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Using temp intervals to avoid using copy for fast xor function""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals result_intervals = [] result_intervals_append = result_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 a_temp_interval = None b_temp_interval = None while i < a_intervals_len and j < b_intervals_len: if a_temp_interval is None: a_start, a_end = a_intervals[i] else: a_start, a_end = a_temp_interval if b_temp_interval is None: b_start, b_end = b_intervals[j] else: b_start, b_end = b_temp_interval if a_end < b_start: result_intervals_append([a_start, a_end]) a_temp_interval = None i += 1 elif b_end < a_start: result_intervals_append([b_start, b_end]) b_temp_interval = None j += 1 else: if a_start < b_start: result_intervals_append([a_start, b_start - 1]) a_start = b_start if b_start < a_start: result_intervals_append([b_start, a_start - 1]) b_start = a_start if a_end > b_end: a_temp_interval = [b_end + 1, a_end] b_temp_interval = None j += 1 elif a_end < b_end: b_temp_interval = [a_end + 1, b_end] a_temp_interval = None i += 1 else: a_temp_interval = None b_temp_interval = None i += 1 j += 1 while i < a_intervals_len: if a_temp_interval is None: start, end = a_intervals[i] else: start, end = a_temp_interval a_temp_interval = None result_intervals_append([start, end]) i += 1 while j < b_intervals_len: if b_temp_interval is None: start, end = b_intervals[j] else: start, end = b_temp_interval b_temp_interval = None result_intervals_append([start, end]) j += 1 if not result_intervals: return BitIntervals(0, [], -1, -1) merged_intervals = [] merged_intervals_append = merged_intervals.append bit_count = 0 start, end = result_intervals[0] result_intervals_len = len(result_intervals) for i in range(1, result_intervals_len): current_start, current_end = result_intervals[i] if current_start > end + 1: merged_intervals_append([start, end]) bit_count += end - start + 1 start, end = current_start, current_end else: end = max(end, current_end) merged_intervals_append([start, end]) bit_count += end - start + 1 bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_xor_v6(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """The usage temp intervals could abandon the temp array of merged intervals""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals final_intervals = [] final_intervals_append = final_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 a_temp_interval = None b_temp_interval = None start, end = None, None bit_count = 0 while i < a_intervals_len and j < b_intervals_len: if a_temp_interval is None: a_start, a_end = a_intervals[i] else: a_start, a_end = a_temp_interval if b_temp_interval is None: b_start, b_end = b_intervals[j] else: b_start, b_end = b_temp_interval if a_end < b_start: if start is None: start, end = a_start, a_end else: if a_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = a_start, a_end else: end = a_end a_temp_interval = None i += 1 elif b_end < a_start: if start is None: start, end = b_start, b_end else: if b_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = b_start, b_end else: end = b_end b_temp_interval = None j += 1 else: if a_start < b_start: if start is None: start, end = a_start, b_start - 1 else: if a_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = a_start, b_start - 1 else: end = b_start - 1 a_start = b_start if b_start < a_start: if start is None: start, end = b_start, a_start - 1 else: if b_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = b_start, a_start - 1 else: end = a_start - 1 b_start = a_start if a_end > b_end: a_temp_interval = [b_end + 1, a_end] b_temp_interval = None j += 1 elif a_end < b_end: b_temp_interval = [a_end + 1, b_end] a_temp_interval = None i += 1 else: a_temp_interval = None b_temp_interval = None i += 1 j += 1 while i < a_intervals_len: if a_temp_interval is None: start_interval, end_interval = a_intervals[i] else: start_interval, end_interval = a_temp_interval a_temp_interval = None if start is None: start, end = start_interval, end_interval else: if start_interval > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = start_interval, end_interval else: end = end_interval i += 1 while j < b_intervals_len: if b_temp_interval is None: start_interval, end_interval = b_intervals[j] else: start_interval, end_interval = b_temp_interval b_temp_interval = None if start is None: start, end = start_interval, end_interval else: if start_interval > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = start_interval, end_interval else: end = end_interval j += 1 if start is not None: final_intervals_append([start, end]) bit_count += end - start + 1 if not final_intervals: return BitIntervals(0, [], -1, -1) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def handle_non_overlapping(start, end, bit_count, interval_start, interval_end, final_intervals_append): """Appending the non_overlapping intervals into the result, and calculate the bit_count""" if start is None: start, end = interval_start, interval_end else: if interval_start > end + 1: final_intervals_append([start, end]) bit_count += end - start + 1 start, end = interval_start, interval_end else: end = interval_end return start, end, bit_count def process_remaining_intervals(start, end, bit_count, intervals, idx, temp_interval, final_intervals_append): """Deal with the remaining intervals not processed during the first walkthrough""" while idx < len(intervals): if temp_interval is None: start_interval, end_interval = intervals[idx] else: start_interval, end_interval = temp_interval temp_interval = None start, end, bit_count = handle_non_overlapping(start, end, bit_count, start_interval, end_interval, final_intervals_append) idx += 1 return start, end, bit_count def intervals_xor_v7(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Take out similar conditions and boundary check as a function, might be slower""" a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals final_intervals = [] final_intervals_append = final_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 a_temp_interval = None b_temp_interval = None start, end = None, None bit_count = 0 while i < a_intervals_len and j < b_intervals_len: if a_temp_interval is None: a_start, a_end = a_intervals[i] else: a_start, a_end = a_temp_interval if b_temp_interval is None: b_start, b_end = b_intervals[j] else: b_start, b_end = b_temp_interval if a_end < b_start: start, end, bit_count = handle_non_overlapping(start, end, bit_count, a_start, a_end, final_intervals_append) a_temp_interval = None i += 1 elif b_end < a_start: start, end, bit_count = handle_non_overlapping(start, end, bit_count, b_start, b_end, final_intervals_append) b_temp_interval = None j += 1 else: if a_start < b_start: start, end, bit_count = handle_non_overlapping(start, end, bit_count, a_start, b_start - 1, final_intervals_append) a_start = b_start if b_start < a_start: start, end, bit_count = handle_non_overlapping(start, end, bit_count, b_start, a_start - 1, final_intervals_append) b_start = a_start if a_end > b_end: a_temp_interval = [b_end + 1, a_end] b_temp_interval = None j += 1 elif a_end < b_end: b_temp_interval = [a_end + 1, b_end] a_temp_interval = None i += 1 else: a_temp_interval = None b_temp_interval = None i += 1 j += 1 start, end, bit_count = process_remaining_intervals( start, end, bit_count, a_intervals, i, a_temp_interval, final_intervals_append ) start, end, bit_count = process_remaining_intervals( start, end, bit_count, b_intervals, j, b_temp_interval, final_intervals_append ) if start is not None: final_intervals_append([start, end]) bit_count += end - start + 1 if not final_intervals: return BitIntervals(0, [], -1, -1) bit_idx_first = final_intervals[0][0] bit_idx_last = final_intervals[-1][1] bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return bit_intervals def intervals_and_v0(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """does an and of two intervals""" ni = _get_combined_intervals_until_overlap_impossible(a_bit_intervals, b_bit_intervals) intervals = [] bit_count = 0 try: prev_int = next(ni) next_int: list[int] for next_int in ni: max_low: int = max(prev_int[0], next_int[0]) min_high: int = min(prev_int[1], next_int[1]) max_high: int = max(prev_int[1], next_int[1]) # print(f"{prev_int=}, {next_int=}, {max_low=}, {min_high=}, {max_high=}") if max_low <= min_high: bit_count += (min_high - max_low) + 1 # intervals need the +1 intervals.append([max_low, min_high]) if max_high > min_high: prev_int = [min_high + 1, max_high] else: prev_int = next(ni) except StopIteration: pass try: bit_idx_first = intervals[0][0] bit_idx_last = intervals[-1][1] except IndexError: bit_idx_first = -1 bit_idx_last = -1 bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last) return bit_intervals def intervals_and_v1(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals: """Use a better function to check the overlap""" # Check for full overlap if ( a_bit_intervals.bit_count == b_bit_intervals.bit_count and a_bit_intervals.bit_idx_first == b_bit_intervals.bit_idx_first and a_bit_intervals.bit_idx_last == b_bit_intervals.bit_idx_last and len(a_bit_intervals.intervals) == len(b_bit_intervals.intervals) ): return BitIntervals( a_bit_intervals.bit_count, a_bit_intervals.intervals, a_bit_intervals.bit_idx_first, a_bit_intervals.bit_idx_last, ) ni = _get_combined_intervals_until_overlap_impossible(a_bit_intervals, b_bit_intervals) intervals = [] bit_count = 0 try: prev_int = next(ni) next_int: list[int] for next_int in ni: max_low: int = max(prev_int[0], next_int[0]) min_high: int = min(prev_int[1], next_int[1]) max_high: int = max(prev_int[1], next_int[1]) # print(f"{prev_int=}, {next_int=}, {max_low=}, {min_high=}, {max_high=}") if max_low <= min_high: bit_count += (min_high - max_low) + 1 intervals.append([max_low, min_high]) if max_high > min_high: prev_int = [min_high + 1, max_high] else: prev_int = next(ni) except StopIteration: pass try: bit_idx_first = intervals[0][0] bit_idx_last = intervals[-1][1] except IndexError: bit_idx_first = -1 bit_idx_last = -1 bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True) return bit_intervals def logical_subtract(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Perform logical_subtract directly on bitmasks""" a_bit_intervals = a_bitmask.bit_intervals b_bit_intervals = b_bitmask.bit_intervals a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals result_intervals = [] result_intervals_append = result_intervals.append i, j = 0, 0 while i < len(a_intervals) and j < len(b_intervals): a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: result_intervals_append([a_start, a_end]) i += 1 elif b_end < a_start: j += 1 else: if a_start < b_start: result_intervals_append([a_start, b_start - 1]) if a_end > b_end: a_intervals[i] = [b_end + 1, a_end] j += 1 elif a_end < b_end: i += 1 else: i += 1 j += 1 while i < len(a_intervals): result_intervals_append(a_intervals[i]) i += 1 # Merge overlapping intervals if not result_intervals: result = BitIntervals(0, [], -1, -1) else: merged_intervals = [] merged_intervals_append = merged_intervals.append bit_count = 0 current_start, current_end = result_intervals[0] for start, end in result_intervals[1:]: if start > current_end + 1: merged_intervals_append([current_start, current_end]) bit_count += current_end - current_start + 1 current_start, current_end = start, end else: current_end = max(current_end, end) merged_intervals_append([current_start, current_end]) bit_count += current_end - current_start + 1 bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] result = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last) bitmask = LiteBitmask.zeros_and_set_intervals(a_bitmask.length, result.intervals) return bitmask def logical_xnor_v0(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Call the intervals_xor and logical_not to implement the xor function""" bit_intervals = intervals_xor(a_bitmask.bit_intervals, b_bitmask.bit_intervals) bitmask = LiteBitmask(a_bitmask.length, bit_intervals) bitmask = logical_not(bitmask) return bitmask def logical_xnor_v1(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Try to perform xor function in two walkthrough""" a_bit_intervals = a_bitmask.bit_intervals b_bit_intervals = b_bitmask.bit_intervals a_intervals = a_bit_intervals.intervals b_intervals = b_bit_intervals.intervals result_intervals = [] result_intervals_append = result_intervals.append i, j = 0, 0 last_position = -1 # Track the last position we processed while i < len(a_intervals) and j < len(b_intervals): a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: # Append intervals where both are 0s (i.e., gaps) if last_position + 1 < a_start: result_intervals_append([last_position + 1, a_start - 1]) last_position = a_end i += 1 elif b_end < a_start: if last_position + 1 < b_start: result_intervals_append([last_position + 1, b_start - 1]) last_position = b_end j += 1 else: start = max(a_start, b_start) end = min(a_end, b_end) result_intervals_append([start, end]) if a_end > b_end: last_position = b_end j += 1 elif b_end > a_end: last_position = a_end i += 1 else: last_position = a_end i += 1 j += 1 # Append any remaining intervals where both are 0s while i < len(a_intervals): a_start, a_end = a_intervals[i] if last_position + 1 < a_start: result_intervals_append([last_position + 1, a_start - 1]) last_position = a_end i += 1 while j < len(b_intervals): b_start, b_end = b_intervals[j] if last_position + 1 < b_start: result_intervals_append([last_position + 1, b_start - 1]) last_position = b_end j += 1 last_bit = max(a_bitmask.length, b_bitmask.length) - 1 if last_position < last_bit: result_intervals_append([last_position + 1, last_bit]) # Merge overlapping intervals if not result_intervals: result = BitIntervals(0, [], -1, -1) else: merged_intervals = [] merged_intervals_append = merged_intervals.append bit_count = 0 current_start, current_end = result_intervals[0] for start, end in result_intervals[1:]: if start > current_end + 1: merged_intervals_append([current_start, current_end]) bit_count += current_end - current_start + 1 current_start, current_end = start, end else: current_end = max(current_end, end) merged_intervals_append([current_start, current_end]) bit_count += current_end - current_start + 1 bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] result = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last) bitmask = LiteBitmask.zeros_and_set_intervals(a_bitmask.length, result.intervals) return bitmask def logical_xnor_v2(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Refining the logic of the start and end of each sub_intervals""" a_bit_intervals = a_bitmask.bit_intervals b_bit_intervals = b_bitmask.bit_intervals a_intervals = a_bit_intervals.intervals.copy() b_intervals = b_bit_intervals.intervals.copy() result_intervals = [] result_intervals_append = result_intervals.append i, j = 0, 0 while i < len(a_intervals) and j < len(b_intervals): a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: result_intervals_append([a_start, a_end]) i += 1 elif b_end < a_start: result_intervals_append([b_start, b_end]) j += 1 else: start = min(a_start, b_start) end = max(a_end, b_end) result_intervals_append([start, end]) if a_end > b_end: a_intervals[i] = [b_end + 1, a_end] j += 1 elif b_end > a_end: b_intervals[j] = [a_end + 1, b_end] i += 1 else: i += 1 j += 1 while i < len(a_intervals): result_intervals_append(a_intervals[i]) i += 1 while j < len(b_intervals): result_intervals_append(b_intervals[j]) j += 1 last_bit = max(a_bitmask.length, b_bitmask.length) - 1 if result_intervals: last_result_bit = result_intervals[-1][1] if last_result_bit < last_bit: result_intervals_append([last_result_bit + 1, last_bit]) else: result_intervals_append([0, last_bit]) # Merge overlapping intervals and handle gaps if not result_intervals: return LiteBitmask.zeros(a_bitmask.length) merged_intervals = [] merged_intervals_append = merged_intervals.append start, end = result_intervals[0] result_intervals_len = len(result_intervals) for i in range(1, result_intervals_len): current_start, current_end = result_intervals[i] if current_start > end + 1: # Include the gap merged_intervals_append([start, end]) merged_intervals_append([end + 1, current_start - 1]) start, end = current_start, current_end else: end = max(end, current_end) merged_intervals_append([start, end]) bit_count = sum(end - start + 1 for start, end in merged_intervals) bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] result = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last) result.validate() bitmask = LiteBitmask.zeros_and_set_intervals(a_bitmask.length, result.intervals) return bitmask def logical_xnor_v3(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Another effort dealing with the overlap in both intervals""" a_intervals = a_bitmask.bit_intervals.intervals b_intervals = b_bitmask.bit_intervals.intervals length = a_bitmask.length final_intervals = [] a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) final_intervals_append = final_intervals.append i, j = 0, 0 last_end = -1 while i < a_intervals_len and j < b_intervals_len: a_start, a_end = a_intervals[i] b_start, b_end = b_intervals[j] if a_end < b_start: if last_end < a_start - 1: final_intervals_append([last_end + 1, a_start - 1]) last_end = a_end i += 1 elif b_end < a_start: if last_end < b_start - 1: final_intervals_append([last_end + 1, b_start - 1]) last_end = b_end j += 1 else: if last_end < min(a_start, b_start) - 1: final_intervals_append([last_end + 1, min(a_start, b_start) - 1]) overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) final_intervals_append([overlap_start, overlap_end]) last_end = overlap_end if a_end > b_end: j += 1 elif a_end < b_end: i += 1 else: i += 1 j += 1 while i < a_intervals_len: a_start, a_end = a_intervals[i] if last_end < a_start - 1: final_intervals_append([last_end + 1, a_start - 1]) last_end = a_end i += 1 while j < b_intervals_len: b_start, b_end = b_intervals[j] if last_end < b_start - 1: final_intervals_append([last_end + 1, b_start - 1]) last_end = b_end j += 1 if last_end < length - 1: final_intervals_append([last_end + 1, length - 1]) if not final_intervals and not a_intervals and not b_intervals: final_intervals = [[0, length - 1]] merged_intervals = [] for interval in final_intervals: if not merged_intervals or merged_intervals[-1][1] < interval[0] - 1: merged_intervals.append(interval) else: merged_intervals[-1][1] = max(merged_intervals[-1][1], interval[1]) if not merged_intervals: bit_count = 0 bit_idx_first = -1 bit_idx_last = -1 else: bit_count = sum(end - start + 1 for start, end in merged_intervals) bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return LiteBitmask(length, bit_intervals) def logical_xnor_v4(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Usage of add_intervals to both speed up the append process and avoid merge process""" a_intervals = a_bitmask.bit_intervals.intervals b_intervals = b_bitmask.bit_intervals.intervals length = a_bitmask.length merged_intervals = [] merged_intervals_append = merged_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 last_end = -1 def add_interval(start, end): if not merged_intervals or merged_intervals[-1][1] < start - 1: merged_intervals_append([start, end]) else: merged_intervals[-1][1] = max(merged_intervals[-1][1], end) while i < a_intervals_len or j < b_intervals_len: if i < a_intervals_len: a_start, a_end = a_intervals[i] else: a_start, a_end = length, length if j < b_intervals_len: b_start, b_end = b_intervals[j] else: b_start, b_end = length, length if a_end < b_start: if last_end < a_start - 1: add_interval(last_end + 1, a_start - 1) last_end = a_end i += 1 elif b_end < a_start: if last_end < b_start - 1: add_interval(last_end + 1, b_start - 1) last_end = b_end j += 1 else: if last_end < min(a_start, b_start) - 1: add_interval(last_end + 1, min(a_start, b_start) - 1) overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) add_interval(overlap_start, overlap_end) last_end = overlap_end if a_end > b_end: j += 1 elif a_end < b_end: i += 1 else: i += 1 j += 1 if last_end < length - 1: add_interval(last_end + 1, length - 1) if not merged_intervals: bit_count = 0 bit_idx_first = -1 bit_idx_last = -1 else: bit_count = sum(end - start + 1 for start, end in merged_intervals) bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return LiteBitmask(length, bit_intervals) def logical_xnor_v5(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Avoid the merge process, and refine the logic in appending in merge intervals by updating the boundary""" a_intervals = a_bitmask.bit_intervals.intervals b_intervals = b_bitmask.bit_intervals.intervals length = a_bitmask.length merged_intervals = [] merged_intervals_append = merged_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 last_end = -1 while i < a_intervals_len or j < b_intervals_len: if i < a_intervals_len: a_start, a_end = a_intervals[i] else: a_start, a_end = length, length if j < b_intervals_len: b_start, b_end = b_intervals[j] else: b_start, b_end = length, length # Append the non-overlapping intervals into the result if a_end < b_start: if last_end < a_start - 1: if not merged_intervals or merged_intervals[-1][1] < last_end: merged_intervals_append([last_end + 1, a_start - 1]) else: merged_intervals[-1][1] = a_start - 1 last_end = a_end i += 1 elif b_end < a_start: if last_end < b_start - 1: if not merged_intervals or merged_intervals[-1][1] < last_end: merged_intervals_append([last_end + 1, b_start - 1]) else: merged_intervals[-1][1] = b_start - 1 last_end = b_end j += 1 else: if last_end < min(a_start, b_start) - 1: if not merged_intervals or merged_intervals[-1][1] < last_end: merged_intervals_append([last_end + 1, min(a_start, b_start) - 1]) else: merged_intervals[-1][1] = min(a_start, b_start) - 1 overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) if not merged_intervals or merged_intervals[-1][1] < overlap_start - 1: merged_intervals_append([overlap_start, overlap_end]) else: merged_intervals[-1][1] = overlap_end last_end = overlap_end if a_end > b_end: j += 1 elif a_end < b_end: i += 1 else: i += 1 j += 1 if last_end < length - 1: if not merged_intervals or merged_intervals[-1][1] < last_end: merged_intervals_append([last_end + 1, length - 1]) else: merged_intervals[-1][1] = length - 1 if not merged_intervals: bit_count = 0 bit_idx_first = -1 bit_idx_last = -1 else: bit_count = sum(end - start + 1 for start, end in merged_intervals) bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return LiteBitmask(length, bit_intervals) def logical_xnor_v6_append_interval(merged_intervals, merge_intervals_append, last_end, new_start, new_end, bit_count): """A function with merge interval conditions and sum of bit_count to avoid redundant code, but it is slower""" if not merged_intervals or merged_intervals[-1][1] < last_end: merge_intervals_append([new_start, new_end]) bit_count += new_end - new_start + 1 else: bit_count += new_end - merged_intervals[-1][1] merged_intervals[-1][1] = new_end return bit_count def logical_xnor_v6(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask": """Avoid the merge process, and refine the logic in appending in merge intervals by updating the boundary""" a_intervals = a_bitmask.bit_intervals.intervals b_intervals = b_bitmask.bit_intervals.intervals length = a_bitmask.length merged_intervals = [] merged_intervals_append = merged_intervals.append a_intervals_len = len(a_intervals) b_intervals_len = len(b_intervals) i, j = 0, 0 last_end = -1 bit_count = 0 while i < a_intervals_len or j < b_intervals_len: if i < a_intervals_len: a_start, a_end = a_intervals[i] else: a_start, a_end = length, length if j < b_intervals_len: b_start, b_end = b_intervals[j] else: b_start, b_end = length, length # Append the non-overlapping intervals into the result if a_end < b_start: if last_end < a_start - 1: bit_count = logical_xnor_v6_append_interval( merged_intervals, merged_intervals_append, last_end, last_end + 1, a_start - 1, bit_count ) last_end = a_end i += 1 elif b_end < a_start: if last_end < b_start - 1: bit_count = logical_xnor_v6_append_interval( merged_intervals, merged_intervals_append, last_end, last_end + 1, b_start - 1, bit_count ) last_end = b_end j += 1 else: if last_end < min(a_start, b_start) - 1: bit_count = logical_xnor_v6_append_interval( merged_intervals, merged_intervals_append, last_end, last_end + 1, min(a_start, b_start) - 1, bit_count ) overlap_start = max(a_start, b_start) overlap_end = min(a_end, b_end) bit_count = logical_xnor_v6_append_interval( merged_intervals, merged_intervals_append, overlap_start - 1, overlap_start, overlap_end, bit_count ) last_end = overlap_end if a_end > b_end: j += 1 elif a_end < b_end: i += 1 else: i += 1 j += 1 if last_end < length - 1: bit_count = logical_xnor_v6_append_interval( merged_intervals, merged_intervals_append, last_end, last_end + 1, length - 1, bit_count ) if not merged_intervals: bit_idx_first = -1 bit_idx_last = -1 else: bit_idx_first = merged_intervals[0][0] bit_idx_last = merged_intervals[-1][1] bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True) bit_intervals.validate() return LiteBitmask(length, bit_intervals) class IntervalError(Exception): pass def convert_intervals_to_bitintervals(intervals: IntervalList) -> BitIntervals: if intervals: def get_delta(bir): a = bir[0] b = bir[1] if a > b: raise IntervalError(f"negative interval: {bir}") return (b - a) + 1 bit_count = reduce(lambda a, b: a + b, map(get_delta, intervals)) else: bit_count = 0 try: bit_idx_first = intervals[0][0] bit_idx_last = intervals[-1][1] except IndexError: bit_idx_first = -1 bit_idx_last = -1 bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last) return bit_intervals def index_list_to_bitintervals_v0(idx_lst: Union[Set[int], List[int]], limit=None, sort=True) -> BitIntervals: """converting list to bit_intervals""" if idx_lst: if sort: if type(idx_lst) == set: idx_lst = sorted(idx_lst) else: idx_lst = sorted(set(idx_lst)) bit_count = len(idx_lst) bit_idx_first = idx_lst[0] bit_idx_last = idx_lst[-1] intervals = [] count = 0 # TODO: attempting rewrite. # intervals2 = [] # idx = 0 # len_idx_lst = len(idx_lst) # a = idx_lst[idx] # b = idx_lst[idx] # pb = b # while idx < len_idx_lst or (limit is not None and idx < limit): # if b > (pb + 1): # # not sequential. # intervals2.append([a, a]) # if a == b: # pb = b # else: # # a != b # if (pb + 1) == b: # # sequential # pass # else: # # start a new interval # intervals2.append([a, pb]) # a = idx_lst[idx] # idx += 1 # try: # pb = b # b = idx_lst[idx] # except IndexError: # intervals2.append([a, pb]) # VERY SLOW. for key, group in groupby(enumerate(idx_lst), lambda ix: ix[0] - ix[1]): interval = [item[1] for item in group] intervals.append([interval[0], interval[-1]]) count += 1 if limit and limit >= count: raise Exception(f"Over limit:{limit}") # assert intervals == intervals2 result = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True) else: result = BitIntervals(0, [], -1, -1) return result def index_list_to_bitintervals_v1(idx_lst: Union[Set[int], List[int]], limit=None, sort=True) -> BitIntervals: """Improved function by avoid using groupby""" if idx_lst: if sort: if type(idx_lst) == set: idx_lst = sorted(idx_lst) else: idx_lst = sorted(set(idx_lst)) bit_count = len(idx_lst) bit_idx_first = idx_lst[0] bit_idx_last = idx_lst[-1] intervals = [] count = 0 start = idx_lst[0] prev = idx_lst[0] len_idx_lst = len(idx_lst) intervals_append = intervals.append for i in range(1, len_idx_lst): current = idx_lst[i] if current != prev + 1: intervals_append([start, prev]) start = current count += 1 if limit and count >= limit: raise Exception(f"Over limit: {limit}") prev = current intervals_append([start, prev]) result = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True) else: result = BitIntervals(0, [], -1, -1) return result def index_list_to_bitintervals_v2(idx_lst: Union[Set[int], List[int]], limit=None, sort=True) -> BitIntervals: """Refine the logic with early return and sort check""" if not idx_lst: return BitIntervals(0, [], -1, -1) if sort: idx_lst = sorted(idx_lst) if isinstance(idx_lst, set) else sorted(set(idx_lst)) else: idx_lst = sorted(set(idx_lst)) intervals = [] start = prev = idx_lst[0] count = 0 len_idx_lst = len(idx_lst) intervals_append = intervals.append for i in range(1, len_idx_lst): current = idx_lst[i] if current != prev + 1: intervals_append([start, prev]) start = current count += 1 if limit and count >= limit: raise Exception(f"Over limit: {limit}") prev = current intervals_append([start, prev]) return BitIntervals(len_idx_lst, intervals, idx_lst[0], idx_lst[-1], no_copy=True) def repr_LiteBitmask_zasb(bitmask: "LiteBitmask"): bits = (str(i) for i in bitmask.get_bits()) result = f"LiteBitmask.zeros_and_set_bits({bitmask.length}, ({', '.join(bits)}))" return result def repr_LiteBitmask_zasi(bitmask: "LiteBitmask"): result = f"LiteBitmask.zeros_and_set_intervals({bitmask.length}, {str(bitmask.bit_intervals.intervals)})" return result class BitmaskFlags: writeable = False keys = ("writeable",) @staticmethod def from_dict(dct): bf = BitmaskFlags() for key, value in dct.items(): if key in BitmaskFlags.keys: setattr(bf, key, value) else: raise Exception(f"invalid key {key} for BitmaskFlags") return bf def to_dict(self): return dict(writeable=self.writeable) class LiteBitmask(EnhancedBitmaskLike): length: int = 0 # same as size bit_count: int = 0 bit_idx_first = -1 bit_idx_last = -1 flags = None repr_func = repr_LiteBitmask_zasi def __init__(self, length, bit_intervals: BitIntervals): self.length: int = length self.bit_count: int = bit_intervals.bit_count self.bit_idx_first: int = bit_intervals.bit_idx_first self.bit_idx_last: int = bit_intervals.bit_idx_last # self.intervals: list = bit_intervals.intervals self.bit_intervals: BitIntervals = bit_intervals if self.bit_idx_last > (self.length - 1): raise Exception(f"Bit set out of range: {self.bit_idx_last} length:{self.length}") self.flags = BitmaskFlags() # def __new__(cls, *args, **kwargs): def __setstate__(self, state): """used to recreate the object after pickle""" self.length = state["length"] self.bit_count = state["bit_count"] self.bit_idx_first = state["bit_idx_first"] self.bit_idx_last = state["bit_idx_last"] self.bit_intervals = BitIntervals.from_dict(state["bit_intervals"]) self.flags = BitmaskFlags.from_dict(state["flags"]) def __getstate__(self): """used to dump the state of an object for pickling""" return self.to_dict() def intersection_any(self, other: EnhancedBitmaskLike): return intersection_any(self, other) def __repr__(self): # https://github.com/pandas-dev/pandas/issues/17695 # from pandas.io.formats import printing as pd_printing # pd_printing.is_sequence = lambda obj: False return LiteBitmask.repr_func(self) def __getitem__(self, item): # fixme: this doesn't handle negatives in the slices. if not isinstance(item, slice): if item >= self.length: raise IndexError for ( s, e, ) in self.bit_intervals.intervals: # TODO: do a binary search or some neat source. if s <= item <= e: return True return False else: lst = [] c_idx = 0 # if item.start >= self.length or item.stop >= self.length: # pass # if the array is start = item.start stop = item.stop # if you ask beyond the end of the array, isolate the range to the stop. if stop >= self.length: stop = self.length step = item.step rangeish = [i for i in [start, stop, step] if i is not None] for idx in range(*rangeish): # Nlog(M) found = False for i, (s, e) in enumerate(self.bit_intervals.intervals[c_idx:]): if s <= idx <= e: found = True c_idx = c_idx + i break lst.append(found) return lst def __or__(self, other) -> "LiteBitmask": return LiteBitmask.logical_or(self, other) def __and__(self, other) -> "LiteBitmask": return LiteBitmask.logical_and(self, other) def __sub__(self, other): if intersection_possible(self, other): indexes_self = get_interval_bits_set(self.bit_intervals) indexes_other = get_interval_bits_set(other.bit_intervals) indexes = indexes_self - indexes_other # fixme: todo # indexes = intervals_subtract(self.intervals, other.intervals) # this might be slow. the set operation isn't ordered so we need order it. bitmask = LiteBitmask.zeros_and_set_bits(self.length, indexes) else: bitmask = self.copy() return bitmask def __eq__(self, o: "LiteBitmask") -> "LiteBitmask": if type(o) == LiteBitmask: lst = map(lambda a: a[0] == a[1], zip(self.tolist(), o.tolist())) bits = [idx for idx, value in enumerate(lst) if value] result = LiteBitmask.zeros_and_set_bits(self.length, bits) else: result = super().__eq__(o) return result def copy(self, *args, **kwargs): return LiteBitmask.zeros_and_set_intervals(self.length, self.bit_intervals.intervals) def validate(self): if self.bit_count > 0: assert self.bit_idx_first == self.bit_intervals.intervals[0][0], self.bit_intervals.to_dict() assert self.bit_idx_last == self.bit_intervals.intervals[-1][1] assert self.bit_idx_last <= self.length else: assert self.bit_idx_first == -1 assert self.bit_idx_last == -1 count = 0 for s, e in self.bit_intervals.intervals: for i in range(s, e + 1): count += 1 assert count == self.bit_count self.bit_intervals.validate() @staticmethod def get_bitmask_from_idx_lst(bit_count, idx_lst) -> "LiteBitmask": """create a bitmask from an index list""" return LiteBitmask.zeros_and_set_bits(bit_count, idx_lst) @staticmethod def get_idx_lst_from_bitmask(bitmask, bit_idx_first: int = None, bit_idx_last: int = None) -> list: """return the indexes set inside the bitmask""" # FIXME: priority 1, get bit_idx_last working and this is probably slow. if isinstance(bitmask, LiteBitmask): idx_lst = [_ for _ in get_interval_bits(bitmask.bit_intervals, bit_idx_first=bit_idx_first)] else: idx_lst = EnhancedBitmask.get_idx_lst_from_bitmask(bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last) return idx_lst def get_bits(self, bit_idx_first=0) -> Union[list[str], Generator]: """return the bit index list""" # FIXME: remove, use get_idx_lst_from_bitmask return LiteBitmask.get_idx_lst_from_bitmask(self) def tolist(self): """genereate the full bitmask as a list.""" bit_idx_gen = LiteBitmask.get_idx_lst_from_bitmask(self) lst: list[bool] = [False] * self.length for idx in bit_idx_gen: lst[idx] = True return lst @staticmethod def array(lst, **kwargs) -> "LiteBitmask": bits = [idx for idx, value in enumerate(lst) if value] bitmask = LiteBitmask.zeros_and_set_bits(len(lst), bits) return bitmask @staticmethod def zeros(length: int) -> "LiteBitmask": return LiteBitmask(length, BitIntervals(0, [], -1, -1)) @staticmethod def ones(length: int) -> "LiteBitmask": return LiteBitmask(length, BitIntervals(length, [[0, length - 1]], 0, length - 1)) @staticmethod def zeros_and_set_bits(length: int, bit_lst: Union[List[int], Tuple[int, ...], Set[int]], **kwargs): bit_intervals = index_list_to_bitintervals(bit_lst) return LiteBitmask(length, bit_intervals) @staticmethod def zeros_and_set_ranges(length: int, bit_ranges: list[Union[list[int, int], list[int, int, int]]]): """given a tuple of range slices, create a bitmask. range slices are a->b excluding b""" bit_indexes = sorted(set(i for bit_range in bit_ranges for i in range(*bit_range))) bit_intervals = index_list_to_bitintervals(bit_indexes) return LiteBitmask(length, bit_intervals) @staticmethod def zeros_and_set_intervals(length: int, intervals: IntervalList): """given a tuple of intervals, create a bitmask. intervals are a->b including b do_sort_merge will sort the intervals and remove duplicates""" bit_intervals = BitIntervals.from_intervals_list(intervals) return LiteBitmask(length, bit_intervals) @staticmethod def zeros_and_set_bit(length: int, bit_idx): return LiteBitmask( length, BitIntervals( 1, [ [bit_idx, bit_idx], ], bit_idx, bit_idx, ), ) @staticmethod def from_dict(dct): bit_intervals = BitIntervals.from_dict(dct["bit_intervals"]) # bit_intervals = BitIntervals(dct['bit_count'], dct['intervals'], dct['bit_idx_first'], dct['bit_idx_last']) bitmask = LiteBitmask(dct["length"], bit_intervals) bitmask.flags = BitmaskFlags.from_dict(dct["flags"]) return bitmask def to_dict(self): dct = { "length": self.length, "bit_count": self.bit_count, "bit_idx_first": self.bit_idx_first, "bit_idx_last": self.bit_idx_last, "bit_intervals": self.bit_intervals.to_dict(), "flags": self.flags.to_dict(), } return dct def to_json(self): json_str = simplejson.dumps(self, cls=DataEncoder) return json_str @staticmethod def from_json(json_str): obj = simplejson.loads(json_str, cls=DataDecoder) return obj def __len__(self): return self.length def __add__(self, other): return self | other def __xor__(self, other): return LiteBitmask.logical_xor(self, other) @staticmethod def add(x1, x2, **kwargs): return x1 | x2 @staticmethod def logical_xor(x: "LiteBitmask", y: "LiteBitmask") -> "LiteBitmask": bit_intervals = intervals_xor(x.bit_intervals, y.bit_intervals) bitmask = LiteBitmask(x.length, bit_intervals) return bitmask @staticmethod def logical_not(x: "LiteBitmask") -> "LiteBitmask": bitmask = logical_not(x) return bitmask @staticmethod def logical_xnor(x: "LiteBitmask", y: "LiteBitmask") -> "LiteBitmask": return logical_xnor(x, y) def intersection_all(self, other): xnored = self.logical_xnor(self, other) if xnored.bit_count == xnored.length: result = True else: result = False return result @staticmethod def logical_or(x: "LiteBitmask", y: "LiteBitmask"): bit_intervals = intervals_or(x.bit_intervals, y.bit_intervals) bitmask = LiteBitmask(x.length, bit_intervals) return bitmask @staticmethod def logical_and(x, y): if intersection_possible(x, y): bit_intervals = intervals_and(x.bit_intervals, y.bit_intervals) bitmask = LiteBitmask(x.length, bit_intervals) else: bitmask = LiteBitmask.zeros(x.length) return bitmask def intersection(self, other): if intersection_possible(self, other): real_intersection = LiteBitmask.logical_and(self, other) else: real_intersection = None return real_intersection @classmethod def logical_is_subset(cls, left, right): # FIXME: see EnhancedBitmask for a much faster version. result = cls.logical_and(right, cls.logical_not(left)) return result @classmethod def join_bitmasks(cls, bitmask_lst: list, operation) -> "LiteBitmask": """ given a list of bitmasks, join them """ if len(bitmask_lst) == 0: raise Exception("Must provide a bitmask to join with. The length is required.") elif len(bitmask_lst) == 1: result_bitmask = bitmask_lst[0].copy() else: length = 0 for bitmask in bitmask_lst: if length == 0: length = bitmask.length if bitmask.length != length: raise AssertionError("Cannot join bitmasks with different sizes") if operation == cls.logical_or: # ordering them by bit_count bitmask_lst_orig = sorted(bitmask_lst, key=lambda obj: -obj.bit_count) first_orig_bitmask = bitmask_lst_orig[0] if first_orig_bitmask.length == first_orig_bitmask.bit_count: result_bitmask = first_orig_bitmask.copy() else: result_bitmask = reduce(operation, bitmask_lst) else: result_bitmask = reduce(operation, bitmask_lst) return result_bitmask @staticmethod def sign(x1) -> "EnhancedBitmaskLike": return x1.copy() DataEncoder.register(LiteBitmask, LiteBitmask.to_dict) DataDecoder.register("LiteBitmask", LiteBitmask.from_dict) DataEncoder.register(BitIntervals, BitIntervals.to_dict) DataDecoder.register("BitIntervals", BitIntervals.from_dict) DataEncoder.register(BitmaskFlags, BitmaskFlags.to_dict) DataDecoder.register("BitmaskFlags", BitmaskFlags.from_dict) def convert_bitmask_to_litebitmask(bitmask: EnhancedBitmask): length: int = bitmask.length obj = LiteBitmask.zeros_and_set_bits(length, bitmask.get_idx_lst_from_bitmask(bitmask)) return obj def series_intersection_any(ser, search_bitmask: Union[EnhancedBitmaskLike, str]): """function to apply to a pandas series of bitmasks to search for an intersection.""" return [bitmask.intersection_any(search_bitmask) if bitmask != "" else False for bitmask in ser] """Best versions so far""" intervals_subtract = intervals_subtract_v2 intervals_xor = intervals_xor_v7 intervals_and = intervals_and_v1 logical_xnor = logical_xnor_v6 index_list_to_bitintervals = index_list_to_bitintervals_v2 intervals_or = intervals_or_v7 """ An example of using arrays to contain different function names intervals_subtract_funcs = [ ('intervals_subtract_v0', intervals_subtract_v0), ('intervals_subtract_v1', intervals_subtract_v1), ('intervals_subtract_v2', intervals_subtract_v2), ] intervals_subtract_funcs =['intervals_subtract_v0', 'intervals_subtract_v1', 'intervals_subtract_v2'] """