# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.

"""
Unit tests for timeline.
"""
from _decimal import Underflow, Overflow
from pprint import pformat
from typing import List

import numpy as np
import numpy.testing as npt
import pandas as pd
from dateutil.parser import parse as date_parse

from Octeres.bitmask import BASE_DTYPE_MAX
from Octeres.bitmask import eb
from Octeres.timeline import Dependency_Funcs, TLCollisionOverlap, TimelinePit
from Octeres.timeline import Event_Handler
from Octeres.timeline import (
    Timeline,
    sum_array,
    TLEvent,
    TLPointInTime,
    Event_Direction,
)
from Octeres.timeline import TimelineDict, TimelineMask
from Octeres.timeline import reduce_holes, Point_in_Time
from Octeres.util import superprint
import pytest


class Test_Timeline:
    @classmethod
    def setup_class(cls):
        machine_name = "test"
        range_start = date_parse("2015-01-01")
        range_end = date_parse("2015-02-15")
        base_mask = eb.zeros(8)
        cls.tl = Timeline(machine_name, range_start, range_end, base_mask)

    def test_get_possible_unit_seconds(self):
        tl = self.tl
        possible_unit_seconds = tl.get_possible_unit_seconds()
        correct_unit_seconds = (tl.range_end - tl.range_start).total_seconds() * len(tl.base_mask)
        assert possible_unit_seconds == correct_unit_seconds

    #     def test_prepare_event_timeline_00(self):
    #         #TODO
    #         pass
    #
    #     def test_group_timeline_00(self):
    #         #TODO
    #         pass
    #
    #     def test_normalize_timeline_00(self):
    #         #TODO
    #         pass
    #
    #     def test_mask_timeline_to_mag_timeline_00(self):
    #         #TODO
    #         pass
    #

    def test_search_timeline_for_holes_00(self):
        mask_timeline = TimelineMask()
        mask_timeline.append((date_parse("2017-01-01"), eb.ones(8)))
        mask_timeline.append((date_parse("2017-01-02"), eb.ones(8)))
        mask_timeline.append((date_parse("2017-01-03"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))
        mask_timeline.append((date_parse("2017-01-04"), eb.ones(8)))

        holes = self.tl.search_timeline_for_holes(mask_timeline)

        assert len(holes) == 1
        correct_ts = date_parse("2017-01-03")
        correct_te = date_parse("2017-01-04")
        correct_delta = eb.array([0, 1, 0, 0, 0, 0, 0, 0])
        hole_ts = holes[0][0]
        hole_te = holes[0][1]
        hole_delta = holes[0][2]
        superprint(correct_delta)
        superprint(hole_delta)
        assert correct_ts == hole_ts
        assert correct_te == hole_te
        assert (correct_delta == hole_delta).all()

    def test_search_timeline_for_holes_01(self):
        mask_timeline = TimelineMask()
        mask_timeline.append((date_parse("2017-01-01"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))
        mask_timeline.append((date_parse("2017-01-02"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))
        mask_timeline.append((date_parse("2017-01-03"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))
        mask_timeline.append((date_parse("2017-01-04"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))
        mask_timeline.append((date_parse("2017-01-05"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))
        mask_timeline.append((date_parse("2017-01-06"), eb.array([1, 0, 1, 1, 1, 1, 1, 1])))

        holes = self.tl.search_timeline_for_holes(mask_timeline)

        assert len(holes) == 1
        correct_delta = eb.array([0, 1, 0, 0, 0, 0, 0, 0])
        ts, te, hole_delta = holes[0]
        assert ts == date_parse("2017-01-01")
        assert te == date_parse("2017-01-06")
        assert (correct_delta == hole_delta).all()

    #     def test_calculate_area_00(self):
    #         #TODO
    #         pass
    def test_collision_detection_00(self):
        base_mask = self.tl.base_mask
        tl = self.tl

        event_lst = list()
        event_mask = base_mask.copy()
        event_mask[0:3] = 1
        dct = dict(
            pk="an_event",
            event_type_name="job",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 00:00:00"),
            time_end=date_parse("2015-02-04 00:00:00"),
        )
        event_lst.append(dct)
        event_mask = base_mask.copy()
        event_mask[1] = 1
        dct = dict(
            pk="an_event2",
            event_type_name="job",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 01:00:00"),
            time_end=date_parse("2015-02-03 02:00:00"),
        )
        event_lst.append(dct)
        timeline_lst = tl.prepare_event_timeline(event_lst)

        timeline_dct = tl.group_timeline(timeline_lst)
        timeline_sorted = tl.sort_timeline(timeline_dct)
        collisions = tl.find_collisions_timeline(timeline_sorted)
        for collision in collisions:
            superprint(collision)
        assert len(collisions) == 1

    def test_collision_detection_01(self):
        # https://pandas.pydata.org/pandas-docs/version/0.21.1/generated/pandas.Timestamp.to_datetime.html
        base_mask = self.tl.base_mask
        tl = self.tl

        event_lst = list()
        event_mask = base_mask.copy()
        event_mask[0:3] = 1
        dct: TLEvent = dict(
            pk="an_event",
            event_type_name="job",
            bitmask=event_mask,
            time_start=pd.Timestamp(date_parse("2015-02-03 00:00:00")),
            time_end=pd.Timestamp(date_parse("2015-02-04 00:00:00")),
        )
        event_lst.append(dct)
        event_mask = base_mask.copy()
        event_mask[1] = 1
        dct: TLEvent = dict(
            pk="an_event2",
            event_type_name="job",
            bitmask=event_mask,
            time_start=pd.Timestamp(date_parse("2015-02-03 01:00:00")),
            time_end=pd.Timestamp(date_parse("2015-02-03 02:00:00")),
        )
        event_lst.append(dct)
        timeline_lst: List[TLPointInTime] = tl.prepare_event_timeline(event_lst)

        timeline_dct: TimelineDict = tl.group_timeline(timeline_lst)
        timeline_sorted: TimelinePit = tl.sort_timeline(timeline_dct)
        tl.print_timeline(timeline_sorted)
        collisions: List[TLCollisionOverlap] = tl.find_collisions_timeline(timeline_sorted)
        for collision in collisions:
            superprint(collision)
        assert len(collisions) == 1

    def test_full_stack(self):
        base_mask = self.tl.base_mask
        tl = self.tl

        possible_unit_seconds = tl.get_possible_unit_seconds()
        event_handler = Event_Handler()

        event_lst = list()
        event_mask = base_mask.copy()
        event_mask[0:3] = 1
        event_dct = event_handler.get_event(
            event_type_name="job",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03"),
            time_end=date_parse("2015-02-04"),
        )
        event_lst.append(event_dct)
        correct_unit_seconds = (event_dct["time_end"] - event_dct["time_start"]).total_seconds() * sum_array(event_mask)
        correct_possible_unit_seconds = (tl.range_end - tl.range_start).total_seconds() * len(tl.base_mask)

        timeline_lst = tl.prepare_event_timeline(event_lst)

        timeline_dct = tl.group_timeline(timeline_lst)
        timeline_sorted = tl.sort_timeline(timeline_dct)

        # mask_timeline = tl.normalize_timeline(timeline_sorted, test_negative=True)
        mask_timeline = tl.normalize_timeline_pit(timeline_sorted)
        mask_timeline = tl.sum_timeline(mask_timeline, test_negative=False)
        tl.print_timeline(mask_timeline, binary=True)
        mask_timeline = tl.flatten_timeline(mask_timeline)

        mask_timeline = tl.isolate_timeline_range(mask_timeline)

        # tl.print_timeline(mask_timeline)
        mag_timeline = tl.mask_timeline_to_mag_timeline(mask_timeline)

        consumed_unit_seconds = tl.calculate_area(mag_timeline)

        # return consumed_unit_seconds, possible_unit_seconds, mag_timeline
        assert round(abs(consumed_unit_seconds - correct_unit_seconds), 4) == 0
        assert possible_unit_seconds == correct_possible_unit_seconds

    def test_find_event_dependencies_00(self):
        base_mask = self.tl.base_mask
        tl = self.tl
        event_handler = Event_Handler()
        event_lst = list()
        event_mask = base_mask.copy()
        event_mask[0] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 00:00:00"),
            time_end=date_parse("2015-02-04 00:00:00"),
        )
        event_a_pk = event_dct["pk"]
        event_lst.append(event_dct)
        # joint events in time VVV ^^^
        event_mask = base_mask.copy()
        event_mask[0] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 12:00:00"),
            time_end=date_parse("2015-02-05 00:00:00"),
        )
        event_b_pk = event_dct["pk"]
        event_lst.append(event_dct)
        dependency_functions = list()
        dependency_functions.append(Dependency_Funcs.dep_time)
        dependency_functions.append(Dependency_Funcs.dep_space_all)
        event_dependancies = tl.find_event_dependencies(event_lst, dependency_functions)
        assert len(event_dependancies) == 2
        grouped_dependancies = tl.merge_event_dependencies(event_dependancies)
        assert len(grouped_dependancies) == 2
        assert event_a_pk in grouped_dependancies
        assert event_b_pk in grouped_dependancies[event_a_pk]
        assert event_b_pk in grouped_dependancies
        assert event_a_pk in grouped_dependancies[event_b_pk]

    def test_find_event_dependencies_01(self):
        base_mask = self.tl.base_mask
        tl = self.tl
        event_handler = Event_Handler()
        event_lst = list()
        event_mask = base_mask.copy()
        event_mask[0] = 2
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 00:00:00"),
            time_end=date_parse("2015-02-04 00:00:00"),
        )
        event_a_pk = event_dct["pk"]
        event_lst.append(event_dct)
        # joint events in time VVV ^^^
        event_mask = base_mask.copy()
        event_mask[0] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 12:00:00"),
            time_end=date_parse("2015-02-05 00:00:00"),
        )
        event_b_pk = event_dct["pk"]
        event_lst.append(event_dct)
        dependency_functions = list()
        dependency_functions.append(Dependency_Funcs.dep_time)
        dependency_functions.append(Dependency_Funcs.dep_space_all)
        event_dependancies = tl.find_event_dependencies(event_lst, dependency_functions)
        assert len(event_dependancies) == 2
        grouped_dependancies = tl.merge_event_dependencies(event_dependancies)
        assert len(grouped_dependancies) == 2
        assert event_a_pk in grouped_dependancies
        assert event_b_pk in grouped_dependancies[event_a_pk]
        assert event_b_pk in grouped_dependancies
        assert event_a_pk in grouped_dependancies[event_b_pk]

    def test_find_event_dependencies_02(self):
        base_mask = self.tl.base_mask
        tl = self.tl
        event_handler = Event_Handler()
        event_lst = list()
        event_mask = base_mask.copy()
        event_mask[0] = -1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 00:00:00"),
            time_end=date_parse("2015-02-04 00:00:00"),
        )
        event_a_pk = event_dct["pk"]
        event_lst.append(event_dct)
        # joint events in time VVV ^^^
        event_mask = base_mask.copy()
        event_mask[0] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 12:00:00"),
            time_end=date_parse("2015-02-05 00:00:00"),
        )
        event_b_pk = event_dct["pk"]
        event_lst.append(event_dct)
        dependency_functions = list()
        dependency_functions.append(Dependency_Funcs.dep_time)
        dependency_functions.append(Dependency_Funcs.dep_space_all)
        event_dependancies = tl.find_event_dependencies(event_lst, dependency_functions)
        assert len(event_dependancies) == 2
        grouped_dependancies = tl.merge_event_dependencies(event_dependancies)
        assert len(grouped_dependancies) == 2
        assert event_a_pk in grouped_dependancies
        assert event_b_pk in grouped_dependancies[event_a_pk]
        assert event_b_pk in grouped_dependancies
        assert event_a_pk in grouped_dependancies[event_b_pk]

    def test_find_event_dependencies_03(self):
        base_mask = self.tl.base_mask
        tl = self.tl

        event_handler = Event_Handler()

        event_lst = list()

        event_mask = base_mask.copy()
        event_mask[0] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 00:00:00"),
            time_end=date_parse("2015-02-04 00:00:00"),
        )
        event_a_pk = event_dct["pk"]
        event_lst.append(event_dct)

        # disjoin events in time VVV ^^^

        event_mask = base_mask.copy()
        event_mask[0] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-05 00:00:00"),
            time_end=date_parse("2015-02-06 00:00:00"),
        )
        event_b_pk = event_dct["pk"]
        event_lst.append(event_dct)

        # joining event, but larger mask VVV

        event_mask = base_mask.copy()
        event_mask[0:2] = 1
        event_dct = event_handler.get_event(
            event_type_name="event",
            bitmask=event_mask,
            time_start=date_parse("2015-02-03 12:00:00"),
            time_end=date_parse("2015-02-05 12:00:00"),
        )
        event_c_pk = event_dct["pk"]
        event_lst.append(event_dct)

        dependency_functions = list()
        dependency_functions.append(Dependency_Funcs.dep_time)
        dependency_functions.append(Dependency_Funcs.dep_space_all)
        # we have time_start, time_end, bitmask, and event_type_name

        event_dependancies = tl.find_event_dependencies(event_lst, dependency_functions)
        assert len(event_dependancies) == 0

        dependency_functions = list()
        dependency_functions.append(Dependency_Funcs.dep_time)
        event_dependancies = tl.find_event_dependencies(event_lst, dependency_functions)
        superprint(pformat(event_dependancies))
        assert len(event_dependancies) == 4

        tl.generate_dependency_graph(
            event_handler.nodes,
            event_dependancies,
            filename="test_find_event_dependencies_01",
        )

        grouped_dependancies = tl.merge_event_dependencies(event_dependancies)
        assert len(grouped_dependancies) == 3
        superprint(pformat(grouped_dependancies))

        assert event_a_pk in grouped_dependancies
        assert event_b_pk in grouped_dependancies
        assert event_c_pk in grouped_dependancies


