# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.

import abc
import copy
import operator
import pickle
import sys
from dataclasses import dataclass, asdict
from functools import reduce
from operator import itemgetter
import numpy as np
from itertools import groupby
from pprint import pformat
from typing import Union, Tuple, Generator, List, Set

import simplejson

from Octeres.bitmask import EnhancedBitmaskLike, EnhancedBitmask, bool_True, bool_False
from Octeres.serialization import DataEncoder, DataDecoder

"""
invarients:
    all BitIntervals must have ordered and merged intervals.
    Evaluate 
        LiteBitmask should be immutable.
        changing bit_idx_first and last to a property that's immutable.
    do a pass later an drop underscores on functions not needed outside.
"""


def intersection_possible(bitmaska, bitmaskb):
    if bitmaska.bit_count == 0 or bitmaskb.bit_count == 0:
        result = False
    elif bitmaska.bit_idx_first > bitmaskb.bit_idx_last:
        result = False
    elif bitmaska.bit_idx_last < bitmaskb.bit_idx_first:
        result = False
    else:
        result = True
    return result


def intersection_any(bitmaska, bitmaskb):
    if not intersection_possible(bitmaska, bitmaskb):
        result = False
    else:
        if bitmaska.bit_idx_first == bitmaskb.bit_idx_first or bitmaska.bit_idx_first == bitmaskb.bit_idx_last:
            result = True
        elif bitmaska.bit_idx_last == bitmaskb.bit_idx_first or bitmaska.bit_idx_last == bitmaskb.bit_idx_last:
            result = True
        else:
            result = False
            ni = _get_combined_intervals_until_all_comsumed(bitmaska.bit_intervals, bitmaskb.bit_intervals)
            intervals = []
            try:
                prev_int = next(ni)
                for next_int in ni:
                    max_low = max(prev_int[0], next_int[0])
                    min_high = min(prev_int[1], next_int[1])
                    max_high = max(prev_int[1], next_int[1])
                    # print(f"{prev_int=}, {next_int=}, {max_low=}, {min_high=}, {max_high=}")
                    if max_low <= min_high:
                        intervals.append([max_low, min_high])
                    if max_high > min_high:
                        prev_int = [min_high + 1, max_high]
                    else:
                        prev_int = next(ni)
            except StopIteration:
                pass

            for bit_idx in bitmaska.get_bits(bit_idx_first=bitmaskb.bit_idx_first):
                # todo: use first and last bits.
                # if bit_idx < other.bit_idx_first:
                #     continue
                # elif bit_idx > other.bit_idx_last:
                #     continue
                if bitmaskb[bit_idx]:
                    result = True
                    break
    return result


IntervalList = List[List[int]]


class BitIntervals:
    def __init__(self, bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=False):
        self.bit_count = bit_count
        if no_copy:
            self.intervals: List[List] = intervals
        else:
            self.intervals: List[List] = copy.deepcopy(intervals)
        self.bit_idx_first = bit_idx_first
        self.bit_idx_last = bit_idx_last
        self.validate()

    def to_dict(self):
        dct = dict(
            bit_count=self.bit_count,
            intervals=self.intervals,
            bit_idx_first=self.bit_idx_first,
            bit_idx_last=self.bit_idx_last,
        )
        return dct

    @staticmethod
    def from_dict(dct):
        bit_intervals = BitIntervals(
            dct["bit_count"],
            dct["intervals"],
            dct["bit_idx_first"],
            dct["bit_idx_last"],
        )
        bit_intervals.validate()
        return bit_intervals

    @staticmethod
    def from_intervals_list(intervals: IntervalList, validate=False) -> "BitIntervals":
        # intervals, and thus intervals. should be immutable, sorted and merged
        # assume sorted and no overlaps or consecutive intervals (i.e., [5,6],[7,8] should be [5,8])
        bit_count = 0
        idx_first = -1
        idx_last = -1
        # FIXME: priority one, create test to detect problems in interval ids
        intervals = copy.deepcopy(intervals)
        for start_end in intervals:
            if len(start_end) != 2:
                raise IntervalError(f"interval must be a length of 2:{start_end}")
            start, end = start_end
            assert end >= start
            bit_count += end - start + 1
        if bit_count:
            idx_first = intervals[0][0]
            idx_last = intervals[-1][1]
        bit_intervals = BitIntervals(bit_count, intervals, idx_first, idx_last, no_copy=True)
        if validate:
            bit_intervals.validate()
        return bit_intervals

    def __eq__(self, other: "BitIntervals") -> bool:
        result = (
            self.intervals == other.intervals
            and self.bit_count == other.bit_count
            and self.bit_idx_first == other.bit_idx_first
            and self.bit_idx_last == other.bit_idx_last
        )
        return result

    def __repr__(self):
        result = f"BitIntervals({self.bit_count}, {self.intervals}, {self.bit_idx_first}, {self.bit_idx_last})"
        return result

    # def __getitem__(self, item):
    #     result = self.intervals[item]
    #     return result
    #
    # def __sub__(self, other):
    #     return intervals_subtract(self.intervals, other.intervals)
    #
    # def __add__(self, other):
    #     return intervals_or(self.intervals, other.intervals)
    #
    # def __xor__(self, other):
    #     return intervals_xor(self.intervals, other.intervals)

    def validate(self):
        intervals = self.intervals
        if self.bit_count > 0:
            assert self.bit_idx_first == intervals[0][0]
            assert self.bit_idx_last == intervals[-1][1]
        else:
            assert self.bit_idx_first == -1
            assert self.bit_idx_last == -1
        len_intervals = len(intervals)

        def get_delta(bir):
            a = bir[0]
            b = bir[1]
            if a > b:
                raise IntervalError(f"negative interval: {bir}")
            return (b - a) + 1

        if intervals:
            bit_count = reduce(lambda a, b: a + b, map(get_delta, intervals))
        else:
            bit_count = 0
        for i, alar in enumerate(self.intervals):
            al, ar = alar
            if al > ar:
                raise IntervalError(f"negative interval: {alar}")
            if (i + 1) < len_intervals:
                blbr = intervals[i + 1]
                if alar == blbr:
                    raise IntervalError(f"duplicate interval: {alar}")
                bl, br = blbr
                if ar > bl:
                    raise IntervalError(f"interval out of order or not merged: {ar} > {bl} for {alar}->{blbr}")
        assert bit_count == self.bit_count


def get_interval_bits(bit_intervals: BitIntervals, bit_idx_first=0, bit_idx_last=None) -> List:
    """get the indexes that are set.
    use start if you are looking for indexes above a count.
    This is slow and there are probably other better ways to do this.  All calls should evaluate this.
    """
    intervals: List[List] = bit_intervals.intervals
    bit_lst = []
    for interval in intervals:
        for i in range(interval[0], interval[1] + 1):
            if bit_idx_first and i >= bit_idx_first:
                continue
            if (
                bit_idx_last and i >= bit_idx_last
            ):  # FIXME: priority 1, evaulate >= and make sure it's right. use slice methodology x[1:4].  At unit tests
                continue
            bit_lst.append(i)
    return bit_lst


def get_interval_bits_set_v0(bit_intervals: BitIntervals) -> set:
    """get the indexes that are set."""
    # FIXME: priority 1: stop using this, it's slow
    return {idx for interval in bit_intervals.intervals for idx in [i for i in range(interval[0], interval[1] + 1)]}


def get_interval_bits_set(bit_intervals: BitIntervals) -> Set[int]:
    bit_set: Set[int] = set()
    for interval in bit_intervals.intervals:
        bit_set.update(range(interval[0], interval[1] + 1))
    return bit_set


def intervals_subtract_v0(bit_intervals_a: BitIntervals, bit_intervals_b: BitIntervals) -> BitIntervals:
    """subtract the indexes of a list of intervals from another"""
    # FIXME: this needs work, it may be very slow.
    set_a = get_interval_bits_set(bit_intervals_a)
    set_b = get_interval_bits_set(bit_intervals_b)
    set_c = set_a - set_b
    result = index_list_to_bitintervals(set_c)
    return result


