/* * Copyright (C) 2024, UChicago Argonne, LLC * Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file * in the top-level directory. */ package Octeres.Bitmask import scala.collection.IterableOnce.iterableOnceExtensionMethods import scala.collection.immutable.ArraySeq import scala.collection.mutable.{ArrayBuffer, ListBuffer} object BitmaskLite { case class LiteBitmaskStruct(length: Long, bit_count: Long, bit_idx_first: Long, bit_idx_last: Long, intervals: ArraySeq[ArraySeq[Long]]) def getLiteBitmaskZeros(length: Long): LiteBitmaskStruct = { LiteBitmaskStruct(length, 0L, -1L, -1L, ArraySeq()) } def getLiteBitmaskZeros(length: Int): LiteBitmaskStruct = { LiteBitmaskStruct(length.toLong, 0L, -1L, -1L, ArraySeq()) } def getCombinedIntervalsUntilAllConsumed(alst: ArraySeq[ArraySeq[Long]], blst: ArraySeq[ArraySeq[Long]]): Seq[List[Long]] = { var aidx: Long = 0 var bidx: Long = 0 val exit_list = List[Long](-2, -2) LazyList.continually { val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2L val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2L val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2L val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2L (amin, bmin) match { case (-2L, -2L) => exit_list case (_, -2L) => aidx += 1L List(amin, amax) case (-2L, _) => bidx += 1L List(bmin, bmax) case (_, _) if (amin != -2 & amin <= bmin) => aidx += 1L List(amin, amax) case _ => bidx += 1L List(bmin, bmax) } }.takeWhile(_ != exit_list) } class BitmaskError(message: String) extends Exception(message) { def this(message: String, cause: Throwable = null) = { this(message) initCause(cause) } def this(cause: Throwable) = { this(Option(cause).map(_.toString).orNull, cause) } def this() = { this(null: String) } } def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = { val intervals = obj.intervals if (obj.bit_count > 0) { if (obj.bit_idx_last >= obj.length) { throw new BitmaskError(s"bit_idx_last is greater than length") } if (obj.bit_idx_first != intervals.head.head) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}") } if (obj.bit_idx_last != intervals.last.last) { throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != intervals.last.last:${intervals.last.last}") } } else { if (obj.bit_idx_first != -1) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1") } if (obj.bit_idx_last != -1) { throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != -1") } } var bit_count: Long = 0 for ((alar, i) <- intervals.zipWithIndex) { val al: Long = alar.head val ar: Long = alar.last val sub_bit_count: Long = (ar - al) + 1L if (sub_bit_count < 1) { throw new BitmaskError(s"negative interval: [$al, $ar]") } bit_count += sub_bit_count if (al > ar) { throw new BitmaskError(s"negative interval: [$al, $ar]") } val blbr: Option[ArraySeq[Long]] = intervals.lift(i + 1) blbr match { case Some(x: ArraySeq[Long]) => { val bl: Long = x.head val br: Long = x.last if ((al == bl) & (ar == br)) { throw new BitmaskError(s"duplicate interval: [$al, $ar]") } if ((ar + 1) == bl) { throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr") } if (ar > bl) { throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr") } } case None => Nil } } if (bit_count != obj.bit_count) { throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}") } } def logical_or_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = { /* does a logical or of two bit intervals * From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */ val intervalGen: Seq[List[Long]] = getCombinedIntervalsUntilAllConsumed( a_bitmask.intervals, b_bitmask.intervals) val intervals = ArrayBuffer.empty[ArraySeq[Long]] var bitCount: Long = 0 var bitIdxFirst: Long = -1 var bitIdxLast: Long = -1 var prev_start: Long = -1 var prev_end: Long = -1 val itr = intervalGen.iterator for (List(next_start, next_end) <- itr) { (next_start, next_end) match { case (-2, -2) => // case where there is nothing bitCount = 0 bitIdxFirst = -1 bitIdxLast = -1 case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous prev_start = next_start prev_end = next_end intervals += ArraySeq(prev_start, prev_end) bitCount = prev_end - prev_start + 1 bitIdxFirst = intervals.head.head bitIdxLast = intervals.last.last case _ => if (next_start <= prev_end + 1) { val new_end = Math.max(prev_end, next_end) intervals(intervals.length - 1) = ArraySeq(prev_start, new_end) bitCount += new_end - prev_end prev_end = new_end } else { intervals += ArraySeq(next_start, next_end) bitCount += next_end - next_start + 1 prev_start = next_start prev_end = next_end } bitIdxFirst = intervals.head.head bitIdxLast = intervals.last.last } } val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.from(intervals) // hopefully this is a view not a copy // val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.unsafeWrapArray(intervals.toArray) LiteBitmaskStruct(a_bitmask.length, bitCount, bitIdxFirst, bitIdxLast, result_intervals) } def logical_or(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = { logical_or_v0(a_bitmask, b_bitmask) } def intersection_possible(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = true if ((a_bitmask.bit_count == 0) | (b_bitmask.bit_count == 0)) { result = false } else if (a_bitmask.bit_idx_first > b_bitmask.bit_idx_last) { result = false } else if (a_bitmask.bit_idx_last < b_bitmask.bit_idx_first) { result = false } else { result = true } result } def getIntervalUntilAllConsumed(alst: List[List[Long]], blst: List[List[Long]]): Seq[List[Long]] = { /* from get_intervals_until_all_consumed */ var aidx: Long = 0 var bidx: Long = 0 val exit_list: List[Long] = List(-2, -2) LazyList.continually { val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2 val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2 val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2 val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2 // println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax} bmin:${bmin} bmax:${bmax}") (amin, bmin) match { case (-2, -2) => exit_list case (_, -2) => aidx += 1 List(amin, amax) case (-2, _) => bidx += 1 List(bmin, bmax) case (_, _) if (amin != -2 & amin <= bmin) => aidx += 1 List(amin, amax) case _ => bidx += 1 List(bmin, bmax) } }.takeWhile(_ != exit_list) } def getIntervalBits(a_bitmask: LiteBitmaskStruct, bit_idx_first: Long = 0, bit_idx_last: Long = -2): List[Long] = { val bit_lst: ListBuffer[Long] = ListBuffer() a_bitmask.intervals.foreach(elem => { for (i <- elem.head to elem.last) { // inclusive if (i < bit_idx_first){ } else if (bit_idx_last != -2 & (i > bit_idx_last)) { } else { bit_lst += i } } }) bit_lst.toList } def intersection_any_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = false if (!intersection_possible(a_bitmask, b_bitmask)) { result = false } else { val first_a = a_bitmask.bit_idx_first val last_a = a_bitmask.bit_idx_last val first_b = b_bitmask.bit_idx_first val last_b = b_bitmask.bit_idx_last if ((first_a == first_b) | (last_a == last_b)) { result = true } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) { result = true } else { val bit_idx_first = if (first_a >= first_b) first_a else first_b val bit_idx_last = if (last_a < last_b) last_a else last_b var a_lst = getIntervalBits(a_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last) var b_lst = getIntervalBits(b_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last) if (a_lst.isEmpty | b_lst.isEmpty){ } else { var a_idx = a_lst.head a_lst = a_lst.tail var b_idx = b_lst.head b_lst = b_lst.tail while (a_idx >= 0 & b_idx >= 0) { if (a_idx == b_idx) { result = true a_idx = -2 b_idx = -2 } else if (a_idx > b_idx){ // move b_idx if (b_lst.nonEmpty) { b_idx = b_lst.head b_lst = b_lst.tail } else { b_idx = -2 } } else { // move a_idx if (a_lst.nonEmpty){ a_idx = a_lst.head a_lst = a_lst.tail } else { a_idx = -2 } } } } } } result } def toArrayRangeGen(arraySeq: ArraySeq[ArraySeq[Long]], idx_first: Long = 0, idx_last: Long = -2): Iterator[Long] = new Iterator[Long] { /* Given an array of arrays that is a array of ranges [a, b], yield an integer from the range. Note does not yet support full size longs. See .toInt FIXME: doesn't use idx_first and idx_last due to the complexities. */ private var currOutIdx: Long = if (arraySeq.isEmpty) -1L else 0L private var currInnVal: Long = if (arraySeq.nonEmpty) arraySeq(currOutIdx.toInt).head else -1L private var currInnEnd: Long = if (arraySeq.nonEmpty) arraySeq(currOutIdx.toInt).last else -1L def hasNext: Boolean = { if (arraySeq.isEmpty){ false } else if (currOutIdx <= arraySeq.length){ // don't include the length if (currInnVal <= currInnEnd){ // include the end true } else { false } } else { false } } def next(): Long = { val value = currInnVal currInnVal += 1 if (currInnVal > currInnEnd) { // at the end of a range currOutIdx += 1 if (currOutIdx < arraySeq.length) { currInnVal = arraySeq(currOutIdx.toInt).head currInnEnd = arraySeq(currOutIdx.toInt).last } } else { // still in the current range } value } } def intersection_any_v1(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = false if (!intersection_possible(a_bitmask, b_bitmask)) { result = false } else { val first_a = a_bitmask.bit_idx_first val last_a = a_bitmask.bit_idx_last val first_b = b_bitmask.bit_idx_first val last_b = b_bitmask.bit_idx_last if ((first_a == first_b) | (last_a == last_b)) { result = true } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) { result = true } else { // does not work yet // val bit_idx_first = if (first_a >= first_b) first_a else first_b // val bit_idx_last = if (last_a < last_b) last_a else last_b val a_ii = toArrayRangeGen(a_bitmask.intervals) //, bit_idx_first, bit_idx_last) val b_ii = toArrayRangeGen(b_bitmask.intervals) //, bit_idx_first, bit_idx_last) if (a_ii.isEmpty | b_ii.isEmpty){ } else { var a_idx = if (a_ii.hasNext) a_ii.next() else -2 var b_idx = if (b_ii.hasNext) b_ii.next() else -2 while (a_idx >= 0 & b_idx >= 0) { if (a_idx == b_idx) { result = true a_idx = -2 b_idx = -2 } else if (a_idx > b_idx){ // move b_idx if (b_ii.nonEmpty) { b_idx = if (b_ii.hasNext) b_ii.next() else -2 } else { b_idx = -2 } } else { // move a_idx if (a_ii.nonEmpty){ a_idx = if (a_ii.hasNext) a_ii.next() else -2 } else { a_idx = -2 } } } } } } result } def toArrayRangeLimitGen(arraySeq: ArraySeq[ArraySeq[Long]], idx_first: Long = 0, idx_last: Long = -2): Iterator[Long] = { /* remove full outer ArraySeq that are not in the idx_first to idx_last range. */ if (arraySeq.nonEmpty){ toArrayRangeGen(arraySeq.filter(a => ( ((a.head >= idx_first)) | (a.last >= idx_first)) & (((a.head <= idx_last) | (a.last <= idx_last)) | (idx_last == -2)))) } else { toArrayRangeGen(arraySeq) } } def intersection_any_v2(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = false if (!intersection_possible(a_bitmask, b_bitmask)) { result = false } else { val first_a = a_bitmask.bit_idx_first val last_a = a_bitmask.bit_idx_last val first_b = b_bitmask.bit_idx_first val last_b = b_bitmask.bit_idx_last if ((first_a == first_b) | (last_a == last_b)) { result = true } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) { result = true } else { val bit_idx_first = if (first_a >= first_b) first_a else first_b val bit_idx_last = if (last_a < last_b) last_a else last_b val a_ii = toArrayRangeLimitGen(a_bitmask.intervals, bit_idx_first, bit_idx_last) val b_ii = toArrayRangeLimitGen(b_bitmask.intervals, bit_idx_first, bit_idx_last) if (a_ii.isEmpty | b_ii.isEmpty){ } else { var a_idx = if (a_ii.hasNext) a_ii.next() else -2 var b_idx = if (b_ii.hasNext) b_ii.next() else -2 while (a_idx >= 0 & b_idx >= 0) { if (a_idx == b_idx) { result = true a_idx = -2 b_idx = -2 } else if (a_idx > b_idx){ // move b_idx if (b_ii.nonEmpty) { b_idx = if (b_ii.hasNext) b_ii.next() else -2 } else { b_idx = -2 } } else { // move a_idx if (a_ii.nonEmpty){ a_idx = if (a_ii.hasNext) a_ii.next() else -2 } else { a_idx = -2 } } } } } } result } def intersection_any(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { /* this directs the data to the current version of the function */ intersection_any_v2(a_bitmask, b_bitmask) } }