def test_reduce_holes_00():
    """ordered timeline all ts butted up against each other"""
    timeline_holes = list()
    timeline_holes.append((1, 2, np.array([1, 0, 1, 1])))
    timeline_holes.append((2, 3, np.array([1, 0, 1, 1])))
    timeline_holes.append((3, 4, np.array([1, 0, 0, 1])))
    timeline_holes.append((4, 5, np.array([1, 1, 1, 1])))
    timeline_holes.append((5, 6, np.array([1, 1, 1, 1])))
    timeline_holes.append((6, 7, np.array([1, 1, 1, 1])))

    correct_holes = list()
    correct_holes.append((1, 3, np.array([1, 0, 1, 1])))
    correct_holes.append((3, 4, np.array([1, 0, 0, 1])))
    correct_holes.append((4, 7, np.array([1, 1, 1, 1])))

    reductions, timeline_holes = reduce_holes(timeline_holes)
    superprint("\n", pformat(timeline_holes))
    for idx, hole in enumerate(timeline_holes):
        ts, te, mask = hole
        cts, cte, cmask = correct_holes[idx]
        assert ts == cts
        assert te == cte
        npt.assert_equal(mask, cmask)
    assert reductions == 3


def test_reduce_holes_01():
    """ordered timeline all ts butted up against each other"""
    timeline_holes = list()

    timeline_holes.append((0, 1, np.array([1, 1, 1, 1])))
    timeline_holes.append((1, 2, np.array([1, 0, 1, 1])))
    timeline_holes.append((2, 3, np.array([1, 0, 1, 1])))
    timeline_holes.append((3, 4, np.array([1, 0, 0, 1])))
    timeline_holes.append((4, 5, np.array([1, 1, 1, 1])))
    timeline_holes.append((5, 6, np.array([1, 1, 1, 1])))
    timeline_holes.append((6, 7, np.array([1, 1, 1, 1])))

    correct_holes = list()
    correct_holes.append((0, 1, np.array([1, 1, 1, 1])))
    correct_holes.append((1, 3, np.array([1, 0, 1, 1])))
    correct_holes.append((3, 4, np.array([1, 0, 0, 1])))
    correct_holes.append((4, 7, np.array([1, 1, 1, 1])))

    reductions, timeline_holes = reduce_holes(timeline_holes)
    superprint("\n", pformat(timeline_holes))
    for idx, hole in enumerate(timeline_holes):
        ts, te, mask = hole
        cts, cte, cmask = correct_holes[idx]
        assert ts == cts
        assert te == cte
        npt.assert_equal(mask, cmask)
    assert reductions == 3


