Skip to content
Snippets Groups Projects
BitmaskLiteTest.scala 10.3 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

package Octeres.Bitmask

import Octeres.Bitmask.BitmaskLite.{LiteBitmaskStruct, getLiteBitmaskZeros, intersection_any, intersection_any_v0, intersection_any_v1, intersection_any_v2, logical_or, toArrayRangeGen, toArrayRangeLimitGen, validateLiteBitmaskSlotsLike}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._
import org.scalatest.prop.{TableFor2, TableFor3, TableFor4}
import scala.collection.immutable.ArraySeq

class BitmaskLiteTest extends AnyFunSuite{
    test("test_bitmask_logical_or") {
        val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, LiteBitmaskStruct] = Table(
            ("bitmask0", "bitmask1", "bitmask_or"),
            (
                LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
                LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
                LiteBitmaskStruct(8, 4, 0, 3, ArraySeq(ArraySeq[Long](0, 3)))
            ), (
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
                getLiteBitmaskZeros(8),
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
            ), (
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
            ), (
                LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
                LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
            ), (
                LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
                LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
                LiteBitmaskStruct(16, 8, 1, 8, ArraySeq(ArraySeq[Long](1, 8)))
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: LiteBitmaskStruct) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            println(s"bitmask_or: ${bitmask_or}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            validateLiteBitmaskSlotsLike(bitmask_or)
            val bitmask_r = logical_or(bitmask0, bitmask1)
            bitmask_r should equal(bitmask_or)
            println(intersection_any(bitmask0, bitmask1))
        }
        }
    }

