Newer
Older
/*
* Copyright (C) 2024, UChicago Argonne, LLC
* Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
* in the top-level directory.
*/

Eric Pershey
committed
/*
Todo: the .toInt are not the best for indexing a ArraySeq, this may require a rewrite using a different larger structure other
than ArraySeq[ArraySeq[Int]].
*/

Eric Pershey
committed
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf

Eric Pershey
committed
import org.apache.spark.sql.types._
import scala.collection.immutable.ArraySeq

Eric Pershey
committed
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

Eric Pershey
committed

Eric Pershey
committed
case class LiteBitmaskStruct(length: Long, bit_count: Long, bit_idx_first: Long, bit_idx_last: Long,
intervals: ArraySeq[ArraySeq[Long]])

Eric Pershey
committed
val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)

Eric Pershey
committed
def getLiteBitmaskZeros(length: Long): Row = {
new GenericRowWithSchema(Array(length, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema)
}
def getLiteBitmaskZeros(length: Int): Row = {
new GenericRowWithSchema(Array(length.toLong, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema)
}
def getLiteBitmaskRow(length: Long, bit_count: Long, intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long): Row = {
new GenericRowWithSchema(Array(length, bit_count, bit_idx_first, bit_idx_last, intervals), LiteBitmaskSchema)

Eric Pershey
committed
def getCombinedIntervalsUntilAllConsumed(alst: ArraySeq[ArraySeq[Long]], blst: ArraySeq[ArraySeq[Long]]): Seq[List[Long]] = {
var aidx: Long = 0
var bidx: Long = 0
val exit_list = List[Long](-2, -2)

Eric Pershey
committed
LazyList.continually {

Eric Pershey
committed
val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2L
val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2L
val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2L
val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2L

Eric Pershey
committed
(amin, bmin) match {

Eric Pershey
committed
case (-2L, -2L) =>

Eric Pershey
committed
exit_list

Eric Pershey
committed
case (_, -2L) =>
aidx += 1L

Eric Pershey
committed
List(amin, amax)

Eric Pershey
committed
case (-2L, _) =>
bidx += 1L

Eric Pershey
committed
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>

Eric Pershey
committed
aidx += 1L

Eric Pershey
committed
List(amin, amax)
case _ =>

Eric Pershey
committed
bidx += 1L

Eric Pershey
committed
List(bmin, bmax)

Eric Pershey
committed
}
}.takeWhile(_ != exit_list)
}
class BitmaskError(message: String) extends Exception(message) {
def this(message: String, cause: Throwable = null) = {
this(message)
initCause(cause)
}
def this(cause: Throwable) = {
this(Option(cause).map(_.toString).orNull, cause)
}
def this() = {
this(null: String)
}
}
def validateLiteBitmaskSlotsLike(row: Row): Unit = {
val obj = rowToLiteBitmaskStruct(row)
validateLiteBitmaskSlotsLike(obj)
}
def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
val intervals = obj.intervals
if (obj.bit_count > 0) {

Eric Pershey
committed
if (obj.bit_idx_last >= obj.length) {
throw new BitmaskError(s"bit_idx_last is greater than length")
}
if (obj.bit_idx_first != intervals.head.head) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
if (obj.bit_idx_last != intervals.last.last) {

Eric Pershey
committed
throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != intervals.last.last:${intervals.last.last}")
}
} else {
if (obj.bit_idx_first != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
if (obj.bit_idx_last != -1) {

Eric Pershey
committed
throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != -1")

Eric Pershey
committed
var bit_count: Long = 0
for ((alar, i) <- intervals.zipWithIndex) {

Eric Pershey
committed
val al: Long = alar.head
val ar: Long = alar.last
val sub_bit_count: Long = (ar - al) + 1L
if (sub_bit_count < 1) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
bit_count += sub_bit_count
if (al > ar) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}

Eric Pershey
committed
val blbr: Option[ArraySeq[Long]] = intervals.lift(i + 1)

Eric Pershey
committed
case Some(x: ArraySeq[Long]) => {
val bl: Long = x.head
val br: Long = x.last
if ((al == bl) & (ar == br)) {
throw new BitmaskError(s"duplicate interval: [$al, $ar]")
}
if ((ar + 1) == bl) {
throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
}
if (ar > bl) {
throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
}
}
case None => Nil
}
}
if (bit_count != obj.bit_count) {
throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
}
}
def logical_or_v0(abi: Row, bbi: Row): LiteBitmaskStruct = {
logical_or_v0(rowToLiteBitmaskStruct(abi), rowToLiteBitmaskStruct(bbi))
def logical_or_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */

Eric Pershey
committed
val intervalGen: Seq[List[Long]] = getCombinedIntervalsUntilAllConsumed(

Eric Pershey
committed
a_bitmask.intervals,
b_bitmask.intervals)

Eric Pershey
committed
val intervals = ArrayBuffer.empty[ArraySeq[Long]]
var bitCount: Long = 0
var bitIdxFirst: Long = -1
var bitIdxLast: Long = -1
var prev_start: Long = -1
var prev_end: Long = -1

Eric Pershey
committed
val itr = intervalGen.iterator
for (List(next_start, next_end) <- itr) {
(next_start, next_end) match {
case (-2, -2) => // case where there is nothing
bitCount = 0
bitIdxFirst = -1
bitIdxLast = -1
case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
prev_start = next_start
prev_end = next_end

Eric Pershey
committed
intervals += ArraySeq(prev_start, prev_end)

Eric Pershey
committed
bitCount = prev_end - prev_start + 1
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
case _ =>
if (next_start <= prev_end + 1) {
val new_end = Math.max(prev_end, next_end)

Eric Pershey
committed
intervals(intervals.length - 1) = ArraySeq(prev_start, new_end)

Eric Pershey
committed
bitCount += new_end - prev_end
prev_end = new_end
} else {

Eric Pershey
committed
intervals += ArraySeq(next_start, next_end)

Eric Pershey
committed
bitCount += next_end - next_start + 1
prev_start = next_start
prev_end = next_end
}
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
}
}

Eric Pershey
committed
val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.from(intervals)
// hopefully this is a view not a copy
// val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.unsafeWrapArray(intervals.toArray)
LiteBitmaskStruct(a_bitmask.length, bitCount, bitIdxFirst, bitIdxLast, result_intervals)

Eric Pershey
committed
}
def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
/* convert a LiteBitmaskStruct to a Row */
// might need this somehow: import sparkSession.implicits._
val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count,

Eric Pershey
committed
bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last, bitmaskStruct.intervals), LiteBitmaskSchema)

Eric Pershey
committed
def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
/* convert a Row to a LiteBitmaskStruct */

Eric Pershey
committed
val length: Long = row.getAs[Long]("length")
val bit_count: Long = row.getAs[Long]("bit_count")
val bit_idx_first: Long = row.getAs[Long]("bit_idx_first")
val bit_idx_last: Long = row.getAs[Long]("bit_idx_last")
val intervals: ArraySeq[ArraySeq[Long]] = row.getAs[ArraySeq[ArraySeq[Long]]]("intervals")
LiteBitmaskStruct(length, bit_count, bit_idx_first, bit_idx_last, intervals)

Eric Pershey
committed
}

Eric Pershey
committed
def logical_or(a_bitmask: Row, b_bitmask: Row): LiteBitmaskStruct = {
(a_bitmask, b_bitmask) match {
case (null, null) => null
case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
case (_, _) => logical_or_v0(a_bitmask, b_bitmask)
}

Eric Pershey
committed
}
def intersection_possible(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = true
if ((a_bitmask.bit_count == 0) | (b_bitmask.bit_count == 0)) {
result = false
} else if (a_bitmask.bit_idx_first > b_bitmask.bit_idx_last) {
result = false
} else if (a_bitmask.bit_idx_last < b_bitmask.bit_idx_first) {
result = false
} else {
result = true
}
result
}

Eric Pershey
committed
def intersection_possible(abi: Row, bbi: Row): Boolean = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intersection_possible(bitmask_a, bitmask_b)
}
def getIntervalUntilAllConsumed(alst: List[List[Long]], blst: List[List[Long]]): Seq[List[Long]] = {

Eric Pershey
committed
/* from get_intervals_until_all_consumed */

Eric Pershey
committed
var aidx: Long = 0
var bidx: Long = 0
val exit_list: List[Long] = List(-2, -2)

Eric Pershey
committed
LazyList.continually {

Eric Pershey
committed
val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2
val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2
val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2
val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2

Eric Pershey
committed
// println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax} bmin:${bmin} bmax:${bmax}")
(amin, bmin) match {
case (-2, -2) =>
exit_list
case (_, -2) =>
aidx += 1
List(amin, amax)
case (-2, _) =>
bidx += 1
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1
List(amin, amax)
case _ =>
bidx += 1
List(bmin, bmax)
}
}.takeWhile(_ != exit_list)
}

Eric Pershey
committed
class IntervalIterator(intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long) {
var cur_interval_idx: Long = 0
var cur_interval: ArraySeq[Long] = intervals(cur_interval_idx.toInt)
var cur_range_idx: Long = 0
var cur_range: ArraySeq[Long] = ArraySeq.range(cur_interval.head, cur_interval.last+1)

Eric Pershey
committed
_jumpForward(bit_idx_first)
def printState(): Unit = {
println(s"bit_idx_first: ${bit_idx_first} bit_idx_last: ${bit_idx_last}")
println(s"cur_interval_idx: ${cur_interval_idx} cur_interval: ${cur_interval}")
println(s"cur_range_idx: ${cur_range_idx} cur_range: ${cur_range}")
}

Eric Pershey
committed
def _jumpForward(i: Long): Unit = {

Eric Pershey
committed
var found = false
while (!isEmpty & !found) {
if (bit_idx_first > cur_interval.last) {
_moveForwardInterval()
} else {

Eric Pershey
committed
if (bit_idx_first > cur_range(cur_range_idx.toInt)){

Eric Pershey
committed
_moveForwardRange()
} else {
found = true
}
}
}
}
def _moveEnd(): Unit = {
cur_interval_idx = -2
cur_range_idx = -2
}
def _moveForwardInterval(): Unit = {
cur_interval_idx += 1
cur_range_idx = 0

Eric Pershey
committed
cur_interval = intervals(cur_interval_idx.toInt)
cur_range = ArraySeq.range(cur_interval.head, cur_interval.last+1)

Eric Pershey
committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
}
def _moveForwardRange(): Unit = {
cur_range_idx += 1
}
def isEmpty: Boolean = {
if (cur_interval_idx == -2 & cur_range_idx == -2){
true
} else if (cur_interval_idx >= intervals.length & cur_range_idx >= cur_range.length){
true
} else {
false
}
}
def nonEmpty: Boolean = {
!isEmpty
}
def _move(): Unit = {
if (cur_range_idx+1 >= cur_range.length){
if (cur_interval_idx+1 >= cur_interval.length) {
_moveEnd()
} else {
_moveForwardInterval()
}
} else {
_moveForwardRange()
}
}

Eric Pershey
committed
def get(): Long = {
val i = cur_range(cur_range_idx.toInt)

Eric Pershey
committed
if (i >= bit_idx_last & bit_idx_last >= 0){
_moveEnd()
} else {
_move()
}
i
}
}

Eric Pershey
committed
// def getIntervalBitsGenerator(intervals: List[List[Long]], bit_idx_first: Long, bit_idx_last: Long): Iterator[Long] = {

Eric Pershey
committed
// for (lst <- intervals) {
// val start = lst(0)
// val end = lst.last + 1
// LazyList.range(start, end)
// }
// }

Eric Pershey
committed
def getIntervalBits(a_bitmask: LiteBitmaskStruct, bit_idx_first: Long = 0, bit_idx_last: Long = -2): List[Long] = {
val bit_lst: ListBuffer[Long] = ListBuffer()

Eric Pershey
committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
a_bitmask.intervals.foreach(elem => {
for (i <- elem.head to elem.last) { // inclusive
if (i < bit_idx_first){
} else if (bit_idx_last != -2 & (i > bit_idx_last)) {
} else {
bit_lst += i
}
}
})
bit_lst.toList
}
def intersection_any_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = false
if (!intersection_possible(a_bitmask, b_bitmask)) {
result = false
} else {
val first_a = a_bitmask.bit_idx_first
val last_a = a_bitmask.bit_idx_last
val first_b = b_bitmask.bit_idx_first
val last_b = b_bitmask.bit_idx_last
if ((first_a == first_b) | (last_a == last_b)) {
result = true
} else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
result = true
} else {
val bit_idx_first = if (first_a >= first_b) first_a else first_b
val bit_idx_last = if (last_a < last_b) last_a else last_b
var a_lst = getIntervalBits(a_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
var b_lst = getIntervalBits(b_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
if (a_lst.isEmpty | b_lst.isEmpty){
} else {
var a_idx = a_lst.head
a_lst = a_lst.tail
var b_idx = b_lst.head
b_lst = b_lst.tail
while (a_idx >= 0 & b_idx >= 0) {
if (a_idx == b_idx) {
result = true
a_idx = -2
b_idx = -2
} else if (a_idx > b_idx){
// move b_idx
if (b_lst.nonEmpty) {
b_idx = b_lst.head
b_lst = b_lst.tail
} else {
b_idx = -2
}
} else {
// move a_idx
if (a_lst.nonEmpty){
a_idx = a_lst.head
a_lst = a_lst.tail
} else {
a_idx = -2
}
}
}
}
}
}
result
}

Eric Pershey
committed
def intersection_any_v0(abi: Row, bbi: Row): Boolean = {

Eric Pershey
committed
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)

Eric Pershey
committed
intersection_any_v0(bitmask_a, bitmask_b)

Eric Pershey
committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
}
def intersection_any_v1(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
var result = false
if (!intersection_possible(a_bitmask, b_bitmask)) {
result = false
} else {
val first_a = a_bitmask.bit_idx_first
val last_a = a_bitmask.bit_idx_last
val first_b = b_bitmask.bit_idx_first
val last_b = b_bitmask.bit_idx_last
if ((first_a == first_b) | (last_a == last_b)) {
result = true
} else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
result = true
} else {
val bit_idx_first = if (first_a >= first_b) first_a else first_b
val bit_idx_last = if (last_a < last_b) last_a else last_b
val a_ii = new IntervalIterator(a_bitmask.intervals, bit_idx_first, bit_idx_last)
val b_ii = new IntervalIterator(b_bitmask.intervals, bit_idx_first, bit_idx_last)
if (a_ii.isEmpty | b_ii.isEmpty){
} else {
var a_idx = a_ii.get()
var b_idx = b_ii.get()
while (a_idx >= 0 & b_idx >= 0) {
if (a_idx == b_idx) {
result = true
a_idx = -2
b_idx = -2
} else if (a_idx > b_idx){
// move b_idx
if (b_ii.nonEmpty) {
b_idx = b_ii.get()
} else {
b_idx = -2
}
} else {
// move a_idx
if (a_ii.nonEmpty){
a_idx = a_ii.get()
} else {
a_idx = -2
}
}
}
}
}
}
result
}

Eric Pershey
committed
def intersection_any_v1(abi: Row, bbi: Row): Boolean = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intersection_any_v1(bitmask_a, bitmask_b)
}
def intersection_any(a_bitmask: Row, b_bitmask: Row): Boolean = {
val result = (a_bitmask, b_bitmask) match {
case (null, null) => false
case (null, _) => false
case (_, null) => false
case (_, _) => intersection_any_v1(a_bitmask, b_bitmask)
}
result

Eric Pershey
committed
}
val udf_intersection_any: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => intersection_any(a_bitmask, b_bitmask))
val udf_logical_or: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => logical_or(a_bitmask, b_bitmask))

Eric Pershey
committed
def testCall(): Unit = {
println(s"from Scala: nothing to see here!")
}
def testCall(a: Long, b: Long): Long = {
println(s"from Scala: ${a} + ${b} = ${a + b}")
a + b
}
def registerAll(): Unit = {
val sparkSession: SparkSession = SparkSession.builder().getOrCreate()
println("Registering Scala UDF functions using sparkSession.")
sparkSession.udf.register("logical_or", udf_logical_or)
sparkSession.udf.register("intersection_any", udf_intersection_any)

Eric Pershey
committed
}