def test_reduce_holes_02():
    """ordered timeline all ts separated a bit"""
    timeline_holes = list()

    timeline_holes.append((0, 1, np.array([1, 1, 1, 1])))
    timeline_holes.append((1, 2, np.array([1, 0, 1, 1])))
    timeline_holes.append((3, 4, np.array([1, 0, 0, 1])))
    timeline_holes.append((4, 5, np.array([1, 1, 1, 1])))
    timeline_holes.append((6, 7, np.array([1, 1, 1, 1])))

    correct_holes = list()
    correct_holes.append((0, 1, np.array([1, 1, 1, 1])))
    correct_holes.append((1, 2, np.array([1, 0, 1, 1])))
    correct_holes.append((3, 4, np.array([1, 0, 0, 1])))
    correct_holes.append((4, 5, np.array([1, 1, 1, 1])))
    correct_holes.append((6, 7, np.array([1, 1, 1, 1])))

    reductions, timeline_holes = reduce_holes(timeline_holes)
    superprint("\n", pformat(timeline_holes))
    for idx, hole in enumerate(timeline_holes):
        ts, te, mask = hole
        cts, cte, cmask = correct_holes[idx]
        assert ts == cts
        assert te == cte
        npt.assert_equal(mask, cmask)
    assert reductions == 0


