Skip to content
Snippets Groups Projects
BitmaskLiteTest.scala 5.99 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

package Octeres.Bitmask

import Octeres.Bitmask.BitmaskLite.{LiteBitmaskStruct, getLiteBitmaskZeros, validateLiteBitmaskSlotsLike,
    intersection_any_v0, logical_or, intersection_any
//    , IntervalIterator, intersection_any_v1
}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._
import org.scalatest.prop.TableFor3

import scala.collection.immutable.ArraySeq

class BitmaskLiteTest extends AnyFunSuite{
    test("test_bitmask_logical_or") {
        val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, LiteBitmaskStruct] = Table(
            ("bitmask0", "bitmask1", "bitmask_or"),
            (
                LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
                LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
                LiteBitmaskStruct(8, 4, 0, 3, ArraySeq(ArraySeq[Long](0, 3)))
            ), (
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
                getLiteBitmaskZeros(8),
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
            ), (
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
            ), (
                LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
                LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
            ), (
                LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
                LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
                LiteBitmaskStruct(16, 8, 1, 8, ArraySeq(ArraySeq[Long](1, 8)))
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: LiteBitmaskStruct) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            println(s"bitmask_or: ${bitmask_or}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            validateLiteBitmaskSlotsLike(bitmask_or)
            val bitmask_r = logical_or(bitmask0, bitmask1)
            bitmask_r should equal(bitmask_or)
            println(intersection_any(bitmask0, bitmask1))
        }
        }
    }

//    test("test_getIntervalBitsGenerator") {
//        val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table(
//            ("intervals", "bit_idx_first", "bit_idx_last"),
//            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1),
//            (ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9),
//            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6),
//            (ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6)
//        )
//
//        forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
//            {
//                println(intervals)
//                val ii = new IntervalIterator("G", intervals, bit_idx_first, bit_idx_last)
//                ii.printState()
//                while (!ii.isEmpty) {
//                    println(s"i: ${ii.get()}")
//                }
//            }
//        }}
//    }

    test("test_bitmask_intersection_any") {
        val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, Boolean] = Table(
            ("bitmask0", "bitmask1", "logical_truth"),
            (
                LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
                LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
                true
            ), (
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
                getLiteBitmaskZeros(8),
                false
            ), (
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
                false
            ), (
                LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
                LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
                true
            ), (
                LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
                LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
                false
            ), (
                LiteBitmaskStruct(128, 7, 11, 106, ArraySeq(ArraySeq[Long](11, 11), ArraySeq[Long](13, 13), ArraySeq[Long](62, 63), ArraySeq[Long](104, 106))),
                LiteBitmaskStruct(128, 64, 0, 126, ArraySeq.range(0, 128, 2).map(i => ArraySeq(i.toLong, i.toLong))),
                true
            ), (
                LiteBitmaskStruct(4392, 10, 2911, 3006, ArraySeq(ArraySeq[Long](2911, 2911), ArraySeq[Long](2913, 2913), ArraySeq[Long](2962, 2963), ArraySeq[Long](2965, 2965), ArraySeq[Long](2970, 2971), ArraySeq[Long](3004, 3006))),
                LiteBitmaskStruct(4392, 2196, 0, 4390, ArraySeq.range(0, 4392, 2).map(i => ArraySeq(i.toLong, i.toLong))),
                true
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            intersection_any_v0(bitmask0, bitmask1) should equal(logical_truth)
//            intersection_any_v1(bitmask0, bitmask1) should equal(logical_truth)
            val truth = intersection_any(bitmask0, bitmask1)
            println(s"truth: ${truth}")
            truth should equal(logical_truth)
//            intersection_any_v0(bitmask0, bitmask1) should equal(intersection_any_v1(bitmask0, bitmask1))