Skip to content
Snippets Groups Projects
BitmaskLiteTest.scala 5.08 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

package Octeres.Bitmask

import Octeres.Bitmask.BitmaskLite.{LiteBitmaskStruct, getLiteBitmaskZeros, IntervalIterator, validateLiteBitmaskSlotsLike,
    intersection_any_v0, intersection_any_v1, logical_or, intersection_any}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._
import org.scalatest.prop.TableFor3

import scala.collection.immutable.ArraySeq

class BitmaskLiteTest extends AnyFunSuite{
    test("test_bitmask_logical_or") {
        val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, LiteBitmaskStruct] = Table(
            ("bitmask0", "bitmask1", "bitmask_or"),
            (
                LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
                LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
                LiteBitmaskStruct(8, 4, 0, 3, ArraySeq(ArraySeq[Long](0, 3)))
            ), (
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
                getLiteBitmaskZeros(8),
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
            ), (
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
            ), (
                LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
                LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
            ), (
                LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
                LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
                LiteBitmaskStruct(16, 8, 1, 8, ArraySeq(ArraySeq[Long](1, 8)))
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: LiteBitmaskStruct) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            println(s"bitmask_or: ${bitmask_or}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            validateLiteBitmaskSlotsLike(bitmask_or)
            val bitmask_r = logical_or(bitmask0, bitmask1)
            bitmask_r should equal(bitmask_or)
            println(intersection_any(bitmask0, bitmask1))
        }
        }
    }

    test("test_getIntervalBitsGenerator") {
        val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table(
            ("intervals", "bit_idx_first", "bit_idx_last"),
            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1),
            (ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9),
            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6),
            (ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6)
        )

        forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
            {
                println(intervals)
                val ii = new IntervalIterator(intervals, bit_idx_first, bit_idx_last)
                ii.printState()
                while (!ii.isEmpty) {
                    println(s"i: ${ii.get()}")
                }
            }
        }}
    }

    test("test_bitmask_intersection_any") {
        val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, Boolean] = Table(
            ("bitmask0", "bitmask1", "logical_truth"),
            (
                LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
                LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
                true
            ), (
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
                getLiteBitmaskZeros(8),
                false
            ), (
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
                false
            ), (
                LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
                LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
                true
            ), (
                LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
                LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
                false
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            val truth = intersection_any(bitmask0, bitmask1)
            println(s"truth: ${truth}")
            truth should equal(logical_truth)
            intersection_any_v0(bitmask0, bitmask1) should equal(intersection_any_v1(bitmask0, bitmask1))
        }
        }
    }

}