def test_reduce_holes_03():
    """ordered timeline all ts separated a bit"""
    timeline_holes = list()

    timeline_holes.append((0, 1, np.array([1, 1, 1, 1])))
    timeline_holes.append((1, 2, np.array([1, 0, 1, 1])))
    timeline_holes.append((3, 4, np.array([1, 0, 0, 1])))
    timeline_holes.append((4, 6, np.array([1, 1, 1, 1])))
    timeline_holes.append((6, 7, np.array([1, 1, 1, 1])))

    correct_holes = list()
    correct_holes.append((0, 1, np.array([1, 1, 1, 1])))
    correct_holes.append((1, 2, np.array([1, 0, 1, 1])))
    correct_holes.append((3, 4, np.array([1, 0, 0, 1])))
    correct_holes.append((4, 7, np.array([1, 1, 1, 1])))

    reductions, timeline_holes = reduce_holes(timeline_holes)
    superprint("\n", pformat(timeline_holes))
    for idx, hole in enumerate(timeline_holes):
        ts, te, mask = hole
        cts, cte, cmask = correct_holes[idx]
        assert ts == cts
        assert te == cte
        npt.assert_equal(mask, cmask)
    assert reductions == 1


def test_tlevent_00():
    event: TLEvent = dict(
        pk="1",
        event_type_name="fish",
        bitmask=eb.zeros(10),
        time_start=date_parse("2020-01-01"),
        time_end=date_parse("2020-01-02"),
    )
    superprint(type(event), event)
    assert type(event) == dict