//    test("test_getIntervalBitsGenerator") {
//        val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table(
//            ("intervals", "bit_idx_first", "bit_idx_last"),
//            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1),
//            (ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9),
//            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6),
//            (ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6)
//        )
//
//        forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
//            {
//                println(intervals)
//                val ii = new IntervalIterator("G", intervals, bit_idx_first, bit_idx_last)
//                ii.printState()
//                while (!ii.isEmpty) {
//                    println(s"i: ${ii.get()}")
//                }
//            }
//        }}
//    }
    test("test_good") {

        val deck = ArraySeq(1, 9)

        def toGenerator3(start: Int, end: Int): Iterator[Int] = new Iterator[Int] {
            private var current = start

            def hasNext: Boolean = {
                if (current <= end){
                    true
                } else {
                    false
                }
            }

            def next(): Int = {
                val value = current
                current += 1
                value
            }
        }
        val lg3a = toGenerator3(10, 11)
        while (lg3a.hasNext){
            println(lg3a.next())
        }
    }

    test("test_toArrayRangeGen") {
        val arrayTable: TableFor2[ArraySeq[ArraySeq[Long]], List[Long]] = Table(
            ("intervals", "expected"),
            (ArraySeq(ArraySeq[Long](0, 2)), List[Long](0, 1, 2)),
            (ArraySeq(ArraySeq[Long](0, 0)), List[Long](0)),
            (ArraySeq(), List[Long]()),
            (ArraySeq(ArraySeq[Long](1, 3)), List[Long](1, 2, 3)),
            (ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), List[Long](0, 1, 3, 4)),
            (ArraySeq.range(0, 10, 2).map(i => ArraySeq(i.toLong, i.toLong)), List[Long](0, 2, 4, 6, 8)),
            (ArraySeq.range(1, 11, 2).map(i => ArraySeq(i.toLong, i.toLong)), List[Long](1, 3, 5, 7, 9))
        )
        forAll(arrayTable) { (interval, expected) => {
            println(s"interval: ${interval}")
            println(s"expected: ${expected}")
            val ag = toArrayRangeGen(interval)
            val resultListB: ListBuffer[Long] = ListBuffer()
            while (ag.hasNext){
                resultListB += ag.next()
            }
            val resultList = resultListB.toList
            println(s"result:   ${resultList}")
            expected should equal(resultList)
        }
        }
    }

    test("test_toArrayRangeLimitGen") {
        val arrayTable: TableFor4[ArraySeq[ArraySeq[Long]], Long, Long, List[Long]] = Table(
            ("intervals", "idx_first", "idx_last", "expected"),
            (ArraySeq(ArraySeq[Long](0, 2)), 0L, -2L, List[Long](0, 1, 2)),
            (ArraySeq(ArraySeq[Long](0, 0)), 0L, -2L, List[Long](0)),
            (ArraySeq(), 0L, -2L, List[Long]()),
            (ArraySeq(ArraySeq[Long](1, 3)), 0L, -2L, List[Long](1, 2, 3)),
            (ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0L, -2L, List[Long](0, 1, 3, 4)),
            (ArraySeq.range(0, 10, 2).map(i => ArraySeq(i.toLong, i.toLong)), 4L, -2L, List[Long](4, 6, 8)),
            (ArraySeq.range(1, 11, 2).map(i => ArraySeq(i.toLong, i.toLong)), 1L, -2L, List[Long](1, 3, 5, 7, 9)),
            (ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4), ArraySeq[Long](6, 8)), 3L, 4L, List[Long](3, 4)),
            // FIXME: this one pulls in 7 and 8 because of the 6, 8
            (ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4), ArraySeq[Long](6, 8)), 3L, 6L, List[Long](3, 4, 6, 7, 8)),
            // FIXME: this one pulls in 0 because of the one in the 0, 1
            (ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4), ArraySeq[Long](6, 8)), 1L, 6L, List[Long](0, 1, 3, 4, 6, 7, 8)),
        )
        forAll(arrayTable) { (interval, idx_first, idx_last, expected) => {
            println(s"interval: ${interval} ${idx_first} ${idx_last}")
            println(s"expected: ${expected}")
            val ag = toArrayRangeLimitGen(interval, idx_first, idx_last)
            val resultListB: ListBuffer[Long] = ListBuffer()
            while (ag.hasNext){
                resultListB += ag.next()
            }
            val resultList = resultListB.toList
            println(s"result:   ${resultList}")
            expected should equal(resultList)
        }
        }
    }

    test("test_bitmask_intersection_any") {
        val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, Boolean] = Table(
            ("bitmask0", "bitmask1", "logical_truth"),
            (
                LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
                LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
                true
            ), (
                LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
                getLiteBitmaskZeros(8),
                false
            ), (
                getLiteBitmaskZeros(8),
                getLiteBitmaskZeros(8),
                false
            ), (
                LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
                LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
                true
            ), (
                LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
                LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
                false
            ), (
                LiteBitmaskStruct(128, 7, 11, 106, ArraySeq(ArraySeq[Long](11, 11), ArraySeq[Long](13, 13), ArraySeq[Long](62, 63), ArraySeq[Long](104, 106))),
                LiteBitmaskStruct(128, 64, 0, 126, ArraySeq.range(0, 128, 2).map(i => ArraySeq(i.toLong, i.toLong))),
                true
            ), (
                LiteBitmaskStruct(4392, 10, 2911, 3006, ArraySeq(ArraySeq[Long](2911, 2911), ArraySeq[Long](2913, 2913), ArraySeq[Long](2962, 2963), ArraySeq[Long](2965, 2965), ArraySeq[Long](2970, 2971), ArraySeq[Long](3004, 3006))),
                LiteBitmaskStruct(4392, 2196, 0, 4390, ArraySeq.range(0, 4392, 2).map(i => ArraySeq(i.toLong, i.toLong))),
                true
            ), (
                LiteBitmaskStruct(128, 1, 0, 0, ArraySeq(ArraySeq[Long](0, 0))),
                LiteBitmaskStruct(128, 64, 1, 127, ArraySeq.range(1, 128, 2).map(i => ArraySeq(i.toLong, i.toLong))),
                false
            ), (
                LiteBitmaskStruct(128, 1, 0, 0, ArraySeq(ArraySeq[Long](0, 0))),
                LiteBitmaskStruct(128, 64, 0, 126, ArraySeq.range(0, 128, 2).map(i => ArraySeq(i.toLong, i.toLong))),
                true
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => {
//            println(s"bitmask0: ${bitmask0}")
//            println(s"bitmask1: ${bitmask1}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            intersection_any_v0(bitmask0, bitmask1) should equal(logical_truth)
            intersection_any_v1(bitmask0, bitmask1) should equal(logical_truth)
            intersection_any_v2(bitmask0, bitmask1) should equal(logical_truth)
            val truth = intersection_any(bitmask0, bitmask1)
            println(s"truth: ${truth}")
            truth should equal(logical_truth)
//            intersection_any_v0(bitmask0, bitmask1) should equal(intersection_any_v1(bitmask0, bitmask1))