def intervals_subtract_v1(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Directly perform subtract function on the intervals"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    result_intervals = []
    result_intervals_append = result_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    while i < a_intervals_len:
        a_start, a_end = a_intervals[i]
        while j < b_intervals_len:
            b_start, b_end = b_intervals[j]
            if b_end < a_start:
                j += 1
                continue
            if b_start > a_end:
                break
            if b_start <= a_start and b_end >= a_end:
                a_start = a_end + 1
                break
            elif b_start <= a_start:
                a_start = b_end + 1
            elif b_end >= a_end:
                result_intervals_append([a_start, b_start - 1])
                a_start = a_end + 1
                break
            else:
                result_intervals_append([a_start, b_start - 1])
                a_start = b_end + 1
            j += 1
        if a_start <= a_end:
            result_intervals_append([a_start, a_end])
        i += 1

    if not result_intervals:
        return BitIntervals(0, [], -1, -1)

    merged_intervals = []
    merged_intervals_append = merged_intervals.append

    start, end = result_intervals[0]
    result_intervals_len = len(result_intervals)
    for i in range(1, result_intervals_len):
        current_start, current_end = result_intervals[i]
        if current_start > end + 1:
            merged_intervals_append([start, end])
            start, end = current_start, current_end
        else:
            end = max(end, current_end)
    merged_intervals_append([start, end])

    bit_count = sum(end - start + 1 for start, end in merged_intervals)
    bit_idx_first = merged_intervals[0][0]
    bit_idx_last = merged_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_subtract_v2(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Refine the logic during the interval walkthrough"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    result_intervals = []
    result_intervals_append = result_intervals.append

    i, j = 0, 0
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)

    while i < a_intervals_len:
        a_start, a_end = a_intervals[i]

        # Skip b_intervals that end before the current a_start
        while j < b_intervals_len and b_intervals[j][1] < a_start:
            j += 1

        while j < b_intervals_len:
            b_start, b_end = b_intervals[j]
            if b_start > a_end:
                break

            # Handle overlapping intervals
            if b_start <= a_start and b_end >= a_end:
                a_start = a_end + 1
                break
            elif b_start <= a_start:
                a_start = b_end + 1
            elif b_end >= a_end:
                result_intervals_append([a_start, b_start - 1])
                a_start = a_end + 1
                break
            else:
                result_intervals_append([a_start, b_start - 1])
                a_start = b_end + 1
            j += 1

        if a_start <= a_end:
            result_intervals_append([a_start, a_end])

        i += 1

    if not result_intervals:
        return BitIntervals(0, [], -1, -1)

    merged_intervals = []
    start, end = result_intervals[0]
    result_intervals_len = len(result_intervals)
    bit_count = 0

    for i in range(1, result_intervals_len):
        current_start, current_end = result_intervals[i]
        if current_start > end + 1:
            merged_intervals.append([start, end])
            bit_count += end - start + 1
            start, end = current_start, current_end
        else:
            end = max(end, current_end)

    merged_intervals.append([start, end])
    bit_count += end - start + 1

    bit_idx_first = merged_intervals[0][0]
    bit_idx_last = merged_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_subtract_v3(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """An attempt to cut the merge interval process, directly generate the result during the walkthrough"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    start, end = None, None

    while i < a_intervals_len:
        a_start, a_end = a_intervals[i]
        while j < b_intervals_len:
            b_start, b_end = b_intervals[j]
            if b_end < a_start:
                j += 1
                continue
            if b_start > a_end:
                break
            if b_start <= a_start and b_end >= a_end:
                a_start = a_end + 1
                break
            elif b_start <= a_start:
                a_start = b_end + 1
            elif b_end >= a_end:
                if a_start < b_start:
                    if start is None:
                        start, end = a_start, b_start - 1
                    else:
                        if a_start > end + 1:
                            final_intervals_append([start, end])
                            start, end = a_start, b_start - 1
                        else:
                            end = b_start - 1
                a_start = a_end + 1
                break
            else:
                if start is None:
                    start, end = a_start, b_start - 1
                else:
                    if a_start > end + 1:
                        final_intervals_append([start, end])
                        start, end = a_start, b_start - 1
                    else:
                        end = b_start - 1
                a_start = b_end + 1
            j += 1
        if a_start <= a_end:
            if start is None:
                start, end = a_start, a_end
            else:
                if a_start > end + 1:
                    final_intervals_append([start, end])
                    start, end = a_start, a_end
                else:
                    end = a_end
        i += 1

    if start is not None:
        final_intervals_append([start, end])

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_count = sum(end - start + 1 for start, end in final_intervals)
    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last)
    bit_intervals.validate()
    return bit_intervals


def intervals_subtract_v4(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Anothe attempt to cut the merge interval process, refining the walkthrough logic"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    start, end = None, None
    bit_count = 0

    while i < a_intervals_len:
        a_start, a_end = a_intervals[i]
        while j < b_intervals_len:
            b_start, b_end = b_intervals[j]
            if b_end < a_start:
                j += 1
                continue
            if b_start > a_end:
                break
            if b_start <= a_start and b_end >= a_end:
                a_start = a_end + 1
                break
            elif b_start <= a_start:
                a_start = b_end + 1
            elif b_end >= a_end:
                if a_start < b_start:
                    if start is None:
                        start, end = a_start, b_start - 1
                    else:
                        if a_start > end + 1:
                            final_intervals_append([start, end])
                            bit_count += end - start + 1
                            start, end = a_start, b_start - 1
                        else:
                            end = b_start - 1
                a_start = a_end + 1
                break
            else:
                if start is None:
                    start, end = a_start, b_start - 1
                else:
                    if a_start > end + 1:
                        final_intervals_append([start, end])
                        bit_count += end - start + 1
                        start, end = a_start, b_start - 1
                    else:
                        end = b_start - 1
                a_start = b_end + 1
            j += 1
        if a_start <= a_end:
            if start is None:
                start, end = a_start, a_end
            else:
                if a_start > end + 1:
                    final_intervals_append([start, end])
                    bit_count += end - start + 1
                    start, end = a_start, a_end
                else:
                    end = a_end
        i += 1

    if start is not None:
        final_intervals_append([start, end])
        bit_count += end - start + 1

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last)
    bit_intervals.validate()
    return bit_intervals


def _get_combined_intervals_until_all_comsumed(abi: BitIntervals, bbi: BitIntervals) -> Generator[IntervalList, None, None]:
    alst = abi.intervals
    blst = bbi.intervals
    aidx = 0
    bidx = 0
    while True:
        try:
            amin, amax = alst[aidx]
        except IndexError:
            amin, amax = None, None
        try:
            bmin, bmax = blst[bidx]
        except IndexError:
            bmin, bmax = None, None

        if amin is None and bmin is None:
            break

        if bmin is None or (amin is not None and amin <= bmin):
            r = [amin, amax]
            aidx += 1
        else:
            r = [bmin, bmax]
            bidx += 1
        yield r


def _get_combined_intervals_until_overlap_impossible(abi: BitIntervals, bbi: BitIntervals) -> Generator[IntervalList, None, None]:
    alst = abi.intervals
    blst = bbi.intervals
    aidx = 0
    bidx = 0
    while True:
        try:
            amin, amax = alst[aidx]
        except IndexError:
            amin, amax = None, alst[-1][1] if len(alst) else -1
        try:
            bmin, bmax = blst[bidx]
        except IndexError:
            bmin, bmax = None, blst[-1][1] if len(blst) else -1

        if (amin is None and bmin is None) or (amin is None and bmin > amax) or (bmin is None and amin > bmax):
            break

        if bmin is None or (amin is not None and amin <= bmin):
            r = [amin, amax]
            aidx += 1
        else:
            r = [bmin, bmax]
            bidx += 1
        yield r


def intervals_or_v0(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals:
    """does a logical or of two bit intervals"""
    interval_gen = _get_combined_intervals_until_all_comsumed(abi, bbi)
    try:
        prev_start, prev_end = next(interval_gen)
        intervals = [[prev_start, prev_end]]
        bit_count = prev_end - prev_start + 1
    except StopIteration:
        # both of the BitIntervals lists were empty
        bit_count = 0
        intervals = []
        bit_idx_first = -1
        bit_idx_last = -1
    else:
        for next_start, next_end in interval_gen:
            if next_start <= prev_end + 1:
                new_end = max(prev_end, next_end)
                intervals[-1] = [prev_start, new_end]
                bit_count += new_end - prev_end
                prev_end = new_end
            else:
                intervals.append([next_start, next_end])
                bit_count += next_end - next_start + 1
                prev_start, prev_end = next_start, next_end
        bit_idx_first = intervals[0][0]
        bit_idx_last = intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True)
    return bit_intervals


def intervals_or_v4(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals:
    """Fixed the over merge problem occurred in v1 and v2"""
    a_intervals = abi.intervals
    b_intervals = bbi.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append
    i, j = 0, 0
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    start, end = None, None

    def append_interval(start, end):
        if start is not None:
            if final_intervals:
                last_start, last_end = final_intervals[-1]
                if start <= last_end + 1:
                    final_intervals[-1][1] = max(last_end, end)
                else:
                    final_intervals_append([start, end])
            else:
                final_intervals_append([start, end])

    while i < a_intervals_len and j < b_intervals_len:
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            append_interval(start, end)
            start, end = a_start, a_end
            i += 1
        elif b_end < a_start:
            append_interval(start, end)
            start, end = b_start, b_end
            j += 1
        else:
            append_interval(start, end)
            start = min(a_start, b_start)
            end = max(a_end, b_end)
            if a_end > b_end:
                j += 1
            elif a_end < b_end:
                i += 1
            else:
                i += 1
                j += 1

    while i < a_intervals_len:
        a_start, a_end = a_intervals[i]
        append_interval(start, end)
        start, end = a_start, a_end
        i += 1

    while j < b_intervals_len:
        b_start, b_end = b_intervals[j]
        append_interval(start, end)
        start, end = b_start, b_end
        j += 1

    append_interval(start, end)

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_count = sum(end - start + 1 for start, end in final_intervals)
    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_or_v5(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals:
    """Optimized version of intervals_or with minimized condition checks"""
    a_intervals = abi.intervals
    b_intervals = bbi.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append
    i, j = 0, 0
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    bit_count = 0

    def merge_or_append_interval_v5(start, end, bit_count):
        if final_intervals and final_intervals[-1][1] >= start - 1:
            previous_end = final_intervals[-1][1]
            final_intervals[-1][1] = max(final_intervals[-1][1], end)
            bit_count += final_intervals[-1][1] - previous_end
        else:
            final_intervals_append([start, end])
            bit_count += end - start + 1
        return bit_count

    while i < a_intervals_len and j < b_intervals_len:
        if a_intervals[i][0] <= b_intervals[j][0]:
            start, end = a_intervals[i]
            i += 1
        else:
            start, end = b_intervals[j]
            j += 1

        while i < a_intervals_len and a_intervals[i][0] <= end + 1:
            end = max(end, a_intervals[i][1])
            i += 1
        while j < b_intervals_len and b_intervals[j][0] <= end + 1:
            end = max(end, b_intervals[j][1])
            j += 1

        bit_count = merge_or_append_interval_v5(start, end, bit_count)

    while i < a_intervals_len:
        bit_count = merge_or_append_interval_v5(a_intervals[i][0], a_intervals[i][1], bit_count)
        i += 1

    while j < b_intervals_len:
        bit_count = merge_or_append_interval_v5(b_intervals[j][0], b_intervals[j][1], bit_count)
        j += 1

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_or_v6(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals:
    """Optimized version of intervals_or_v4 with cut the redundant overlap and margin check"""
    a_intervals = abi.intervals
    b_intervals = bbi.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append
    i, j = 0, 0
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)

    while i < a_intervals_len and j < b_intervals_len:
        if a_intervals[i][1] < b_intervals[j][0]:
            if final_intervals and final_intervals[-1][1] >= a_intervals[i][0] - 1:
                final_intervals[-1][1] = max(final_intervals[-1][1], a_intervals[i][1])
            else:
                final_intervals_append(a_intervals[i])
            i += 1
        elif b_intervals[j][1] < a_intervals[i][0]:
            if final_intervals and final_intervals[-1][1] >= b_intervals[j][0] - 1:
                final_intervals[-1][1] = max(final_intervals[-1][1], b_intervals[j][1])
            else:
                final_intervals_append(b_intervals[j])
            j += 1
        else:
            start = min(a_intervals[i][0], b_intervals[j][0])
            end = max(a_intervals[i][1], b_intervals[j][1])
            while i < a_intervals_len and a_intervals[i][0] <= end:
                end = max(end, a_intervals[i][1])
                i += 1
            while j < b_intervals_len and b_intervals[j][0] <= end:
                end = max(end, b_intervals[j][1])
                j += 1
            if final_intervals and final_intervals[-1][1] >= start - 1:
                final_intervals[-1][1] = max(final_intervals[-1][1], end)
            else:
                final_intervals_append([start, end])

    while i < a_intervals_len:
        if final_intervals and final_intervals[-1][1] >= a_intervals[i][0] - 1:
            final_intervals[-1][1] = max(final_intervals[-1][1], a_intervals[i][1])
        else:
            final_intervals_append(a_intervals[i])
        i += 1

    while j < b_intervals_len:
        if final_intervals and final_intervals[-1][1] >= b_intervals[j][0] - 1:
            final_intervals[-1][1] = max(final_intervals[-1][1], b_intervals[j][1])
        else:
            final_intervals_append(b_intervals[j])
        j += 1

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_count = sum(end - start + 1 for start, end in final_intervals)
    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def _merge_or_append_interval_v7(start, end, final_intervals, final_intervals_append, bit_count):
    """Deal with the interval merging boundary for v7 and calculate the bit_count"""
    if final_intervals and final_intervals[-1][1] >= start - 1:
        previous_end = final_intervals[-1][1]
        final_intervals[-1][1] = max(final_intervals[-1][1], end)
        bit_count += final_intervals[-1][1] - previous_end
    else:
        final_intervals_append([start, end])
        bit_count += end - start + 1
    return bit_count


def intervals_or_v7(abi: BitIntervals, bbi: BitIntervals) -> BitIntervals:
    """Try to take merging function outside the or function to see if it could be faster"""
    a_intervals = abi.intervals
    b_intervals = bbi.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append
    i, j = 0, 0
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    bit_count = 0

    while i < a_intervals_len and j < b_intervals_len:
        if a_intervals[i][0] <= b_intervals[j][0]:
            start, end = a_intervals[i]
            i += 1
        else:
            start, end = b_intervals[j]
            j += 1

        while i < a_intervals_len and a_intervals[i][0] <= end + 1:
            end = max(end, a_intervals[i][1])
            i += 1
        while j < b_intervals_len and b_intervals[j][0] <= end + 1:
            end = max(end, b_intervals[j][1])
            j += 1

        bit_count = _merge_or_append_interval_v7(start, end, final_intervals, final_intervals_append, bit_count)

    while i < a_intervals_len:
        bit_count = _merge_or_append_interval_v7(
            a_intervals[i][0], a_intervals[i][1], final_intervals, final_intervals_append, bit_count
        )
        i += 1

    while j < b_intervals_len:
        bit_count = _merge_or_append_interval_v7(
            b_intervals[j][0], b_intervals[j][1], final_intervals, final_intervals_append, bit_count
        )
        j += 1

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    # bit_count = sum(end - start + 1 for start, end in final_intervals)
    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def logical_not(bitmask: "LiteBitmask") -> "LiteBitmask":
    """negate the intervals"""
    if bitmask.all():
        bitmask = LiteBitmask.zeros(bitmask.length)
    elif bitmask.none():
        bitmask = LiteBitmask.ones(bitmask.length)
    else:
        bi = bitmask.bit_intervals
        length = bitmask.length
        if bi.intervals and bi.intervals[-1][1] >= length:
            raise IndexError("last interval exceeds specified length")
        intervals = []
        interval_append = intervals.append
        bit_count = 0
        not_start = 0
        for bi_start, bi_end in bi.intervals:
            if bi_start >= length:
                break
            not_end = bi_start - 1
            if not_end >= 0:
                interval_append([not_start, not_end])
            bit_count += not_end - not_start + 1
            not_start = bi_end + 1
        if not_start < length:
            interval_append([not_start, length - 1])
            bit_count += length - not_start
        bitmask = LiteBitmask.zeros_and_set_intervals(bitmask.length, intervals)
    return bitmask


def intervals_xor_set_v0(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """negate the intervals"""
    # FIXME: possibly use (a-b)|(b-a), but probably better to use iteration of one list (b?) and comparison with an iteration of
    # the other list (a?).  performance comparisons needed for the size of the two intervals.
    # FIXME: this needs work, it may be very slow.
    set_a = get_interval_bits_set(a_bit_intervals)
    set_b = get_interval_bits_set(b_bit_intervals)
    set_c = set_a.symmetric_difference(set_b)
    result = index_list_to_bitintervals_v0(set_c)
    return result


def intervals_xor_set_v1(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Test a new version index_list_to_bitintervals"""
    set_a = get_interval_bits_set(a_bit_intervals)
    set_b = get_interval_bits_set(b_bit_intervals)
    set_c = set_a.symmetric_difference(set_b)
    result = index_list_to_bitintervals_v1(set_c)
    return result


def intervals_xor_set_v2(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Test another version of index_list_to_bitintervals"""
    set_a = get_interval_bits_set(a_bit_intervals)
    set_b = get_interval_bits_set(b_bit_intervals)
    set_c = set_a.symmetric_difference(set_b)
    result = index_list_to_bitintervals_v2(set_c)
    return result


def intervals_xor_v0(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Directly do xor on the intervals to avoid data type change"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals

    result_intervals = []
    i, j = 0, 0

    while i < len(a_intervals) and j < len(b_intervals):
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            result_intervals.append([a_start, a_end])
            i += 1
        elif b_end < a_start:
            result_intervals.append([b_start, b_end])
            j += 1
        else:
            if a_start < b_start:
                result_intervals.append([a_start, b_start - 1])
            if a_end > b_end:
                result_intervals.append([b_end + 1, a_end])
                j += 1
            elif a_end < b_end:
                result_intervals.append([a_end + 1, b_end])
                i += 1
            else:
                i += 1
                j += 1

    while i < len(a_intervals):
        result_intervals.append(a_intervals[i])
        i += 1

    while j < len(b_intervals):
        result_intervals.append(b_intervals[j])
        j += 1

    merged_intervals = []
    for start, end in result_intervals:
        if not merged_intervals or merged_intervals[-1][1] < start - 1:
            merged_intervals.append([start, end])
        else:
            merged_intervals[-1][1] = max(merged_intervals[-1][1], end)

    bit_count = sum(interval[1] - interval[0] + 1 for interval in merged_intervals)
    bit_idx_first = merged_intervals[0][0] if merged_intervals else -1
    bit_idx_last = merged_intervals[-1][1] if merged_intervals else -1

    return BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last)


def intervals_xor_v4(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Using copy to avoid the address problem when updating the value on original intervals"""
    a_intervals = a_bit_intervals.intervals.copy()
    b_intervals = b_bit_intervals.intervals.copy()
    result_intervals = []
    result_intervals_append = result_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    while i < a_intervals_len and j < b_intervals_len:
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            result_intervals_append([a_start, a_end])
            i += 1
        elif b_end < a_start:
            result_intervals_append([b_start, b_end])
            j += 1
        else:
            if a_start < b_start:
                result_intervals_append([a_start, b_start - 1])
                a_start = b_start
            if b_start < a_start:
                result_intervals_append([b_start, a_start - 1])

            if a_end > b_end:
                a_intervals[i] = [b_end + 1, a_end]
                j += 1
            elif a_end < b_end:
                b_intervals[j] = [a_end + 1, b_end]
                i += 1
            else:
                i += 1
                j += 1

    while i < a_intervals_len:
        start, end = a_intervals[i]
        result_intervals_append([start, end])
        i += 1

    while j < b_intervals_len:
        start, end = b_intervals[j]
        result_intervals_append([start, end])
        j += 1

    if not result_intervals:
        return BitIntervals(0, [], -1, -1)

    merged_intervals = []
    merged_intervals_append = merged_intervals.append

    start, end = result_intervals[0]
    result_intervals_len = len(result_intervals)
    for i in range(1, result_intervals_len):
        current_start, current_end = result_intervals[i]
        if current_start > end + 1:
            merged_intervals_append([start, end])
            start, end = current_start, current_end
        else:
            end = max(end, current_end)
    merged_intervals_append([start, end])

    bit_count = sum(end - start + 1 for start, end in merged_intervals)
    bit_idx_first = merged_intervals[0][0]
    bit_idx_last = merged_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_xor_v5(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Using temp intervals to avoid using copy for fast xor function"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    result_intervals = []
    result_intervals_append = result_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    a_temp_interval = None
    b_temp_interval = None

    while i < a_intervals_len and j < b_intervals_len:
        if a_temp_interval is None:
            a_start, a_end = a_intervals[i]
        else:
            a_start, a_end = a_temp_interval

        if b_temp_interval is None:
            b_start, b_end = b_intervals[j]
        else:
            b_start, b_end = b_temp_interval

        if a_end < b_start:
            result_intervals_append([a_start, a_end])
            a_temp_interval = None
            i += 1
        elif b_end < a_start:
            result_intervals_append([b_start, b_end])
            b_temp_interval = None
            j += 1
        else:
            if a_start < b_start:
                result_intervals_append([a_start, b_start - 1])
                a_start = b_start
            if b_start < a_start:
                result_intervals_append([b_start, a_start - 1])
                b_start = a_start

            if a_end > b_end:
                a_temp_interval = [b_end + 1, a_end]
                b_temp_interval = None
                j += 1
            elif a_end < b_end:
                b_temp_interval = [a_end + 1, b_end]
                a_temp_interval = None
                i += 1
            else:
                a_temp_interval = None
                b_temp_interval = None
                i += 1
                j += 1

    while i < a_intervals_len:
        if a_temp_interval is None:
            start, end = a_intervals[i]
        else:
            start, end = a_temp_interval
            a_temp_interval = None
        result_intervals_append([start, end])
        i += 1

    while j < b_intervals_len:
        if b_temp_interval is None:
            start, end = b_intervals[j]
        else:
            start, end = b_temp_interval
            b_temp_interval = None
        result_intervals_append([start, end])
        j += 1

    if not result_intervals:
        return BitIntervals(0, [], -1, -1)

    merged_intervals = []
    merged_intervals_append = merged_intervals.append
    bit_count = 0
    start, end = result_intervals[0]
    result_intervals_len = len(result_intervals)
    for i in range(1, result_intervals_len):
        current_start, current_end = result_intervals[i]
        if current_start > end + 1:
            merged_intervals_append([start, end])
            bit_count += end - start + 1
            start, end = current_start, current_end
        else:
            end = max(end, current_end)

    merged_intervals_append([start, end])
    bit_count += end - start + 1
    bit_idx_first = merged_intervals[0][0]
    bit_idx_last = merged_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_xor_v6(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """The usage temp intervals could abandon the temp array of merged intervals"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    a_temp_interval = None
    b_temp_interval = None
    start, end = None, None
    bit_count = 0

    while i < a_intervals_len and j < b_intervals_len:
        if a_temp_interval is None:
            a_start, a_end = a_intervals[i]
        else:
            a_start, a_end = a_temp_interval

        if b_temp_interval is None:
            b_start, b_end = b_intervals[j]
        else:
            b_start, b_end = b_temp_interval

        if a_end < b_start:
            if start is None:
                start, end = a_start, a_end
            else:
                if a_start > end + 1:
                    final_intervals_append([start, end])
                    bit_count += end - start + 1
                    start, end = a_start, a_end
                else:
                    end = a_end
            a_temp_interval = None
            i += 1
        elif b_end < a_start:
            if start is None:
                start, end = b_start, b_end
            else:
                if b_start > end + 1:
                    final_intervals_append([start, end])
                    bit_count += end - start + 1
                    start, end = b_start, b_end
                else:
                    end = b_end
            b_temp_interval = None
            j += 1
        else:
            if a_start < b_start:
                if start is None:
                    start, end = a_start, b_start - 1
                else:
                    if a_start > end + 1:
                        final_intervals_append([start, end])
                        bit_count += end - start + 1
                        start, end = a_start, b_start - 1
                    else:
                        end = b_start - 1
                a_start = b_start
            if b_start < a_start:
                if start is None:
                    start, end = b_start, a_start - 1
                else:
                    if b_start > end + 1:
                        final_intervals_append([start, end])
                        bit_count += end - start + 1
                        start, end = b_start, a_start - 1
                    else:
                        end = a_start - 1
                b_start = a_start

            if a_end > b_end:
                a_temp_interval = [b_end + 1, a_end]
                b_temp_interval = None
                j += 1
            elif a_end < b_end:
                b_temp_interval = [a_end + 1, b_end]
                a_temp_interval = None
                i += 1
            else:
                a_temp_interval = None
                b_temp_interval = None
                i += 1
                j += 1

    while i < a_intervals_len:
        if a_temp_interval is None:
            start_interval, end_interval = a_intervals[i]
        else:
            start_interval, end_interval = a_temp_interval
            a_temp_interval = None
        if start is None:
            start, end = start_interval, end_interval
        else:
            if start_interval > end + 1:
                final_intervals_append([start, end])
                bit_count += end - start + 1
                start, end = start_interval, end_interval
            else:
                end = end_interval
        i += 1

    while j < b_intervals_len:
        if b_temp_interval is None:
            start_interval, end_interval = b_intervals[j]
        else:
            start_interval, end_interval = b_temp_interval
            b_temp_interval = None
        if start is None:
            start, end = start_interval, end_interval
        else:
            if start_interval > end + 1:
                final_intervals_append([start, end])
                bit_count += end - start + 1
                start, end = start_interval, end_interval
            else:
                end = end_interval
        j += 1

    if start is not None:
        final_intervals_append([start, end])
        bit_count += end - start + 1

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def handle_non_overlapping(start, end, bit_count, interval_start, interval_end, final_intervals_append):
    """Appending the non_overlapping intervals into the result, and calculate the bit_count"""
    if start is None:
        start, end = interval_start, interval_end
    else:
        if interval_start > end + 1:
            final_intervals_append([start, end])
            bit_count += end - start + 1
            start, end = interval_start, interval_end
        else:
            end = interval_end
    return start, end, bit_count


def process_remaining_intervals(start, end, bit_count, intervals, idx, temp_interval, final_intervals_append):
    """Deal with the remaining intervals not processed during the first walkthrough"""
    while idx < len(intervals):
        if temp_interval is None:
            start_interval, end_interval = intervals[idx]
        else:
            start_interval, end_interval = temp_interval
            temp_interval = None
        start, end, bit_count = handle_non_overlapping(start, end, bit_count, start_interval, end_interval, final_intervals_append)
        idx += 1
    return start, end, bit_count


def intervals_xor_v7(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Take out similar conditions and boundary check as a function, might be slower"""
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals
    final_intervals = []
    final_intervals_append = final_intervals.append

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    a_temp_interval = None
    b_temp_interval = None
    start, end = None, None
    bit_count = 0

    while i < a_intervals_len and j < b_intervals_len:
        if a_temp_interval is None:
            a_start, a_end = a_intervals[i]
        else:
            a_start, a_end = a_temp_interval

        if b_temp_interval is None:
            b_start, b_end = b_intervals[j]
        else:
            b_start, b_end = b_temp_interval

        if a_end < b_start:
            start, end, bit_count = handle_non_overlapping(start, end, bit_count, a_start, a_end, final_intervals_append)
            a_temp_interval = None
            i += 1
        elif b_end < a_start:
            start, end, bit_count = handle_non_overlapping(start, end, bit_count, b_start, b_end, final_intervals_append)
            b_temp_interval = None
            j += 1
        else:
            if a_start < b_start:
                start, end, bit_count = handle_non_overlapping(start, end, bit_count, a_start, b_start - 1, final_intervals_append)
                a_start = b_start
            if b_start < a_start:
                start, end, bit_count = handle_non_overlapping(start, end, bit_count, b_start, a_start - 1, final_intervals_append)
                b_start = a_start

            if a_end > b_end:
                a_temp_interval = [b_end + 1, a_end]
                b_temp_interval = None
                j += 1
            elif a_end < b_end:
                b_temp_interval = [a_end + 1, b_end]
                a_temp_interval = None
                i += 1
            else:
                a_temp_interval = None
                b_temp_interval = None
                i += 1
                j += 1

    start, end, bit_count = process_remaining_intervals(
        start, end, bit_count, a_intervals, i, a_temp_interval, final_intervals_append
    )
    start, end, bit_count = process_remaining_intervals(
        start, end, bit_count, b_intervals, j, b_temp_interval, final_intervals_append
    )

    if start is not None:
        final_intervals_append([start, end])
        bit_count += end - start + 1

    if not final_intervals:
        return BitIntervals(0, [], -1, -1)

    bit_idx_first = final_intervals[0][0]
    bit_idx_last = final_intervals[-1][1]
    bit_intervals = BitIntervals(bit_count, final_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()
    return bit_intervals


def intervals_and_v0(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """does an and of two intervals"""
    ni = _get_combined_intervals_until_overlap_impossible(a_bit_intervals, b_bit_intervals)
    intervals = []
    bit_count = 0
    try:
        prev_int = next(ni)
        next_int: list[int]
        for next_int in ni:
            max_low: int = max(prev_int[0], next_int[0])
            min_high: int = min(prev_int[1], next_int[1])
            max_high: int = max(prev_int[1], next_int[1])
            # print(f"{prev_int=}, {next_int=}, {max_low=}, {min_high=}, {max_high=}")
            if max_low <= min_high:
                bit_count += (min_high - max_low) + 1  # intervals need the +1
                intervals.append([max_low, min_high])
            if max_high > min_high:
                prev_int = [min_high + 1, max_high]
            else:
                prev_int = next(ni)
    except StopIteration:
        pass
    try:
        bit_idx_first = intervals[0][0]
        bit_idx_last = intervals[-1][1]
    except IndexError:
        bit_idx_first = -1
        bit_idx_last = -1
    bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last)
    return bit_intervals


def intervals_and_v1(a_bit_intervals: BitIntervals, b_bit_intervals: BitIntervals) -> BitIntervals:
    """Use a better function to check the overlap"""
    # Check for full overlap
    if (
        a_bit_intervals.bit_count == b_bit_intervals.bit_count
        and a_bit_intervals.bit_idx_first == b_bit_intervals.bit_idx_first
        and a_bit_intervals.bit_idx_last == b_bit_intervals.bit_idx_last
        and len(a_bit_intervals.intervals) == len(b_bit_intervals.intervals)
    ):
        return BitIntervals(
            a_bit_intervals.bit_count,
            a_bit_intervals.intervals,
            a_bit_intervals.bit_idx_first,
            a_bit_intervals.bit_idx_last,
        )

    ni = _get_combined_intervals_until_overlap_impossible(a_bit_intervals, b_bit_intervals)
    intervals = []
    bit_count = 0
    try:
        prev_int = next(ni)
        next_int: list[int]
        for next_int in ni:
            max_low: int = max(prev_int[0], next_int[0])
            min_high: int = min(prev_int[1], next_int[1])
            max_high: int = max(prev_int[1], next_int[1])
            # print(f"{prev_int=}, {next_int=}, {max_low=}, {min_high=}, {max_high=}")
            if max_low <= min_high:
                bit_count += (min_high - max_low) + 1
                intervals.append([max_low, min_high])
            if max_high > min_high:
                prev_int = [min_high + 1, max_high]
            else:
                prev_int = next(ni)
    except StopIteration:
        pass
    try:
        bit_idx_first = intervals[0][0]
        bit_idx_last = intervals[-1][1]
    except IndexError:
        bit_idx_first = -1
        bit_idx_last = -1
    bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True)
    return bit_intervals


def logical_subtract(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Perform logical_subtract directly on bitmasks"""
    a_bit_intervals = a_bitmask.bit_intervals
    b_bit_intervals = b_bitmask.bit_intervals
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals

    result_intervals = []
    result_intervals_append = result_intervals.append

    i, j = 0, 0

    while i < len(a_intervals) and j < len(b_intervals):
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            result_intervals_append([a_start, a_end])
            i += 1
        elif b_end < a_start:
            j += 1
        else:
            if a_start < b_start:
                result_intervals_append([a_start, b_start - 1])
            if a_end > b_end:
                a_intervals[i] = [b_end + 1, a_end]
                j += 1
            elif a_end < b_end:
                i += 1
            else:
                i += 1
                j += 1

    while i < len(a_intervals):
        result_intervals_append(a_intervals[i])
        i += 1

    # Merge overlapping intervals
    if not result_intervals:
        result = BitIntervals(0, [], -1, -1)
    else:
        merged_intervals = []
        merged_intervals_append = merged_intervals.append
        bit_count = 0

        current_start, current_end = result_intervals[0]
        for start, end in result_intervals[1:]:
            if start > current_end + 1:
                merged_intervals_append([current_start, current_end])
                bit_count += current_end - current_start + 1
                current_start, current_end = start, end
            else:
                current_end = max(current_end, end)

        merged_intervals_append([current_start, current_end])
        bit_count += current_end - current_start + 1

        bit_idx_first = merged_intervals[0][0]
        bit_idx_last = merged_intervals[-1][1]
        result = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last)

    bitmask = LiteBitmask.zeros_and_set_intervals(a_bitmask.length, result.intervals)
    return bitmask


def logical_xnor_v0(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Call the intervals_xor and logical_not to implement the xor function"""
    bit_intervals = intervals_xor(a_bitmask.bit_intervals, b_bitmask.bit_intervals)
    bitmask = LiteBitmask(a_bitmask.length, bit_intervals)
    bitmask = logical_not(bitmask)
    return bitmask


def logical_xnor_v1(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Try to perform xor function in two walkthrough"""
    a_bit_intervals = a_bitmask.bit_intervals
    b_bit_intervals = b_bitmask.bit_intervals
    a_intervals = a_bit_intervals.intervals
    b_intervals = b_bit_intervals.intervals

    result_intervals = []
    result_intervals_append = result_intervals.append

    i, j = 0, 0
    last_position = -1  # Track the last position we processed

    while i < len(a_intervals) and j < len(b_intervals):
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            # Append intervals where both are 0s (i.e., gaps)
            if last_position + 1 < a_start:
                result_intervals_append([last_position + 1, a_start - 1])
            last_position = a_end
            i += 1
        elif b_end < a_start:
            if last_position + 1 < b_start:
                result_intervals_append([last_position + 1, b_start - 1])
            last_position = b_end
            j += 1
        else:
            start = max(a_start, b_start)
            end = min(a_end, b_end)
            result_intervals_append([start, end])
            if a_end > b_end:
                last_position = b_end
                j += 1
            elif b_end > a_end:
                last_position = a_end
                i += 1
            else:
                last_position = a_end
                i += 1
                j += 1

    # Append any remaining intervals where both are 0s
    while i < len(a_intervals):
        a_start, a_end = a_intervals[i]
        if last_position + 1 < a_start:
            result_intervals_append([last_position + 1, a_start - 1])
        last_position = a_end
        i += 1

    while j < len(b_intervals):
        b_start, b_end = b_intervals[j]
        if last_position + 1 < b_start:
            result_intervals_append([last_position + 1, b_start - 1])
        last_position = b_end
        j += 1

    last_bit = max(a_bitmask.length, b_bitmask.length) - 1
    if last_position < last_bit:
        result_intervals_append([last_position + 1, last_bit])

    # Merge overlapping intervals
    if not result_intervals:
        result = BitIntervals(0, [], -1, -1)
    else:
        merged_intervals = []
        merged_intervals_append = merged_intervals.append
        bit_count = 0

        current_start, current_end = result_intervals[0]
        for start, end in result_intervals[1:]:
            if start > current_end + 1:
                merged_intervals_append([current_start, current_end])
                bit_count += current_end - current_start + 1
                current_start, current_end = start, end
            else:
                current_end = max(current_end, end)

        merged_intervals_append([current_start, current_end])
        bit_count += current_end - current_start + 1

        bit_idx_first = merged_intervals[0][0]
        bit_idx_last = merged_intervals[-1][1]
        result = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last)

    bitmask = LiteBitmask.zeros_and_set_intervals(a_bitmask.length, result.intervals)
    return bitmask


def logical_xnor_v2(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Refining the logic of the start and end of each sub_intervals"""
    a_bit_intervals = a_bitmask.bit_intervals
    b_bit_intervals = b_bitmask.bit_intervals
    a_intervals = a_bit_intervals.intervals.copy()
    b_intervals = b_bit_intervals.intervals.copy()

    result_intervals = []
    result_intervals_append = result_intervals.append
    i, j = 0, 0
    while i < len(a_intervals) and j < len(b_intervals):
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            result_intervals_append([a_start, a_end])
            i += 1
        elif b_end < a_start:
            result_intervals_append([b_start, b_end])
            j += 1
        else:
            start = min(a_start, b_start)
            end = max(a_end, b_end)
            result_intervals_append([start, end])
            if a_end > b_end:
                a_intervals[i] = [b_end + 1, a_end]
                j += 1
            elif b_end > a_end:
                b_intervals[j] = [a_end + 1, b_end]
                i += 1
            else:
                i += 1
                j += 1

    while i < len(a_intervals):
        result_intervals_append(a_intervals[i])
        i += 1

    while j < len(b_intervals):
        result_intervals_append(b_intervals[j])
        j += 1

    last_bit = max(a_bitmask.length, b_bitmask.length) - 1
    if result_intervals:
        last_result_bit = result_intervals[-1][1]
        if last_result_bit < last_bit:
            result_intervals_append([last_result_bit + 1, last_bit])
    else:
        result_intervals_append([0, last_bit])

    # Merge overlapping intervals and handle gaps
    if not result_intervals:
        return LiteBitmask.zeros(a_bitmask.length)

    merged_intervals = []
    merged_intervals_append = merged_intervals.append

    start, end = result_intervals[0]
    result_intervals_len = len(result_intervals)
    for i in range(1, result_intervals_len):
        current_start, current_end = result_intervals[i]
        if current_start > end + 1:
            # Include the gap
            merged_intervals_append([start, end])
            merged_intervals_append([end + 1, current_start - 1])
            start, end = current_start, current_end
        else:
            end = max(end, current_end)
    merged_intervals_append([start, end])

    bit_count = sum(end - start + 1 for start, end in merged_intervals)
    bit_idx_first = merged_intervals[0][0]
    bit_idx_last = merged_intervals[-1][1]
    result = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last)
    result.validate()
    bitmask = LiteBitmask.zeros_and_set_intervals(a_bitmask.length, result.intervals)
    return bitmask


def logical_xnor_v3(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Another effort dealing with the overlap in both intervals"""
    a_intervals = a_bitmask.bit_intervals.intervals
    b_intervals = b_bitmask.bit_intervals.intervals
    length = a_bitmask.length
    final_intervals = []

    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    final_intervals_append = final_intervals.append
    i, j = 0, 0

    last_end = -1

    while i < a_intervals_len and j < b_intervals_len:
        a_start, a_end = a_intervals[i]
        b_start, b_end = b_intervals[j]

        if a_end < b_start:
            if last_end < a_start - 1:
                final_intervals_append([last_end + 1, a_start - 1])
            last_end = a_end
            i += 1
        elif b_end < a_start:
            if last_end < b_start - 1:
                final_intervals_append([last_end + 1, b_start - 1])
            last_end = b_end
            j += 1
        else:
            if last_end < min(a_start, b_start) - 1:
                final_intervals_append([last_end + 1, min(a_start, b_start) - 1])

            overlap_start = max(a_start, b_start)
            overlap_end = min(a_end, b_end)
            final_intervals_append([overlap_start, overlap_end])
            last_end = overlap_end

            if a_end > b_end:
                j += 1
            elif a_end < b_end:
                i += 1
            else:
                i += 1
                j += 1

    while i < a_intervals_len:
        a_start, a_end = a_intervals[i]
        if last_end < a_start - 1:
            final_intervals_append([last_end + 1, a_start - 1])
        last_end = a_end
        i += 1

    while j < b_intervals_len:
        b_start, b_end = b_intervals[j]
        if last_end < b_start - 1:
            final_intervals_append([last_end + 1, b_start - 1])
        last_end = b_end
        j += 1

    if last_end < length - 1:
        final_intervals_append([last_end + 1, length - 1])

    if not final_intervals and not a_intervals and not b_intervals:
        final_intervals = [[0, length - 1]]

    merged_intervals = []
    for interval in final_intervals:
        if not merged_intervals or merged_intervals[-1][1] < interval[0] - 1:
            merged_intervals.append(interval)
        else:
            merged_intervals[-1][1] = max(merged_intervals[-1][1], interval[1])

    if not merged_intervals:
        bit_count = 0
        bit_idx_first = -1
        bit_idx_last = -1
    else:
        bit_count = sum(end - start + 1 for start, end in merged_intervals)
        bit_idx_first = merged_intervals[0][0]
        bit_idx_last = merged_intervals[-1][1]

    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()

    return LiteBitmask(length, bit_intervals)


def logical_xnor_v4(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Usage of add_intervals to both speed up the append process and avoid merge process"""
    a_intervals = a_bitmask.bit_intervals.intervals
    b_intervals = b_bitmask.bit_intervals.intervals
    length = a_bitmask.length
    merged_intervals = []
    merged_intervals_append = merged_intervals.append
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    last_end = -1

    def add_interval(start, end):
        if not merged_intervals or merged_intervals[-1][1] < start - 1:
            merged_intervals_append([start, end])
        else:
            merged_intervals[-1][1] = max(merged_intervals[-1][1], end)

    while i < a_intervals_len or j < b_intervals_len:
        if i < a_intervals_len:
            a_start, a_end = a_intervals[i]
        else:
            a_start, a_end = length, length

        if j < b_intervals_len:
            b_start, b_end = b_intervals[j]
        else:
            b_start, b_end = length, length

        if a_end < b_start:
            if last_end < a_start - 1:
                add_interval(last_end + 1, a_start - 1)
            last_end = a_end
            i += 1
        elif b_end < a_start:
            if last_end < b_start - 1:
                add_interval(last_end + 1, b_start - 1)
            last_end = b_end
            j += 1
        else:
            if last_end < min(a_start, b_start) - 1:
                add_interval(last_end + 1, min(a_start, b_start) - 1)

            overlap_start = max(a_start, b_start)
            overlap_end = min(a_end, b_end)
            add_interval(overlap_start, overlap_end)
            last_end = overlap_end

            if a_end > b_end:
                j += 1
            elif a_end < b_end:
                i += 1
            else:
                i += 1
                j += 1

    if last_end < length - 1:
        add_interval(last_end + 1, length - 1)

    if not merged_intervals:
        bit_count = 0
        bit_idx_first = -1
        bit_idx_last = -1
    else:
        bit_count = sum(end - start + 1 for start, end in merged_intervals)
        bit_idx_first = merged_intervals[0][0]
        bit_idx_last = merged_intervals[-1][1]

    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()

    return LiteBitmask(length, bit_intervals)


def logical_xnor_v5(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Avoid the merge process, and refine the logic in appending in merge intervals by updating the boundary"""
    a_intervals = a_bitmask.bit_intervals.intervals
    b_intervals = b_bitmask.bit_intervals.intervals
    length = a_bitmask.length
    merged_intervals = []
    merged_intervals_append = merged_intervals.append
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    last_end = -1

    while i < a_intervals_len or j < b_intervals_len:
        if i < a_intervals_len:
            a_start, a_end = a_intervals[i]
        else:
            a_start, a_end = length, length

        if j < b_intervals_len:
            b_start, b_end = b_intervals[j]
        else:
            b_start, b_end = length, length
        # Append the non-overlapping intervals into the result
        if a_end < b_start:
            if last_end < a_start - 1:
                if not merged_intervals or merged_intervals[-1][1] < last_end:
                    merged_intervals_append([last_end + 1, a_start - 1])
                else:
                    merged_intervals[-1][1] = a_start - 1
            last_end = a_end
            i += 1
        elif b_end < a_start:
            if last_end < b_start - 1:
                if not merged_intervals or merged_intervals[-1][1] < last_end:
                    merged_intervals_append([last_end + 1, b_start - 1])
                else:
                    merged_intervals[-1][1] = b_start - 1
            last_end = b_end
            j += 1
        else:
            if last_end < min(a_start, b_start) - 1:
                if not merged_intervals or merged_intervals[-1][1] < last_end:
                    merged_intervals_append([last_end + 1, min(a_start, b_start) - 1])
                else:
                    merged_intervals[-1][1] = min(a_start, b_start) - 1

            overlap_start = max(a_start, b_start)
            overlap_end = min(a_end, b_end)
            if not merged_intervals or merged_intervals[-1][1] < overlap_start - 1:
                merged_intervals_append([overlap_start, overlap_end])
            else:
                merged_intervals[-1][1] = overlap_end
            last_end = overlap_end

            if a_end > b_end:
                j += 1
            elif a_end < b_end:
                i += 1
            else:
                i += 1
                j += 1

    if last_end < length - 1:
        if not merged_intervals or merged_intervals[-1][1] < last_end:
            merged_intervals_append([last_end + 1, length - 1])
        else:
            merged_intervals[-1][1] = length - 1

    if not merged_intervals:
        bit_count = 0
        bit_idx_first = -1
        bit_idx_last = -1
    else:
        bit_count = sum(end - start + 1 for start, end in merged_intervals)
        bit_idx_first = merged_intervals[0][0]
        bit_idx_last = merged_intervals[-1][1]

    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()

    return LiteBitmask(length, bit_intervals)


def logical_xnor_v6_append_interval(merged_intervals, merge_intervals_append, last_end, new_start, new_end, bit_count):
    """A function with merge interval conditions and sum of bit_count to avoid redundant code, but it is slower"""
    if not merged_intervals or merged_intervals[-1][1] < last_end:
        merge_intervals_append([new_start, new_end])
        bit_count += new_end - new_start + 1
    else:
        bit_count += new_end - merged_intervals[-1][1]
        merged_intervals[-1][1] = new_end
    return bit_count


def logical_xnor_v6(a_bitmask: "LiteBitmask", b_bitmask: "LiteBitmask") -> "LiteBitmask":
    """Avoid the merge process, and refine the logic in appending in merge intervals by updating the boundary"""
    a_intervals = a_bitmask.bit_intervals.intervals
    b_intervals = b_bitmask.bit_intervals.intervals
    length = a_bitmask.length
    merged_intervals = []
    merged_intervals_append = merged_intervals.append
    a_intervals_len = len(a_intervals)
    b_intervals_len = len(b_intervals)
    i, j = 0, 0

    last_end = -1
    bit_count = 0

    while i < a_intervals_len or j < b_intervals_len:
        if i < a_intervals_len:
            a_start, a_end = a_intervals[i]
        else:
            a_start, a_end = length, length

        if j < b_intervals_len:
            b_start, b_end = b_intervals[j]
        else:
            b_start, b_end = length, length
        # Append the non-overlapping intervals into the result
        if a_end < b_start:
            if last_end < a_start - 1:
                bit_count = logical_xnor_v6_append_interval(
                    merged_intervals, merged_intervals_append, last_end, last_end + 1, a_start - 1, bit_count
                )
            last_end = a_end
            i += 1
        elif b_end < a_start:
            if last_end < b_start - 1:
                bit_count = logical_xnor_v6_append_interval(
                    merged_intervals, merged_intervals_append, last_end, last_end + 1, b_start - 1, bit_count
                )
            last_end = b_end
            j += 1
        else:
            if last_end < min(a_start, b_start) - 1:
                bit_count = logical_xnor_v6_append_interval(
                    merged_intervals, merged_intervals_append, last_end, last_end + 1, min(a_start, b_start) - 1, bit_count
                )

            overlap_start = max(a_start, b_start)
            overlap_end = min(a_end, b_end)
            bit_count = logical_xnor_v6_append_interval(
                merged_intervals, merged_intervals_append, overlap_start - 1, overlap_start, overlap_end, bit_count
            )
            last_end = overlap_end
            if a_end > b_end:
                j += 1
            elif a_end < b_end:
                i += 1
            else:
                i += 1
                j += 1

    if last_end < length - 1:
        bit_count = logical_xnor_v6_append_interval(
            merged_intervals, merged_intervals_append, last_end, last_end + 1, length - 1, bit_count
        )

    if not merged_intervals:
        bit_idx_first = -1
        bit_idx_last = -1
    else:
        bit_idx_first = merged_intervals[0][0]
        bit_idx_last = merged_intervals[-1][1]

    bit_intervals = BitIntervals(bit_count, merged_intervals, bit_idx_first, bit_idx_last, no_copy=True)
    bit_intervals.validate()

    return LiteBitmask(length, bit_intervals)


class IntervalError(Exception):
    pass


def convert_intervals_to_bitintervals(intervals: IntervalList) -> BitIntervals:
    if intervals:

        def get_delta(bir):
            a = bir[0]
            b = bir[1]
            if a > b:
                raise IntervalError(f"negative interval: {bir}")
            return (b - a) + 1

        bit_count = reduce(lambda a, b: a + b, map(get_delta, intervals))
    else:
        bit_count = 0

    try:
        bit_idx_first = intervals[0][0]
        bit_idx_last = intervals[-1][1]
    except IndexError:
        bit_idx_first = -1
        bit_idx_last = -1
    bit_intervals = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last)
    return bit_intervals


def index_list_to_bitintervals_v0(idx_lst: Union[Set[int], List[int]], limit=None, sort=True) -> BitIntervals:
    """converting list to bit_intervals"""
    if idx_lst:
        if sort:
            if type(idx_lst) == set:
                idx_lst = sorted(idx_lst)
            else:
                idx_lst = sorted(set(idx_lst))
        bit_count = len(idx_lst)
        bit_idx_first = idx_lst[0]
        bit_idx_last = idx_lst[-1]
        intervals = []
        count = 0

        # TODO: attempting rewrite.
        # intervals2 = []
        # idx = 0
        # len_idx_lst = len(idx_lst)
        # a = idx_lst[idx]
        # b = idx_lst[idx]
        # pb = b
        # while idx < len_idx_lst or (limit is not None and idx < limit):
        #     if b > (pb + 1):
        #         # not sequential.
        #         intervals2.append([a, a])
        #     if a == b:
        #         pb = b
        #     else:
        #         # a != b
        #         if (pb + 1) == b:
        #             # sequential
        #             pass
        #         else:
        #             # start a new interval
        #             intervals2.append([a, pb])
        #             a = idx_lst[idx]
        #     idx += 1
        #     try:
        #         pb = b
        #         b = idx_lst[idx]
        #     except IndexError:
        #         intervals2.append([a, pb])

        # VERY SLOW.
        for key, group in groupby(enumerate(idx_lst), lambda ix: ix[0] - ix[1]):
            interval = [item[1] for item in group]
            intervals.append([interval[0], interval[-1]])
            count += 1
            if limit and limit >= count:
                raise Exception(f"Over limit:{limit}")

        # assert intervals == intervals2

        result = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True)
    else:
        result = BitIntervals(0, [], -1, -1)
    return result


def index_list_to_bitintervals_v1(idx_lst: Union[Set[int], List[int]], limit=None, sort=True) -> BitIntervals:
    """Improved function by avoid using groupby"""
    if idx_lst:
        if sort:
            if type(idx_lst) == set:
                idx_lst = sorted(idx_lst)
            else:
                idx_lst = sorted(set(idx_lst))

        bit_count = len(idx_lst)
        bit_idx_first = idx_lst[0]
        bit_idx_last = idx_lst[-1]
        intervals = []
        count = 0

        start = idx_lst[0]
        prev = idx_lst[0]
        len_idx_lst = len(idx_lst)
        intervals_append = intervals.append
        for i in range(1, len_idx_lst):
            current = idx_lst[i]
            if current != prev + 1:
                intervals_append([start, prev])
                start = current
                count += 1
                if limit and count >= limit:
                    raise Exception(f"Over limit: {limit}")
            prev = current

        intervals_append([start, prev])

        result = BitIntervals(bit_count, intervals, bit_idx_first, bit_idx_last, no_copy=True)
    else:
        result = BitIntervals(0, [], -1, -1)

    return result


def index_list_to_bitintervals_v2(idx_lst: Union[Set[int], List[int]], limit=None, sort=True) -> BitIntervals:
    """Refine the logic with early return and sort check"""
    if not idx_lst:
        return BitIntervals(0, [], -1, -1)

    if sort:
        idx_lst = sorted(idx_lst) if isinstance(idx_lst, set) else sorted(set(idx_lst))
    else:
        idx_lst = sorted(set(idx_lst))

    intervals = []
    start = prev = idx_lst[0]
    count = 0
    len_idx_lst = len(idx_lst)
    intervals_append = intervals.append
    for i in range(1, len_idx_lst):
        current = idx_lst[i]
        if current != prev + 1:
            intervals_append([start, prev])
            start = current
            count += 1
            if limit and count >= limit:
                raise Exception(f"Over limit: {limit}")
        prev = current

    intervals_append([start, prev])

    return BitIntervals(len_idx_lst, intervals, idx_lst[0], idx_lst[-1], no_copy=True)


def repr_LiteBitmask_zasb(bitmask: "LiteBitmask"):
    bits = (str(i) for i in bitmask.get_bits())
    result = f"LiteBitmask.zeros_and_set_bits({bitmask.length}, ({', '.join(bits)}))"
    return result


def repr_LiteBitmask_zasi(bitmask: "LiteBitmask"):
    result = f"LiteBitmask.zeros_and_set_intervals({bitmask.length}, {str(bitmask.bit_intervals.intervals)})"
    return result


class BitmaskFlags:
    writeable = False
    keys = ("writeable",)

    @staticmethod
    def from_dict(dct):
        bf = BitmaskFlags()
        for key, value in dct.items():
            if key in BitmaskFlags.keys:
                setattr(bf, key, value)
            else:
                raise Exception(f"invalid key {key} for BitmaskFlags")
        return bf

    def to_dict(self):
        return dict(writeable=self.writeable)


class LiteBitmask(EnhancedBitmaskLike):
    length: int = 0  # same as size
    bit_count: int = 0
    bit_idx_first = -1
    bit_idx_last = -1
    flags = None
    repr_func = repr_LiteBitmask_zasi

    def __init__(self, length, bit_intervals: BitIntervals):
        self.length: int = length
        self.bit_count: int = bit_intervals.bit_count
        self.bit_idx_first: int = bit_intervals.bit_idx_first
        self.bit_idx_last: int = bit_intervals.bit_idx_last
        # self.intervals: list = bit_intervals.intervals
        self.bit_intervals: BitIntervals = bit_intervals
        if self.bit_idx_last > (self.length - 1):
            raise Exception(f"Bit set out of range: {self.bit_idx_last} length:{self.length}")
        self.flags = BitmaskFlags()

    # def __new__(cls, *args, **kwargs):

    def __setstate__(self, state):
        """used to recreate the object after pickle"""
        self.length = state["length"]
        self.bit_count = state["bit_count"]
        self.bit_idx_first = state["bit_idx_first"]
        self.bit_idx_last = state["bit_idx_last"]
        self.bit_intervals = BitIntervals.from_dict(state["bit_intervals"])
        self.flags = BitmaskFlags.from_dict(state["flags"])

    def __getstate__(self):
        """used to dump the state of an object for pickling"""
        return self.to_dict()

    def intersection_any(self, other: EnhancedBitmaskLike):
        return intersection_any(self, other)

    def __repr__(self):
        # https://github.com/pandas-dev/pandas/issues/17695
        # from pandas.io.formats import printing as pd_printing
        # pd_printing.is_sequence = lambda obj: False
        return LiteBitmask.repr_func(self)

    def __getitem__(self, item):
        # fixme: this doesn't handle negatives in the slices.
        if not isinstance(item, slice):
            if item >= self.length:
                raise IndexError
            for (
                s,
                e,
            ) in self.bit_intervals.intervals:  # TODO: do a binary search or some neat source.
                if s <= item <= e:
                    return True
            return False
        else:
            lst = []
            c_idx = 0
            # if item.start >= self.length or item.stop >= self.length:
            #     pass  # if the array is
            start = item.start
            stop = item.stop
            # if you ask beyond the end of the array, isolate the range to the stop.
            if stop >= self.length:
                stop = self.length
            step = item.step
            rangeish = [i for i in [start, stop, step] if i is not None]
            for idx in range(*rangeish):  # Nlog(M)
                found = False
                for i, (s, e) in enumerate(self.bit_intervals.intervals[c_idx:]):
                    if s <= idx <= e:
                        found = True
                        c_idx = c_idx + i
                        break
                lst.append(found)
        return lst

    def __or__(self, other) -> "LiteBitmask":
        return LiteBitmask.logical_or(self, other)

    def __and__(self, other) -> "LiteBitmask":
        return LiteBitmask.logical_and(self, other)

    def __sub__(self, other):
        if intersection_possible(self, other):
            indexes_self = get_interval_bits_set(self.bit_intervals)
            indexes_other = get_interval_bits_set(other.bit_intervals)
            indexes = indexes_self - indexes_other
            # fixme: todo
            # indexes = intervals_subtract(self.intervals, other.intervals)
            # this might be slow.  the set operation isn't ordered so we need order it.
            bitmask = LiteBitmask.zeros_and_set_bits(self.length, indexes)
        else:
            bitmask = self.copy()
        return bitmask

    def __eq__(self, o: "LiteBitmask") -> "LiteBitmask":
        if type(o) == LiteBitmask:
            lst = map(lambda a: a[0] == a[1], zip(self.tolist(), o.tolist()))
            bits = [idx for idx, value in enumerate(lst) if value]
            result = LiteBitmask.zeros_and_set_bits(self.length, bits)
        else:
            result = super().__eq__(o)
        return result

    def copy(self, *args, **kwargs):
        return LiteBitmask.zeros_and_set_intervals(self.length, self.bit_intervals.intervals)

    def validate(self):
        if self.bit_count > 0:
            assert self.bit_idx_first == self.bit_intervals.intervals[0][0], self.bit_intervals.to_dict()
            assert self.bit_idx_last == self.bit_intervals.intervals[-1][1]
            assert self.bit_idx_last <= self.length
        else:
            assert self.bit_idx_first == -1
            assert self.bit_idx_last == -1
        count = 0
        for s, e in self.bit_intervals.intervals:
            for i in range(s, e + 1):
                count += 1
        assert count == self.bit_count
        self.bit_intervals.validate()

    @staticmethod
    def get_bitmask_from_idx_lst(bit_count, idx_lst) -> "LiteBitmask":
        """create a bitmask from an index list"""
        return LiteBitmask.zeros_and_set_bits(bit_count, idx_lst)

    @staticmethod
    def get_idx_lst_from_bitmask(bitmask, bit_idx_first: int = None, bit_idx_last: int = None) -> list:
        """return the indexes set inside the bitmask"""
        # FIXME: priority 1, get bit_idx_last working and this is probably slow.
        if isinstance(bitmask, LiteBitmask):
            idx_lst = [_ for _ in get_interval_bits(bitmask.bit_intervals, bit_idx_first=bit_idx_first)]
        else:
            idx_lst = EnhancedBitmask.get_idx_lst_from_bitmask(bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
        return idx_lst

    def get_bits(self, bit_idx_first=0) -> Union[list[str], Generator]:
        """return the bit index list"""
        # FIXME: remove, use get_idx_lst_from_bitmask
        return LiteBitmask.get_idx_lst_from_bitmask(self)

    def tolist(self):
        """genereate the full bitmask as a list."""
        bit_idx_gen = LiteBitmask.get_idx_lst_from_bitmask(self)
        lst: list[bool] = [False] * self.length
        for idx in bit_idx_gen:
            lst[idx] = True
        return lst

    @staticmethod
    def array(lst, **kwargs) -> "LiteBitmask":
        bits = [idx for idx, value in enumerate(lst) if value]
        bitmask = LiteBitmask.zeros_and_set_bits(len(lst), bits)
        return bitmask

    @staticmethod
    def zeros(length: int) -> "LiteBitmask":
        return LiteBitmask(length, BitIntervals(0, [], -1, -1))

    @staticmethod
    def ones(length: int) -> "LiteBitmask":
        return LiteBitmask(length, BitIntervals(length, [[0, length - 1]], 0, length - 1))

    @staticmethod
    def zeros_and_set_bits(length: int, bit_lst: Union[List[int], Tuple[int, ...], Set[int]], **kwargs):
        bit_intervals = index_list_to_bitintervals(bit_lst)
        return LiteBitmask(length, bit_intervals)

    @staticmethod
    def zeros_and_set_ranges(length: int, bit_ranges: list[Union[list[int, int], list[int, int, int]]]):
        """given a tuple of range slices, create a bitmask.
        range slices are a->b excluding b"""
        bit_indexes = sorted(set(i for bit_range in bit_ranges for i in range(*bit_range)))
        bit_intervals = index_list_to_bitintervals(bit_indexes)
        return LiteBitmask(length, bit_intervals)

    @staticmethod
    def zeros_and_set_intervals(length: int, intervals: IntervalList):
        """given a tuple of intervals, create a bitmask.
        intervals are a->b including b
        do_sort_merge will sort the intervals and remove duplicates"""
        bit_intervals = BitIntervals.from_intervals_list(intervals)
        return LiteBitmask(length, bit_intervals)

    @staticmethod
    def zeros_and_set_bit(length: int, bit_idx):
        return LiteBitmask(
            length,
            BitIntervals(
                1,
                [
                    [bit_idx, bit_idx],
                ],
                bit_idx,
                bit_idx,
            ),
        )

    @staticmethod
    def from_dict(dct):
        bit_intervals = BitIntervals.from_dict(dct["bit_intervals"])
        # bit_intervals = BitIntervals(dct['bit_count'], dct['intervals'], dct['bit_idx_first'], dct['bit_idx_last'])
        bitmask = LiteBitmask(dct["length"], bit_intervals)
        bitmask.flags = BitmaskFlags.from_dict(dct["flags"])
        return bitmask

    def to_dict(self):
        dct = {
            "length": self.length,
            "bit_count": self.bit_count,
            "bit_idx_first": self.bit_idx_first,
            "bit_idx_last": self.bit_idx_last,
            "bit_intervals": self.bit_intervals.to_dict(),
            "flags": self.flags.to_dict(),
        }
        return dct

    def to_json(self):
        json_str = simplejson.dumps(self, cls=DataEncoder)
        return json_str

    @staticmethod
    def from_json(json_str):
        obj = simplejson.loads(json_str, cls=DataDecoder)
        return obj

    def __len__(self):
        return self.length

    def __add__(self, other):
        return self | other

    def __xor__(self, other):
        return LiteBitmask.logical_xor(self, other)

    @staticmethod
    def add(x1, x2, **kwargs):
        return x1 | x2

    @staticmethod
    def logical_xor(x: "LiteBitmask", y: "LiteBitmask") -> "LiteBitmask":
        bit_intervals = intervals_xor(x.bit_intervals, y.bit_intervals)
        bitmask = LiteBitmask(x.length, bit_intervals)
        return bitmask

    @staticmethod
    def logical_not(x: "LiteBitmask") -> "LiteBitmask":
        bitmask = logical_not(x)
        return bitmask

    @staticmethod
    def logical_xnor(x: "LiteBitmask", y: "LiteBitmask") -> "LiteBitmask":
        return logical_xnor(x, y)

    def intersection_all(self, other):
        xnored = self.logical_xnor(self, other)
        if xnored.bit_count == xnored.length:
            result = True
        else:
            result = False
        return result

    @staticmethod
    def logical_or(x: "LiteBitmask", y: "LiteBitmask"):
        bit_intervals = intervals_or(x.bit_intervals, y.bit_intervals)
        bitmask = LiteBitmask(x.length, bit_intervals)
        return bitmask

    @staticmethod
    def logical_and(x, y):
        if intersection_possible(x, y):
            bit_intervals = intervals_and(x.bit_intervals, y.bit_intervals)
            bitmask = LiteBitmask(x.length, bit_intervals)
        else:
            bitmask = LiteBitmask.zeros(x.length)
        return bitmask

    def intersection(self, other):
        if intersection_possible(self, other):
            real_intersection = LiteBitmask.logical_and(self, other)
        else:
            real_intersection = None
        return real_intersection

    @classmethod
    def logical_is_subset(cls, left, right):
        # FIXME: see EnhancedBitmask for a much faster version.
        result = cls.logical_and(right, cls.logical_not(left))
        return result

    @classmethod
    def join_bitmasks(cls, bitmask_lst: list, operation) -> "LiteBitmask":
        """
        given a list of bitmasks, join them
        """
        if len(bitmask_lst) == 0:
            raise Exception("Must provide a bitmask to join with.  The length is required.")
        elif len(bitmask_lst) == 1:
            result_bitmask = bitmask_lst[0].copy()
        else:
            length = 0
            for bitmask in bitmask_lst:
                if length == 0:
                    length = bitmask.length
                if bitmask.length != length:
                    raise AssertionError("Cannot join bitmasks with different sizes")
            if operation == cls.logical_or:
                # ordering them by bit_count
                bitmask_lst_orig = sorted(bitmask_lst, key=lambda obj: -obj.bit_count)
                first_orig_bitmask = bitmask_lst_orig[0]
                if first_orig_bitmask.length == first_orig_bitmask.bit_count:
                    result_bitmask = first_orig_bitmask.copy()
                else:
                    result_bitmask = reduce(operation, bitmask_lst)
            else:
                result_bitmask = reduce(operation, bitmask_lst)
        return result_bitmask

    @staticmethod
    def sign(x1) -> "EnhancedBitmaskLike":
        return x1.copy()


DataEncoder.register(LiteBitmask, LiteBitmask.to_dict)
DataDecoder.register("LiteBitmask", LiteBitmask.from_dict)
DataEncoder.register(BitIntervals, BitIntervals.to_dict)
DataDecoder.register("BitIntervals", BitIntervals.from_dict)
DataEncoder.register(BitmaskFlags, BitmaskFlags.to_dict)
DataDecoder.register("BitmaskFlags", BitmaskFlags.from_dict)


def convert_bitmask_to_litebitmask(bitmask: EnhancedBitmask):
    length: int = bitmask.length
    obj = LiteBitmask.zeros_and_set_bits(length, bitmask.get_idx_lst_from_bitmask(bitmask))
    return obj


def series_intersection_any(ser, search_bitmask: Union[EnhancedBitmaskLike, str]):
    """function to apply to a pandas series of bitmasks to search for an intersection."""
    return [bitmask.intersection_any(search_bitmask) if bitmask != "" else False for bitmask in ser]


"""Best versions so far"""
intervals_subtract = intervals_subtract_v2
intervals_xor = intervals_xor_v7
intervals_and = intervals_and_v1
logical_xnor = logical_xnor_v6
index_list_to_bitintervals = index_list_to_bitintervals_v2
intervals_or = intervals_or_v7

"""
An example of using arrays to contain different function names
intervals_subtract_funcs = [
    ('intervals_subtract_v0', intervals_subtract_v0),
    ('intervals_subtract_v1', intervals_subtract_v1),
    ('intervals_subtract_v2', intervals_subtract_v2),
]

intervals_subtract_funcs =['intervals_subtract_v0', 'intervals_subtract_v1', 'intervals_subtract_v2']
"""