def test_tlevent_01():
    event = TLEvent(
        pk="1",
        event_type_name="fish",
        bitmask=eb.zeros(10),
        time_start=date_parse("2020-01-01"),
        time_end=date_parse("2020-01-02"),
    )
    superprint(type(event), event)


def test_tlevent_02():
    event = TLEvent(
        pk="1",
        event_type_name="fish",
        bitmask=eb.zeros(10),
        time_start=pd.Timestamp(date_parse("2020-01-01")),
        time_end=pd.Timestamp(date_parse("2020-01-02")),
    )
    superprint(type(event), event)


def test_tlpit_00():
    pit = TLPointInTime(
        pk="1",
        name="2",
        category="aaaa",
        bitmask=eb.zeros(10),
        ts=date_parse("2020-01-01"),
        direction=Event_Direction.POSITIVE,
    )
    superprint(pit)


def test_tlpit_01():
    pit = TLPointInTime(
        pk="1",
        name="2",
        category="aaaa",
        bitmask=eb.zeros(10),
        ts=date_parse("2020-01-01"),
        direction=Event_Direction.POSITIVE,
    )
    superprint(pit)


def test_tlpit_02():
    # noinspection PyTypeChecker
    pit: TLPointInTime = dict(
        pk="1",
        name="2",
        category="aaaa",
        bitmask=eb.zeros(10),
        ts=date_parse("2020-01-01"),
        direction=str(Event_Direction.POSITIVE),  # type: ignore
    )
    superprint(pit)


