Skip to content
Snippets Groups Projects
Commit 44395573 authored by Eric Pershey's avatar Eric Pershey
Browse files

isolating the bitmask functions from spark by moving under Octeres.Bitmask

parent aad56ac8
No related branches found
No related tags found
No related merge requests found
/*
* Copyright (C) 2024, UChicago Argonne, LLC
* Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
* in the top-level directory.
*/
package Octeres.Bitmask
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
object BitmaskLite {
case class LiteBitmaskStruct(length: Long, bit_count: Long, bit_idx_first: Long, bit_idx_last: Long,
intervals: ArraySeq[ArraySeq[Long]])
def getLiteBitmaskZeros(length: Long): LiteBitmaskStruct = {
LiteBitmaskStruct(length, 0L, -1L, -1L, ArraySeq())
}
def getLiteBitmaskZeros(length: Int): LiteBitmaskStruct = {
LiteBitmaskStruct(length.toLong, 0L, -1L, -1L, ArraySeq())
}
def getCombinedIntervalsUntilAllConsumed(alst: ArraySeq[ArraySeq[Long]], blst: ArraySeq[ArraySeq[Long]]): Seq[List[Long]] = {
var aidx: Long = 0
var bidx: Long = 0
val exit_list = List[Long](-2, -2)
LazyList.continually {
val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2L
val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2L
val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2L
val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2L
(amin, bmin) match {
case (-2L, -2L) =>
exit_list
case (_, -2L) =>
aidx += 1L
List(amin, amax)
case (-2L, _) =>
bidx += 1L
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1L
List(amin, amax)
case _ =>
bidx += 1L
List(bmin, bmax)
}
}.takeWhile(_ != exit_list)
}
class BitmaskError(message: String) extends Exception(message) {
def this(message: String, cause: Throwable = null) = {
this(message)
initCause(cause)
}
def this(cause: Throwable) = {
this(Option(cause).map(_.toString).orNull, cause)
}
def this() = {
this(null: String)
}
}
def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
val intervals = obj.intervals
if (obj.bit_count > 0) {
if (obj.bit_idx_last >= obj.length) {
throw new BitmaskError(s"bit_idx_last is greater than length")
}
if (obj.bit_idx_first != intervals.head.head) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
if (obj.bit_idx_last != intervals.last.last) {
throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != intervals.last.last:${intervals.last.last}")
}
} else {
if (obj.bit_idx_first != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
if (obj.bit_idx_last != -1) {
throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != -1")
}
}
var bit_count: Long = 0
for ((alar, i) <- intervals.zipWithIndex) {
val al: Long = alar.head
val ar: Long = alar.last
val sub_bit_count: Long = (ar - al) + 1L
if (sub_bit_count < 1) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
bit_count += sub_bit_count
if (al > ar) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
val blbr: Option[ArraySeq[Long]] = intervals.lift(i + 1)
blbr match {
case Some(x: ArraySeq[Long]) => {
val bl: Long = x.head
val br: Long = x.last
if ((al == bl) & (ar == br)) {
throw new BitmaskError(s"duplicate interval: [$al, $ar]")
}
if ((ar + 1) == bl) {
throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
}
if (ar > bl) {
throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
}
}
case None => Nil
}
}
if (bit_count != obj.bit_count) {
throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
}
}
def logical_or_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val intervalGen: Seq[List[Long]] = getCombinedIntervalsUntilAllConsumed(
a_bitmask.intervals,
b_bitmask.intervals)
val intervals = ArrayBuffer.empty[ArraySeq[Long]]
var bitCount: Long = 0
var bitIdxFirst: Long = -1
var bitIdxLast: Long = -1
var prev_start: Long = -1
var prev_end: Long = -1
val itr = intervalGen.iterator
for (List(next_start, next_end) <- itr) {
(next_start, next_end) match {
case (-2, -2) => // case where there is nothing
bitCount = 0
bitIdxFirst = -1
bitIdxLast = -1
case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
prev_start = next_start
prev_end = next_end
intervals += ArraySeq(prev_start, prev_end)
bitCount = prev_end - prev_start + 1
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
case _ =>
if (next_start <= prev_end + 1) {
val new_end = Math.max(prev_end, next_end)
intervals(intervals.length - 1) = ArraySeq(prev_start, new_end)
bitCount += new_end - prev_end
prev_end = new_end
} else {
intervals += ArraySeq(next_start, next_end)
bitCount += next_end - next_start + 1
prev_start = next_start
prev_end = next_end
}
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
}
}
val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.from(intervals)
// hopefully this is a view not a copy
// val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.unsafeWrapArray(intervals.toArray)
LiteBitmaskStruct(a_bitmask.length, bitCount, bitIdxFirst, bitIdxLast, result_intervals)
}
def logical_or(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = {
logical_or_v0(a_bitmask, b_bitmask)
}
def intersection_possible(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = true
if ((a_bitmask.bit_count == 0) | (b_bitmask.bit_count == 0)) {
result = false
} else if (a_bitmask.bit_idx_first > b_bitmask.bit_idx_last) {
result = false
} else if (a_bitmask.bit_idx_last < b_bitmask.bit_idx_first) {
result = false
} else {
result = true
}
result
}
def getIntervalUntilAllConsumed(alst: List[List[Long]], blst: List[List[Long]]): Seq[List[Long]] = {
/* from get_intervals_until_all_consumed */
var aidx: Long = 0
var bidx: Long = 0
val exit_list: List[Long] = List(-2, -2)
LazyList.continually {
val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2
val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2
val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2
val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2
// println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax} bmin:${bmin} bmax:${bmax}")
(amin, bmin) match {
case (-2, -2) =>
exit_list
case (_, -2) =>
aidx += 1
List(amin, amax)
case (-2, _) =>
bidx += 1
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1
List(amin, amax)
case _ =>
bidx += 1
List(bmin, bmax)
}
}.takeWhile(_ != exit_list)
}
class IntervalIterator(intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long) {
var cur_interval_idx: Long = 0
var cur_interval: ArraySeq[Long] = intervals(cur_interval_idx.toInt)
var cur_range_idx: Long = 0
var cur_range: ArraySeq[Long] = ArraySeq.range(cur_interval.head, cur_interval.last+1)
_jumpForward(bit_idx_first)
def printState(): Unit = {
println(s"bit_idx_first: ${bit_idx_first} bit_idx_last: ${bit_idx_last}")
println(s"cur_interval_idx: ${cur_interval_idx} cur_interval: ${cur_interval}")
println(s"cur_range_idx: ${cur_range_idx} cur_range: ${cur_range}")
}
def _jumpForward(i: Long): Unit = {
var found = false
while (!isEmpty & !found) {
if (bit_idx_first > cur_interval.last) {
_moveForwardInterval()
} else {
if (bit_idx_first > cur_range(cur_range_idx.toInt)){
_moveForwardRange()
} else {
found = true
}
}
}
}
def _moveEnd(): Unit = {
cur_interval_idx = -2
cur_range_idx = -2
}
def _moveForwardInterval(): Unit = {
cur_interval_idx += 1
cur_range_idx = 0
cur_interval = intervals(cur_interval_idx.toInt)
cur_range = ArraySeq.range(cur_interval.head, cur_interval.last+1)
}
def _moveForwardRange(): Unit = {
cur_range_idx += 1
}
def isEmpty: Boolean = {
if (cur_interval_idx == -2 & cur_range_idx == -2){
true
} else if (cur_interval_idx >= intervals.length & cur_range_idx >= cur_range.length){
true
} else {
false
}
}
def nonEmpty: Boolean = {
!isEmpty
}
def _move(): Unit = {
if (cur_range_idx+1 >= cur_range.length){
if (cur_interval_idx+1 >= cur_interval.length) {
_moveEnd()
} else {
_moveForwardInterval()
}
} else {
_moveForwardRange()
}
}
def get(): Long = {
val i = cur_range(cur_range_idx.toInt)
if (i >= bit_idx_last & bit_idx_last >= 0){
_moveEnd()
} else {
_move()
}
i
}
}
def getIntervalBits(a_bitmask: LiteBitmaskStruct, bit_idx_first: Long = 0, bit_idx_last: Long = -2): List[Long] = {
val bit_lst: ListBuffer[Long] = ListBuffer()
a_bitmask.intervals.foreach(elem => {
for (i <- elem.head to elem.last) { // inclusive
if (i < bit_idx_first){
} else if (bit_idx_last != -2 & (i > bit_idx_last)) {
} else {
bit_lst += i
}
}
})
bit_lst.toList
}
def intersection_any_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = false
if (!intersection_possible(a_bitmask, b_bitmask)) {
result = false
} else {
val first_a = a_bitmask.bit_idx_first
val last_a = a_bitmask.bit_idx_last
val first_b = b_bitmask.bit_idx_first
val last_b = b_bitmask.bit_idx_last
if ((first_a == first_b) | (last_a == last_b)) {
result = true
} else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
result = true
} else {
val bit_idx_first = if (first_a >= first_b) first_a else first_b
val bit_idx_last = if (last_a < last_b) last_a else last_b
var a_lst = getIntervalBits(a_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
var b_lst = getIntervalBits(b_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
if (a_lst.isEmpty | b_lst.isEmpty){
} else {
var a_idx = a_lst.head
a_lst = a_lst.tail
var b_idx = b_lst.head
b_lst = b_lst.tail
while (a_idx >= 0 & b_idx >= 0) {
if (a_idx == b_idx) {
result = true
a_idx = -2
b_idx = -2
} else if (a_idx > b_idx){
// move b_idx
if (b_lst.nonEmpty) {
b_idx = b_lst.head
b_lst = b_lst.tail
} else {
b_idx = -2
}
} else {
// move a_idx
if (a_lst.nonEmpty){
a_idx = a_lst.head
a_lst = a_lst.tail
} else {
a_idx = -2
}
}
}
}
}
}
result
}
def intersection_any_v1(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = false
if (!intersection_possible(a_bitmask, b_bitmask)) {
result = false
} else {
val first_a = a_bitmask.bit_idx_first
val last_a = a_bitmask.bit_idx_last
val first_b = b_bitmask.bit_idx_first
val last_b = b_bitmask.bit_idx_last
if ((first_a == first_b) | (last_a == last_b)) {
result = true
} else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
result = true
} else {
val bit_idx_first = if (first_a >= first_b) first_a else first_b
val bit_idx_last = if (last_a < last_b) last_a else last_b
val a_ii = new IntervalIterator(a_bitmask.intervals, bit_idx_first, bit_idx_last)
val b_ii = new IntervalIterator(b_bitmask.intervals, bit_idx_first, bit_idx_last)
if (a_ii.isEmpty | b_ii.isEmpty){
} else {
var a_idx = a_ii.get()
var b_idx = b_ii.get()
while (a_idx >= 0 & b_idx >= 0) {
if (a_idx == b_idx) {
result = true
a_idx = -2
b_idx = -2
} else if (a_idx > b_idx){
// move b_idx
if (b_ii.nonEmpty) {
b_idx = b_ii.get()
} else {
b_idx = -2
}
} else {
// move a_idx
if (a_ii.nonEmpty){
a_idx = a_ii.get()
} else {
a_idx = -2
}
}
}
}
}
}
result
}
def intersection_any(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
/* this directs the data to the current version of the function */
intersection_any_v1(a_bitmask, b_bitmask)
}
}
......@@ -10,19 +10,16 @@ Todo: the .toInt are not the best for indexing a ArraySeq, this may require a re
*/
package Octeres
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import Octeres.Bitmask.BitmaskLite.LiteBitmaskStruct
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
import scala.collection.immutable.ArraySeq
object DataUDF {
case class LiteBitmaskStruct(length: Long, bit_count: Long, bit_idx_first: Long, bit_idx_last: Long,
intervals: ArraySeq[ArraySeq[Long]])
val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)
......@@ -38,159 +35,9 @@ object DataUDF {
new GenericRowWithSchema(Array(length, bit_count, bit_idx_first, bit_idx_last, intervals), LiteBitmaskSchema)
}
def getCombinedIntervalsUntilAllConsumed(alst: ArraySeq[ArraySeq[Long]], blst: ArraySeq[ArraySeq[Long]]): Seq[List[Long]] = {
var aidx: Long = 0
var bidx: Long = 0
val exit_list = List[Long](-2, -2)
LazyList.continually {
val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2L
val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2L
val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2L
val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2L
(amin, bmin) match {
case (-2L, -2L) =>
exit_list
case (_, -2L) =>
aidx += 1L
List(amin, amax)
case (-2L, _) =>
bidx += 1L
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1L
List(amin, amax)
case _ =>
bidx += 1L
List(bmin, bmax)
}
}.takeWhile(_ != exit_list)
}
class BitmaskError(message: String) extends Exception(message) {
def this(message: String, cause: Throwable = null) = {
this(message)
initCause(cause)
}
def this(cause: Throwable) = {
this(Option(cause).map(_.toString).orNull, cause)
}
def this() = {
this(null: String)
}
}
def validateLiteBitmaskSlotsLike(row: Row): Unit = {
val obj = rowToLiteBitmaskStruct(row)
validateLiteBitmaskSlotsLike(obj)
}
def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
val intervals = obj.intervals
if (obj.bit_count > 0) {
if (obj.bit_idx_last >= obj.length) {
throw new BitmaskError(s"bit_idx_last is greater than length")
}
if (obj.bit_idx_first != intervals.head.head) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
if (obj.bit_idx_last != intervals.last.last) {
throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != intervals.last.last:${intervals.last.last}")
}
} else {
if (obj.bit_idx_first != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
if (obj.bit_idx_last != -1) {
throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != -1")
}
}
var bit_count: Long = 0
for ((alar, i) <- intervals.zipWithIndex) {
val al: Long = alar.head
val ar: Long = alar.last
val sub_bit_count: Long = (ar - al) + 1L
if (sub_bit_count < 1) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
bit_count += sub_bit_count
if (al > ar) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
val blbr: Option[ArraySeq[Long]] = intervals.lift(i + 1)
blbr match {
case Some(x: ArraySeq[Long]) => {
val bl: Long = x.head
val br: Long = x.last
if ((al == bl) & (ar == br)) {
throw new BitmaskError(s"duplicate interval: [$al, $ar]")
}
if ((ar + 1) == bl) {
throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
}
if (ar > bl) {
throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
}
}
case None => Nil
}
}
if (bit_count != obj.bit_count) {
throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
}
}
def logical_or_v0(abi: Row, bbi: Row): LiteBitmaskStruct = {
logical_or_v0(rowToLiteBitmaskStruct(abi), rowToLiteBitmaskStruct(bbi))
}
def logical_or_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val intervalGen: Seq[List[Long]] = getCombinedIntervalsUntilAllConsumed(
a_bitmask.intervals,
b_bitmask.intervals)
val intervals = ArrayBuffer.empty[ArraySeq[Long]]
var bitCount: Long = 0
var bitIdxFirst: Long = -1
var bitIdxLast: Long = -1
var prev_start: Long = -1
var prev_end: Long = -1
val itr = intervalGen.iterator
for (List(next_start, next_end) <- itr) {
(next_start, next_end) match {
case (-2, -2) => // case where there is nothing
bitCount = 0
bitIdxFirst = -1
bitIdxLast = -1
case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
prev_start = next_start
prev_end = next_end
intervals += ArraySeq(prev_start, prev_end)
bitCount = prev_end - prev_start + 1
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
case _ =>
if (next_start <= prev_end + 1) {
val new_end = Math.max(prev_end, next_end)
intervals(intervals.length - 1) = ArraySeq(prev_start, new_end)
bitCount += new_end - prev_end
prev_end = new_end
} else {
intervals += ArraySeq(next_start, next_end)
bitCount += next_end - next_start + 1
prev_start = next_start
prev_end = next_end
}
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
}
}
val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.from(intervals)
// hopefully this is a view not a copy
// val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.unsafeWrapArray(intervals.toArray)
LiteBitmaskStruct(a_bitmask.length, bitCount, bitIdxFirst, bitIdxLast, result_intervals)
Octeres.Bitmask.BitmaskLite.validateLiteBitmaskSlotsLike(obj)
}
def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
......@@ -216,285 +63,17 @@ object DataUDF {
case (null, null) => null
case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
case (_, _) => logical_or_v0(a_bitmask, b_bitmask)
case (_, _) => Octeres.Bitmask.BitmaskLite.logical_or(rowToLiteBitmaskStruct(a_bitmask), rowToLiteBitmaskStruct(b_bitmask))
}
}
def intersection_possible(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = true
if ((a_bitmask.bit_count == 0) | (b_bitmask.bit_count == 0)) {
result = false
} else if (a_bitmask.bit_idx_first > b_bitmask.bit_idx_last) {
result = false
} else if (a_bitmask.bit_idx_last < b_bitmask.bit_idx_first) {
result = false
} else {
result = true
}
result
}
def intersection_possible(abi: Row, bbi: Row): Boolean = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intersection_possible(bitmask_a, bitmask_b)
}
def getIntervalUntilAllConsumed(alst: List[List[Long]], blst: List[List[Long]]): Seq[List[Long]] = {
/* from get_intervals_until_all_consumed */
var aidx: Long = 0
var bidx: Long = 0
val exit_list: List[Long] = List(-2, -2)
LazyList.continually {
val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2
val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2
val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2
val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2
// println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax} bmin:${bmin} bmax:${bmax}")
(amin, bmin) match {
case (-2, -2) =>
exit_list
case (_, -2) =>
aidx += 1
List(amin, amax)
case (-2, _) =>
bidx += 1
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1
List(amin, amax)
case _ =>
bidx += 1
List(bmin, bmax)
}
}.takeWhile(_ != exit_list)
}
class IntervalIterator(intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long) {
var cur_interval_idx: Long = 0
var cur_interval: ArraySeq[Long] = intervals(cur_interval_idx.toInt)
var cur_range_idx: Long = 0
var cur_range: ArraySeq[Long] = ArraySeq.range(cur_interval.head, cur_interval.last+1)
_jumpForward(bit_idx_first)
def printState(): Unit = {
println(s"bit_idx_first: ${bit_idx_first} bit_idx_last: ${bit_idx_last}")
println(s"cur_interval_idx: ${cur_interval_idx} cur_interval: ${cur_interval}")
println(s"cur_range_idx: ${cur_range_idx} cur_range: ${cur_range}")
}
def _jumpForward(i: Long): Unit = {
var found = false
while (!isEmpty & !found) {
if (bit_idx_first > cur_interval.last) {
_moveForwardInterval()
} else {
if (bit_idx_first > cur_range(cur_range_idx.toInt)){
_moveForwardRange()
} else {
found = true
}
}
}
}
def _moveEnd(): Unit = {
cur_interval_idx = -2
cur_range_idx = -2
}
def _moveForwardInterval(): Unit = {
cur_interval_idx += 1
cur_range_idx = 0
cur_interval = intervals(cur_interval_idx.toInt)
cur_range = ArraySeq.range(cur_interval.head, cur_interval.last+1)
}
def _moveForwardRange(): Unit = {
cur_range_idx += 1
}
def isEmpty: Boolean = {
if (cur_interval_idx == -2 & cur_range_idx == -2){
true
} else if (cur_interval_idx >= intervals.length & cur_range_idx >= cur_range.length){
true
} else {
false
}
}
def nonEmpty: Boolean = {
!isEmpty
}
def _move(): Unit = {
if (cur_range_idx+1 >= cur_range.length){
if (cur_interval_idx+1 >= cur_interval.length) {
_moveEnd()
} else {
_moveForwardInterval()
}
} else {
_moveForwardRange()
}
}
def get(): Long = {
val i = cur_range(cur_range_idx.toInt)
if (i >= bit_idx_last & bit_idx_last >= 0){
_moveEnd()
} else {
_move()
}
i
}
}
// def getIntervalBitsGenerator(intervals: List[List[Long]], bit_idx_first: Long, bit_idx_last: Long): Iterator[Long] = {
// for (lst <- intervals) {
// val start = lst(0)
// val end = lst.last + 1
// LazyList.range(start, end)
// }
// }
def getIntervalBits(a_bitmask: LiteBitmaskStruct, bit_idx_first: Long = 0, bit_idx_last: Long = -2): List[Long] = {
val bit_lst: ListBuffer[Long] = ListBuffer()
a_bitmask.intervals.foreach(elem => {
for (i <- elem.head to elem.last) { // inclusive
if (i < bit_idx_first){
} else if (bit_idx_last != -2 & (i > bit_idx_last)) {
} else {
bit_lst += i
}
}
})
bit_lst.toList
}
def intersection_any_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = false
if (!intersection_possible(a_bitmask, b_bitmask)) {
result = false
} else {
val first_a = a_bitmask.bit_idx_first
val last_a = a_bitmask.bit_idx_last
val first_b = b_bitmask.bit_idx_first
val last_b = b_bitmask.bit_idx_last
if ((first_a == first_b) | (last_a == last_b)) {
result = true
} else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
result = true
} else {
val bit_idx_first = if (first_a >= first_b) first_a else first_b
val bit_idx_last = if (last_a < last_b) last_a else last_b
var a_lst = getIntervalBits(a_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
var b_lst = getIntervalBits(b_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
if (a_lst.isEmpty | b_lst.isEmpty){
} else {
var a_idx = a_lst.head
a_lst = a_lst.tail
var b_idx = b_lst.head
b_lst = b_lst.tail
while (a_idx >= 0 & b_idx >= 0) {
if (a_idx == b_idx) {
result = true
a_idx = -2
b_idx = -2
} else if (a_idx > b_idx){
// move b_idx
if (b_lst.nonEmpty) {
b_idx = b_lst.head
b_lst = b_lst.tail
} else {
b_idx = -2
}
} else {
// move a_idx
if (a_lst.nonEmpty){
a_idx = a_lst.head
a_lst = a_lst.tail
} else {
a_idx = -2
}
}
}
}
}
}
result
}
def intersection_any_v0(abi: Row, bbi: Row): Boolean = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intersection_any_v0(bitmask_a, bitmask_b)
}
def intersection_any_v1(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = false
if (!intersection_possible(a_bitmask, b_bitmask)) {
result = false
} else {
val first_a = a_bitmask.bit_idx_first
val last_a = a_bitmask.bit_idx_last
val first_b = b_bitmask.bit_idx_first
val last_b = b_bitmask.bit_idx_last
if ((first_a == first_b) | (last_a == last_b)) {
result = true
} else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
result = true
} else {
val bit_idx_first = if (first_a >= first_b) first_a else first_b
val bit_idx_last = if (last_a < last_b) last_a else last_b
val a_ii = new IntervalIterator(a_bitmask.intervals, bit_idx_first, bit_idx_last)
val b_ii = new IntervalIterator(b_bitmask.intervals, bit_idx_first, bit_idx_last)
if (a_ii.isEmpty | b_ii.isEmpty){
} else {
var a_idx = a_ii.get()
var b_idx = b_ii.get()
while (a_idx >= 0 & b_idx >= 0) {
if (a_idx == b_idx) {
result = true
a_idx = -2
b_idx = -2
} else if (a_idx > b_idx){
// move b_idx
if (b_ii.nonEmpty) {
b_idx = b_ii.get()
} else {
b_idx = -2
}
} else {
// move a_idx
if (a_ii.nonEmpty){
a_idx = a_ii.get()
} else {
a_idx = -2
}
}
}
}
}
}
result
}
def intersection_any_v1(abi: Row, bbi: Row): Boolean = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intersection_any_v1(bitmask_a, bitmask_b)
}
def intersection_any(a_bitmask: Row, b_bitmask: Row): Boolean = {
val result = (a_bitmask, b_bitmask) match {
case (null, null) => false
case (null, _) => false
case (_, null) => false
case (_, _) => intersection_any_v1(a_bitmask, b_bitmask)
case (_, _) => Octeres.Bitmask.BitmaskLite.intersection_any(rowToLiteBitmaskStruct(a_bitmask), rowToLiteBitmaskStruct(b_bitmask))
}
result
}
......
......@@ -25,6 +25,7 @@ from pyspark.sql.types import StructType, StructField, TimestampType, StringType
from Ocean.schema_lookup import schema_bitmask
from Octeres.bitmask import BASE_DTYPE_MAX
from Octeres.bitmask import eb
from Octeres.bitmask.bitmask_globals import __boxed_dtype__
from Octeres.bitmask.bitmask_lite import LiteBitmask, LiteBitmaskSlots
from Octeres.data_generation import EventGeneration
from Octeres.forthwith import FORMAT_DATE_DAY
......@@ -47,6 +48,11 @@ try:
except ImportError:
dask = None
try:
import swifter
except ImportError:
swifter = None
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 512)
......@@ -753,8 +759,9 @@ def test_get_mask_sum_03b():
direction=Event_Direction.NEGATIVE,
)
pit.negative_add(dct)
with pytest.raises(Underflow):
value = pit.get_mask_sum()
if __boxed_dtype__ == np.int8:
with pytest.raises(Underflow):
value = pit.get_mask_sum()
def test_get_mask_sum_03c():
......@@ -772,13 +779,103 @@ def test_get_mask_sum_03c():
direction=Event_Direction.NEGATIVE,
)
pit.negative_add(dct)
with pytest.raises(Overflow):
value = pit.get_mask_sum()
if __boxed_dtype__ == np.int8:
with pytest.raises(Overflow):
value = pit.get_mask_sum()
# FIXME: todo
# @pytest.mark.skipif(swifter is None, reason="could not import swifter")
# def test_the_gauntlet_swifter():
# # requires dask, dask[distributed]
# machine_name = "test_machine"
# dtype = "U2"
# empty_value = " "
# test_value = ".."
# bit_total = 32 # 100000
# bitmask_class = LiteBitmask
# time_seconds = 3600 # 86400
# range_start = datetime.datetime.strptime("2024-01-01", FORMAT_DATE_DAY)
# range_end = range_start + datetime.timedelta(seconds=time_seconds)
# eg = EventGeneration(
# time_seconds,
# bit_total,
# dtype=dtype,
# empty_value=empty_value,
# test_value=test_value,
# # visualize=True,
# # visualize_sleep=0.25,
# bitmask_class=bitmask_class,
# seed=42,
# )
# characters = []
# characters.extend(list(range(97, 122 + 1)))
# characters.extend(list(range(65, 90 + 1)))
# characters.extend(list(range(48, 57 + 1)))
# characters = list(map(chr, characters))
# event_names = eg.generate_names_n_level(characters, 3)
# events = eg.generate_non_overlapping_box_events_v1(event_names)
# event_lst = events.event_lst
# correct_unit_seconds = time_seconds * bit_total
# correct_possible_unit_seconds = (range_end - range_start).total_seconds() * bit_total
#
# for event_dct in event_lst:
# ts = event_dct['ts']
# te = event_dct['te']
# event_dct['time_start'] = pd.to_datetime(range_start +
# datetime.timedelta(seconds=ts, microseconds=int(random.uniform(0, 1) * 1000000)), utc=False)
# event_dct['time_end'] = pd.to_datetime(range_start +
# datetime.timedelta(seconds=te, microseconds=int(random.uniform(0, 1) * 1000000)), utc=False)
# event_dct['pk'] = event_dct['name']
# event_dct['event_type_name'] = 'job'
#
# pdf = pd.DataFrame(event_lst)
# pdf['time_start'] = pdf['ts'].swifter.apply(lambda tsi: pd.to_datetime(range_start + datetime.timedelta(seconds=tsi, microseconds=int(random.uniform(0, 1) * 1000000)), utc=False))
# pdf['time_end'] = pdf['te'].swifter.apply(lambda tsi: pd.to_datetime(range_start + datetime.timedelta(seconds=tsi, microseconds=int(random.uniform(0, 1) * 1000000)), utc=False))
# pdf['pk'] = pdf['name']
# pdf['event_type_name'] = 'job'
#
# # ddf['bitmask'] = ddf['bitmask'].apply(func, meta=('bitmask', object))
# # event_lst = df_to_lstofdct(pdf2)
# # for event_dct in event_lst:
# # event_dct['bitmask'] = LiteBitmask.from_dict(event_dct['bitmask'])
#
# # base_mask = bitmask_class.zeros(bit_total)
# # tl = Timeline(machine_name, range_start, range_end, base_mask)
# # possible_unit_seconds = tl.possible_unit_seconds
# # with cProfile.Profile() as pr:
# # timeline_lst = tl.prepare_event_timeline(event_lst)
# # timeline_dct = tl.group_timeline(timeline_lst)
# # timeline_sorted = tl.sort_timeline(timeline_dct)
# # mask_timeline = tl.normalize_timeline_pit(timeline_sorted)
# # mask_timeline = tl.sum_timeline(mask_timeline, test_negative=False)
# # # tl.print_timeline(mask_timeline, binary=True)
# # mask_timeline = tl.flatten_timeline(mask_timeline)
# # mask_timeline = tl.isolate_timeline_range(mask_timeline)
# # # tl.print_timeline(mask_timeline)
# # mag_timeline = tl.mask_timeline_to_mag_timeline(mask_timeline)
# # consumed_unit_seconds = tl.calculate_area(mag_timeline)
# # assert round(abs(consumed_unit_seconds - correct_unit_seconds), 4) == 0
# # assert possible_unit_seconds == correct_possible_unit_seconds
# # profile_result = pstats.Stats(pr)
# # profile_result.sort_stats(SortKey.CUMULATIVE).print_stats(4)
# # print(f"{len(event_lst)=} {len(timeline_lst)=}")
#
# base_mask = bitmask_class.zeros(bit_total)
# tl = TimelineParallel(machine_name, range_start, range_end, base_mask, processes=16)
# possible_unit_seconds = tl.possible_unit_seconds
# with cProfile.Profile() as pr:
# tl.load(event_lst)
# tl.run()
# consumed_unit_seconds = tl.calculate_area()
# # assert round(abs(consumed_unit_seconds - correct_unit_seconds), 4) == 0
# # assert possible_unit_seconds == correct_possible_unit_seconds
# profile_result = pstats.Stats(pr)
# profile_result.sort_stats(SortKey.CUMULATIVE).print_stats(16)
#
#
# # FIXME: todo
# @pytest.mark.skipif(dask is None, reason="could not import dask")
# def test_the_gauntlet():
# def test_the_gauntlet_dask():
# # requires dask, dask[distributed]
# machine_name = "test_machine"
# dtype = "U2"
......
/*
* Copyright (C) 2024, UChicago Argonne, LLC
* Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
* in the top-level directory.
*/
package Octeres.Bitmask
import Octeres.Bitmask.BitmaskLite.{LiteBitmaskStruct, getLiteBitmaskZeros, IntervalIterator, validateLiteBitmaskSlotsLike,
intersection_any_v0, intersection_any_v1, logical_or, intersection_any}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._
import org.scalatest.prop.TableFor3
import scala.collection.immutable.ArraySeq
class BitmaskLiteTest extends AnyFunSuite{
test("test_bitmask_logical_or") {
val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, LiteBitmaskStruct] = Table(
("bitmask0", "bitmask1", "bitmask_or"),
(
LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
LiteBitmaskStruct(8, 4, 0, 3, ArraySeq(ArraySeq[Long](0, 3)))
), (
LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
getLiteBitmaskZeros(8),
LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
), (
getLiteBitmaskZeros(8),
getLiteBitmaskZeros(8),
getLiteBitmaskZeros(8),
), (
LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4)))
), (
LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
LiteBitmaskStruct(16, 8, 1, 8, ArraySeq(ArraySeq[Long](1, 8)))
)
)
forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: LiteBitmaskStruct) => {
println(s"bitmask0: ${bitmask0}")
println(s"bitmask1: ${bitmask1}")
println(s"bitmask_or: ${bitmask_or}")
validateLiteBitmaskSlotsLike(bitmask0)
validateLiteBitmaskSlotsLike(bitmask1)
validateLiteBitmaskSlotsLike(bitmask_or)
val bitmask_r = logical_or(bitmask0, bitmask1)
bitmask_r should equal(bitmask_or)
println(intersection_any(bitmask0, bitmask1))
}
}
}
test("test_getIntervalBitsGenerator") {
val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table(
("intervals", "bit_idx_first", "bit_idx_last"),
(ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1),
(ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9),
(ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6),
(ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6)
)
forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
{
println(intervals)
val ii = new IntervalIterator(intervals, bit_idx_first, bit_idx_last)
ii.printState()
while (!ii.isEmpty) {
println(s"i: ${ii.get()}")
}
}
}}
}
test("test_bitmask_intersection_any") {
val bitmaskTable: TableFor3[LiteBitmaskStruct, LiteBitmaskStruct, Boolean] = Table(
("bitmask0", "bitmask1", "logical_truth"),
(
LiteBitmaskStruct(8, 3, 0, 2, ArraySeq(ArraySeq[Long](0, 2))),
LiteBitmaskStruct(8, 3, 1, 3, ArraySeq(ArraySeq[Long](1, 3))),
true
), (
LiteBitmaskStruct(8, 5, 0, 4, ArraySeq(ArraySeq[Long](0, 4))),
getLiteBitmaskZeros(8),
false
), (
getLiteBitmaskZeros(8),
getLiteBitmaskZeros(8),
false
), (
LiteBitmaskStruct(8, 4, 0, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4))),
LiteBitmaskStruct(8, 2, 1, 2, ArraySeq(ArraySeq[Long](1, 2))),
true
), (
LiteBitmaskStruct(16, 4, 1, 6, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6))),
LiteBitmaskStruct(16, 4, 3, 8, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8))),
false
)
)
forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => {
println(s"bitmask0: ${bitmask0}")
println(s"bitmask1: ${bitmask1}")
validateLiteBitmaskSlotsLike(bitmask0)
validateLiteBitmaskSlotsLike(bitmask1)
val truth = intersection_any(bitmask0, bitmask1)
println(s"truth: ${truth}")
truth should equal(logical_truth)
intersection_any_v0(bitmask0, bitmask1) should equal(intersection_any_v1(bitmask0, bitmask1))
}
}
}
}
......@@ -6,7 +6,8 @@
package Octeres
import Octeres.DataUDF.{LiteBitmaskSchema, intersection_any, intersection_any_v0, intersection_any_v1, liteBitmaskStructToRow, logical_or, validateLiteBitmaskSlotsLike}
import Octeres.DataUDF.{LiteBitmaskSchema, intersection_any, liteBitmaskStructToRow, logical_or, validateLiteBitmaskSlotsLike}
import Octeres.Bitmask.BitmaskLite.{intersection_any_v0, intersection_any_v1, IntervalIterator}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoders, Row, SparkSession}
......@@ -199,7 +200,7 @@ class DataUDFTest extends AnyFunSuite {
forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
{
println(intervals)
val ii = new DataUDF.IntervalIterator(intervals, bit_idx_first, bit_idx_last)
val ii = new IntervalIterator(intervals, bit_idx_first, bit_idx_last)
ii.printState()
while (!ii.isEmpty) {
println(s"i: ${ii.get()}")
......@@ -242,7 +243,6 @@ class DataUDFTest extends AnyFunSuite {
val truth = intersection_any(bitmask0, bitmask1)
println(s"truth: ${truth}")
truth should equal(logical_truth)
intersection_any_v0(bitmask0, bitmask1) should equal(intersection_any_v1(bitmask0, bitmask1))
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment