/* * Copyright (C) 2024, UChicago Argonne, LLC * Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file * in the top-level directory. */ package Octeres import org.apache.spark.sql.SparkSession import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ import scala.collection.immutable.ArraySeq import scala.collection.mutable.ListBuffer import scala.io.Source object DataUDF { case class LiteBitmaskStruct(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema // val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType] val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass) def getLiteBitmaskRow(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): Row = { new GenericRowWithSchema(Array(length, bit_count, intervals, bit_idx_first, bit_idx_last), LiteBitmaskSchema) } def getCombinedIntervalsUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = { var aidx = 0 var bidx = 0 val exit_list = List(-2, -2) LazyList.continually { val amin: Int = if (aidx < alst.length) alst(aidx).head else -2 val amax: Int = if (aidx < alst.length) alst(aidx).last else -2 val bmin: Int = if (bidx < blst.length) blst(bidx).head else -2 val bmax: Int = if (bidx < blst.length) blst(bidx).last else -2 // println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax} bmin:${bmin} bmax:${bmax}") (amin, bmin) match { case (-2, -2) => exit_list case (_, -2) => aidx += 1 List(amin, amax) case (-2, _) => bidx += 1 List(bmin, bmax) case (_, _) if (amin != -2 & amin <= bmin) => aidx += 1 List(amin, amax) case _ => bidx += 1 List(bmin, bmax) } }.takeWhile(_ != exit_list) } class BitmaskError(message: String) extends Exception(message) { def this(message: String, cause: Throwable = null) { this(message) initCause(cause) } def this(cause: Throwable) = { this(Option(cause).map(_.toString).orNull, cause) } def this() = { this(null: String) } } def validateLiteBitmaskSlotsLike(row: Row): Unit = { val obj = rowToLiteBitmaskStruct(row) validateLiteBitmaskSlotsLike(obj) } def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = { val intervals = obj.intervals if (obj.bit_count > 0) { if (obj.bit_idx_last >= obj.length) { throw new BitmaskError(s"bit_idx_last is greater than length") } if (obj.bit_idx_first != intervals.head.head) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}") } if (obj.bit_idx_last != intervals.last.last) { throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != intervals.last.last:${intervals.last.last}") } } else { if (obj.bit_idx_first != -1) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1") } if (obj.bit_idx_last != -1) { throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != -1") } } var bit_count = 0 for ((alar, i) <- intervals.zipWithIndex) { val al = alar.head val ar = alar.last val sub_bit_count = (ar - al) + 1 if (sub_bit_count < 1) { throw new BitmaskError(s"negative interval: [$al, $ar]") } bit_count += sub_bit_count if (al > ar) { throw new BitmaskError(s"negative interval: [$al, $ar]") } val blbr: Option[List[Int]] = intervals.lift(i + 1) blbr match { case Some(x: List[Int]) => { val bl: Int = x.head val br: Int = x.last if ((al == bl) & (ar == br)) { throw new BitmaskError(s"duplicate interval: [$al, $ar]") } if ((ar + 1) == bl) { throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr") } if (ar > bl) { throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr") } } case None => Nil } } if (bit_count != obj.bit_count) { throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}") } } def intervals_or_v0(length: Int, abi: Row, bbi: Row): LiteBitmaskStruct = { val bitmask_a = rowToLiteBitmaskStruct(abi) val bitmask_b = rowToLiteBitmaskStruct(bbi) intervals_or_v0(length, bitmask_a, bitmask_b) } def intervals_or_v0(length: Int, a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = { /* does a logical or of two bit intervals * From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */ val intervalGen: Seq[List[Int]] = getCombinedIntervalsUntilAllConsumed( a_bitmask.intervals, b_bitmask.intervals) val intervals = ListBuffer.empty[List[Int]] var bitCount = 0 var bitIdxFirst = -1 var bitIdxLast = -1 var prev_start = -1 var prev_end = -1 val itr = intervalGen.iterator for (List(next_start, next_end) <- itr) { (next_start, next_end) match { case (-2, -2) => // case where there is nothing bitCount = 0 bitIdxFirst = -1 bitIdxLast = -1 case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous prev_start = next_start prev_end = next_end intervals += List(prev_start, prev_end) bitCount = prev_end - prev_start + 1 bitIdxFirst = intervals.head.head bitIdxLast = intervals.last.last case _ => if (next_start <= prev_end + 1) { val new_end = Math.max(prev_end, next_end) intervals(intervals.length - 1) = List(prev_start, new_end) bitCount += new_end - prev_end prev_end = new_end } else { intervals += List(next_start, next_end) bitCount += next_end - next_start + 1 prev_start = next_start prev_end = next_end } bitIdxFirst = intervals.head.head bitIdxLast = intervals.last.last } } LiteBitmaskStruct(length, bitCount, intervals.map(_.toList).toList, bitIdxFirst, bitIdxLast) } def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = { /* convert a LiteBitmaskStruct to a Row */ // might need this somehow: import sparkSession.implicits._ // row.asInstanceOf[LiteBitmaskStruct] // does not work val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count, bitmaskStruct.intervals, bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last), LiteBitmaskSchema) row } def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = { /* convert a Row to a LiteBitmaskStruct */ // might need this somehow: import sparkSession.implicits._ // row.asInstanceOf[LiteBitmaskStruct] // does not work val bitmaskStruct: LiteBitmaskStruct = row match { case Row(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) => LiteBitmaskStruct(length, bit_count, intervals, bit_idx_first, bit_idx_last) case Row(length: Int, bit_count: Int, intervals: ArraySeq[ArraySeq[Int]], bit_idx_first: Int, bit_idx_last: Int) => LiteBitmaskStruct(length, bit_count, intervals.toList.map(_.toList), bit_idx_first, bit_idx_last) case _ => LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4)) } bitmaskStruct } val logical_or: (Row, Row) => LiteBitmaskStruct = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match { case (null, null) => null case (null, _) => rowToLiteBitmaskStruct(b_bitmask) case (_, null) => rowToLiteBitmaskStruct(a_bitmask) case (_, _) => intervals_or_v0(a_bitmask.getAs("length"), a_bitmask, b_bitmask) } def intersection_possible(abi: Row, bbi: Row): Boolean = { val bitmask_a = rowToLiteBitmaskStruct(abi) val bitmask_b = rowToLiteBitmaskStruct(bbi) intersection_possible(bitmask_a, bitmask_b) } def intersection_possible(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = true if ((a_bitmask.bit_count == 0) | (b_bitmask.bit_count == 0)) { result = false } else if (a_bitmask.bit_idx_first > b_bitmask.bit_idx_last) { result = false } else if (a_bitmask.bit_idx_last < b_bitmask.bit_idx_first) { result = false } else { result = true } result } def getIntervalUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = { /* from get_intervals_until_all_consumed */ var aidx = 0 var bidx = 0 val exit_list = List(-2, -2) LazyList.continually { val amin: Int = if (aidx < alst.length) alst(aidx).head else -2 val amax: Int = if (aidx < alst.length) alst(aidx).last else -2 val bmin: Int = if (bidx < blst.length) blst(bidx).head else -2 val bmax: Int = if (bidx < blst.length) blst(bidx).last else -2 // println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax} bmin:${bmin} bmax:${bmax}") (amin, bmin) match { case (-2, -2) => exit_list case (_, -2) => aidx += 1 List(amin, amax) case (-2, _) => bidx += 1 List(bmin, bmax) case (_, _) if (amin != -2 & amin <= bmin) => aidx += 1 List(amin, amax) case _ => bidx += 1 List(bmin, bmax) } }.takeWhile(_ != exit_list) } def getIntervalBitsGenerator(a_bitmask: LiteBitmaskStruct, bit_idx_first: Int, bit_idx_last: Int): List[Int] = { /* does not work */ getIntervalBitsGenerator(a_bitmask.intervals, bit_idx_first, bit_idx_last) } def getIntervalBitsGenerator(intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): List[Int] = { /* does not work */ for { lst <- intervals i <- LazyList.range(lst.head, lst.last+1) } yield i } class IntervalIterator(intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) { var cur_interval_idx: Int = 0 var cur_interval: List[Int] = intervals(cur_interval_idx) var cur_range_idx: Int = 0 var cur_range: List[Int] = List.range(cur_interval.head, cur_interval.last+1) _jumpForward(bit_idx_first) def printState(): Unit = { println(s"bit_idx_first: ${bit_idx_first} bit_idx_last: ${bit_idx_last}") println(s"cur_interval_idx: ${cur_interval_idx} cur_interval: ${cur_interval}") println(s"cur_range_idx: ${cur_range_idx} cur_range: ${cur_range}") } def _jumpForward(i: Int): Unit = { var found = false while (!isEmpty & !found) { if (bit_idx_first > cur_interval.last) { _moveForwardInterval() } else { if (bit_idx_first > cur_range(cur_range_idx)){ _moveForwardRange() } else { found = true } } } } def _moveEnd(): Unit = { cur_interval_idx = -2 cur_range_idx = -2 } def _moveForwardInterval(): Unit = { cur_interval_idx += 1 cur_range_idx = 0 cur_interval = intervals(cur_interval_idx) cur_range = List.range(cur_interval.head, cur_interval.last+1) } def _moveForwardRange(): Unit = { cur_range_idx += 1 } def isEmpty: Boolean = { if (cur_interval_idx == -2 & cur_range_idx == -2){ true } else if (cur_interval_idx >= intervals.length & cur_range_idx >= cur_range.length){ true } else { false } } def nonEmpty: Boolean = { !isEmpty } def _move(): Unit = { if (cur_range_idx+1 >= cur_range.length){ if (cur_interval_idx+1 >= cur_interval.length) { _moveEnd() } else { _moveForwardInterval() } } else { _moveForwardRange() } } def get(): Int = { val i = cur_range(cur_range_idx) if (i >= bit_idx_last & bit_idx_last >= 0){ _moveEnd() } else { _move() } i } } // def getIntervalBitsGenerator(intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): Iterator[Int] = { // for (lst <- intervals) { // val start = lst(0) // val end = lst.last + 1 // LazyList.range(start, end) // } // } def getIntervalBits(a_bitmask: LiteBitmaskStruct, bit_idx_first: Int = 0, bit_idx_last: Int = -2): List[Int] = { val bit_lst: ListBuffer[Int] = ListBuffer() a_bitmask.intervals.foreach(elem => { for (i <- elem.head to elem.last) { // inclusive if (i < bit_idx_first){ } else if (bit_idx_last != -2 & (i > bit_idx_last)) { } else { bit_lst += i } } }) bit_lst.toList } def intersection_any_v0(abi: Row, bbi: Row): Boolean = { val bitmask_a = rowToLiteBitmaskStruct(abi) val bitmask_b = rowToLiteBitmaskStruct(bbi) intersection_any_v0(bitmask_a, bitmask_b) } def intersection_any_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = false if (!intersection_possible(a_bitmask, b_bitmask)) { result = false } else { val first_a = a_bitmask.bit_idx_first val last_a = a_bitmask.bit_idx_last val first_b = b_bitmask.bit_idx_first val last_b = b_bitmask.bit_idx_last if ((first_a == first_b) | (last_a == last_b)) { result = true } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) { result = true } else { val bit_idx_first = if (first_a >= first_b) first_a else first_b val bit_idx_last = if (last_a < last_b) last_a else last_b var a_lst = getIntervalBits(a_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last) var b_lst = getIntervalBits(b_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last) if (a_lst.isEmpty | b_lst.isEmpty){ } else { var a_idx = a_lst.head a_lst = a_lst.tail var b_idx = b_lst.head b_lst = b_lst.tail while (a_idx >= 0 & b_idx >= 0) { if (a_idx == b_idx) { result = true a_idx = -2 b_idx = -2 } else if (a_idx > b_idx){ // move b_idx if (b_lst.nonEmpty) { b_idx = b_lst.head b_lst = b_lst.tail } else { b_idx = -2 } } else { // move a_idx if (a_lst.nonEmpty){ a_idx = a_lst.head a_lst = a_lst.tail } else { a_idx = -2 } } } } } } result } def intersection_any_v1(abi: Row, bbi: Row): Boolean = { val bitmask_a = rowToLiteBitmaskStruct(abi) val bitmask_b = rowToLiteBitmaskStruct(bbi) intersection_any_v1(bitmask_a, bitmask_b) } def intersection_any_v1(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = { var result = false if (!intersection_possible(a_bitmask, b_bitmask)) { result = false } else { val first_a = a_bitmask.bit_idx_first val last_a = a_bitmask.bit_idx_last val first_b = b_bitmask.bit_idx_first val last_b = b_bitmask.bit_idx_last if ((first_a == first_b) | (last_a == last_b)) { result = true } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) { result = true } else { val bit_idx_first = if (first_a >= first_b) first_a else first_b val bit_idx_last = if (last_a < last_b) last_a else last_b val a_ii = new IntervalIterator(a_bitmask.intervals, bit_idx_first, bit_idx_last) val b_ii = new IntervalIterator(b_bitmask.intervals, bit_idx_first, bit_idx_last) if (a_ii.isEmpty | b_ii.isEmpty){ } else { var a_idx = a_ii.get() var b_idx = b_ii.get() while (a_idx >= 0 & b_idx >= 0) { if (a_idx == b_idx) { result = true a_idx = -2 b_idx = -2 } else if (a_idx > b_idx){ // move b_idx if (b_ii.nonEmpty) { b_idx = b_ii.get() } else { b_idx = -2 } } else { // move a_idx if (a_ii.nonEmpty){ a_idx = a_ii.get() } else { a_idx = -2 } } } } } } result } val intersection_any: (Row, Row) => Boolean = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match { case (null, null) => false case (null, _) => false case (_, null) => false // case (_, _) => intersection_any_v0(a_bitmask, b_bitmask) case (_, _) => intersection_any_v1(a_bitmask, b_bitmask) } def registerAll(sparkSession: SparkSession) { sparkSession.udf.register("logical_or", logical_or(_: Row, _: Row), LiteBitmaskSchema) sparkSession.udf.register("intersection_any", intersection_any(_: Row, _: Row), LiteBitmaskSchema) } }