def test_get_mask_sum_00():
    bask_mask = eb.zeros(8)
    pit = Point_in_Time(bask_mask)

    bitmasks = [
        eb.array([0, 0, 0, 1, 1, 0, 0, 0]),
        eb.array([0, 0, 0, 1, 1, 0, 1, 0]),
        eb.array([0, 1, 1, 1, 1, 0, 0, 0]),
        eb.array([1, 0, 0, 1, 1, 0, 0, 0]),
    ]
    for idx, bitmask in enumerate(bitmasks):
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.POSITIVE,
        )
        pit.positive_add(dct)
    value = pit.get_mask_sum()
    npt.assert_array_equal(value, eb.array([1, 1, 1, 4, 4, 0, 1, 0]))


def test_get_mask_sum_01():
    bask_mask = eb.zeros(8)
    pit = Point_in_Time(bask_mask)

    bitmasks = [
        eb.array([0, 0, 0, 1, 1, 0, 0, 0]),
        eb.array([0, 0, 0, 1, 1, 0, 1, 0]),
        eb.array([0, 1, 1, 1, 1, 0, 0, 0]),
        eb.array([1, 0, 0, 1, 1, 0, 0, 0]),
    ]
    for idx, bitmask in enumerate(bitmasks):
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.POSITIVE,
        )
        pit.positive_add(dct)
    value = pit.get_mask_sum()
    npt.assert_array_equal(value, eb.array([1, 1, 1, 4, 4, 0, 1, 0]))

    bitmasks = [
        eb.array([0, 0, 1, 1, 0, 0, 0, 0]),
    ]
    for idx, bitmask in enumerate(bitmasks):
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.NEGATIVE,
        )
        pit.negative_add(dct)
    value = pit.get_mask_sum()
    npt.assert_array_equal(value, eb.array([1, 1, 0, 3, 4, 0, 1, 0]))


def test_get_mask_sum_02():
    bask_mask = eb.zeros(8)
    pit = Point_in_Time(bask_mask)

    bitmasks = [
        eb.array([0, 0, 0, 1, 1, 0, 0, 0]),
        eb.array([0, 0, 0, 1, 1, 0, 1, 0]),
        eb.array([0, 1, 1, 1, 1, 0, 0, 0]),
        eb.array([1, 0, 0, 1, 1, 0, 0, 0]),
    ]
    for idx, bitmask in enumerate(bitmasks):
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.POSITIVE,
        )
        pit.positive_add(dct)
    value = pit.get_mask_sum()
    npt.assert_array_equal(value, eb.array([1, 1, 1, 4, 4, 0, 1, 0]))

    bitmasks = [
        eb.array([0, 0, 1, 1, 0, 0, 0, 0]),
        eb.array([0, 0, 0, 0, 0, 0, 1, 0]),
    ]
    for idx, bitmask in enumerate(bitmasks):
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.NEGATIVE,
        )
        pit.negative_add(dct)
    value = pit.get_mask_sum()
    npt.assert_array_equal(value, eb.array([1, 1, 0, 3, 4, 0, 0, 0]))


def test_get_mask_sum_03():
    bask_mask = eb.zeros(8)
    pit = Point_in_Time(bask_mask)

    bitmasks = [
        eb.array([0, 0, 1, 1, 0, 0, 0, 0]),
        eb.array([0, 0, 0, 0, 0, 0, 1, 0]),
    ]
    for idx, bitmask in enumerate(bitmasks):
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.NEGATIVE,
        )
        pit.negative_add(dct)
    value = pit.get_mask_sum()
    ecorr = eb.array([0, 0, -1, -1, 0, 0, -1, 0])
    npt.assert_array_equal(value, ecorr)


def test_get_mask_sum_03b():
    # underflow
    bask_mask = eb.zeros(8)
    pit = Point_in_Time(bask_mask)
    for idx in range(BASE_DTYPE_MAX * 2):  # negative needs two more, 127 to -128
        bitmask = eb.array([0, 0, -1, -1, 0, 0, 0, 0])
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.NEGATIVE,
        )
        pit.negative_add(dct)
    with pytest.raises(Underflow):
        value = pit.get_mask_sum()


def test_get_mask_sum_03c():
    # underflow
    bask_mask = eb.zeros(8)
    pit = Point_in_Time(bask_mask)
    for idx in range(BASE_DTYPE_MAX * 2):  # negative needs two more, 127 to -128
        bitmask = eb.array([0, 0, 1, 1, 0, 0, 0, 0])
        dct: TLPointInTime = TLPointInTime(
            pk=str(idx),
            name=str(idx),
            category="a",
            bitmask=bitmask,
            ts=date_parse("2020-01-01"),
            direction=Event_Direction.NEGATIVE,
        )
        pit.negative_add(dct)
    with pytest.raises(Overflow):
        value = pit.get_mask_sum()


# def test_get_mask_sum_04():
#     pit = Point_in_Time()
#
#     bitmasks = [
#         eb.array([0, 0, 0, 1, 1, 0, 0, 0]),
#         eb.array([0, 0, 0, 1, 1, 0, 1, 0]),
#         eb.array([0, 1, 1, 1, 1, 0, 0, 0]),
#         eb.array([1, 0, 0, 1, 1, 0, 0, 0]),
#     ]
#     for idx, bitmask in enumerate(bitmasks):
#         dct: TLPointInTime = TLPointInTime(
#             pk=str(idx),
#             name=str(idx),
#             category="a",
#             bitmask=bitmask,
#             ts=date_parse('2020-01-01'),
#             direction=Event_Direction.POSITIVE,
#         )
#         pit.positive_add(dct)
#     value = pit.get_mask_sum()
#     npt.assert_array_equal(value, eb.array([1, 1, 1, 4, 4, 0, 1, 0]))
#
# def test_get_mask_sum_05():
#     pit = Point_in_Time()
#
#     bitmasks = [
#         eb.array([0, 0, 0, 1, 1, 0, 0, 0]),
#         eb.array([0, 0, 0, 1, 1, 0, 1, 0]),
#         eb.array([0, 1, 1, 1, 1, 0, 0, 0]),
#         eb.array([1, 0, 0, 1, 1, 0, 0, 0]),
#     ]
#     for idx, bitmask in enumerate(bitmasks):
#         dct: TLPointInTime = TLPointInTime(
#             pk=str(idx),
#             name=str(idx),
#             category="a",
#             bitmask=bitmask,
#             ts=date_parse('2020-01-01'),
#             direction=Event_Direction.POSITIVE,
#         )
#         pit.positive_add(dct)
#     value = pit.get_mask_sum()
#     npt.assert_array_equal(value, eb.array([1, 1, 1, 4, 4, 0, 1, 0]))
#
#     bitmasks = [
#         eb.array([0, 0, 1, 1, 0, 0, 0, 0]),
#     ]
#     for idx, bitmask in enumerate(bitmasks):
#         dct: TLPointInTime = TLPointInTime(
#             pk=str(idx),
#             name=str(idx),
#             category="a",
#             bitmask=bitmask,
#             ts=date_parse('2020-01-01'),
#             direction=Event_Direction.NEGATIVE,
#         )
#         pit.negative_add(dct)
#     value = pit.get_mask_sum()
#     npt.assert_array_equal(value, eb.array([1, 1, 0, 3, 4, 0, 1, 0]))
#
# def test_get_mask_sum_06():
#     pit = Point_in_Time()
#
#     bitmasks = [
#         eb.array([0, 0, 0, 1, 1, 0, 0, 0]),
#         eb.array([0, 0, 0, 1, 1, 0, 1, 0]),
#         eb.array([0, 1, 1, 1, 1, 0, 0, 0]),
#         eb.array([1, 0, 0, 1, 1, 0, 0, 0]),
#     ]
#     for idx, bitmask in enumerate(bitmasks):
#         dct: TLPointInTime = TLPointInTime(
#             pk=str(idx),
#             name=str(idx),
#             category="a",
#             bitmask=bitmask,
#             ts=date_parse('2020-01-01'),
#             direction=Event_Direction.POSITIVE,
#         )
#         pit.positive_add(dct)
#     value = pit.get_mask_sum()
#     npt.assert_array_equal(value, eb.array([1, 1, 1, 4, 4, 0, 1, 0]))
#
#     bitmasks = [
#         eb.array([0, 0, 1, 1, 0, 0, 0, 0]),
#         eb.array([0, 0, 0, 0, 0, 0, 1, 0]),
#     ]
#     for idx, bitmask in enumerate(bitmasks):
#         dct: TLPointInTime = TLPointInTime(
#             pk=str(idx),
#             name=str(idx),
#             category="a",
#             bitmask=bitmask,
#             ts=date_parse('2020-01-01'),
#             direction=Event_Direction.NEGATIVE,
#         )
#         pit.negative_add(dct)
#     value = pit.get_mask_sum()
#     npt.assert_array_equal(value, eb.array([1, 1, 0, 3, 4, 0, 0, 0]))
#
# def test_get_mask_sum_07():
#     pit = Point_in_Time()
#
#     bitmasks = [
#         eb.array([0, 0, 1, 1, 0, 0, 0, 0]),
#         eb.array([0, 0, 0, 0, 0, 0, 1, 0]),
#     ]
#     for idx, bitmask in enumerate(bitmasks):
#         dct: TLPointInTime = TLPointInTime(
#             pk=str(idx),
#             name=str(idx),
#             category="a",
#             bitmask=bitmask,
#             ts=date_parse('2020-01-01'),
#             direction=Event_Direction.NEGATIVE,
#         )
#         pit.negative_add(dct)
#     value = pit.get_mask_sum()
#     # underflow!!
#     npt.assert_array_equal(value, eb.array([0, 0, 255, 255, 0, 0, 255, 0]))