Something went wrong on our end
-
Eric Pershey authoredEric Pershey authored
DataUDF.scala 9.41 KiB
/*
* Copyright (C) 2024, UChicago Argonne, LLC
* Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
* in the top-level directory.
*/
package Octeres
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.ListBuffer
object DataUDF {
case class LiteBitmaskStruct(length: Int, bit_count: Int, intervals: List[List[Int]],
bit_idx_first: Int, bit_idx_last: Int)
val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
// val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType]
val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)
def getLiteBitmaskRow(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): Row = {
new GenericRowWithSchema(Array(length, bit_count, intervals, bit_idx_first, bit_idx_last), LiteBitmaskSchema)
}
def getCombinedIntervalsUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = {
var aidx = 0
var bidx = 0
val exit_list = List(-2, -2)
LazyList.continually {
val amin: Int = if (aidx < alst.length) alst(aidx)(0) else -2
val amax: Int = if (aidx < alst.length) alst(aidx)(1) else -2
val bmin: Int = if (bidx < blst.length) blst(bidx)(0) else -2
val bmax: Int = if (bidx < blst.length) blst(bidx)(1) else -2
(amin, bmin) match {
case (-2, -2) =>
exit_list
case (_, -2) =>
aidx += 1
List(amin, amax)
case (-2, _) =>
bidx += 1
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1
List(amin, amax)
case _ =>
exit_list
}
}.takeWhile(_ != exit_list)
}
class BitmaskError(message: String) extends Exception(message) {
def this(message: String, cause: Throwable = null) {
this(message)
initCause(cause)
}
def this(cause: Throwable) = {
this(Option(cause).map(_.toString).orNull, cause)
}
def this() = {
this(null: String)
}
}
def validateLiteBitmaskSlotsLike(row: Row): Unit = {
val obj = rowToLiteBitmaskStruct(row)
validateLiteBitmaskSlotsLike(obj)
}
def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
val intervals = obj.intervals
if (obj.bit_count > 0) {
// println(s"obj.bit_count > 0: ${intervals} ${obj.bit_idx_first} ${obj.bit_idx_last}")
// println(s"obj.bit_count > 0: ${intervals.head}")
// println(s"obj.bit_count > 0: ${intervals.head.head}")
if (obj.bit_idx_first != intervals.head.head) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
if (obj.bit_idx_last != intervals.last.last) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
} else {
if (obj.bit_idx_first != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
if (obj.bit_idx_last != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
}
// val len_intervals = intervals.length
var bit_count = 0
for ((alar, i) <- intervals.zipWithIndex) {
val al = alar.head
val ar = alar.last
val sub_bit_count = (ar - al) + 1
if (sub_bit_count < 1) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
bit_count += sub_bit_count
if (al > ar) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
val blbr: Option[List[Int]] = intervals.lift(i + 1)
blbr match {
case Some(x: List[Int]) => {
val bl: Int = x.head
val br: Int = x.last
if ((al == bl) & (ar == br)) {
throw new BitmaskError(s"duplicate interval: [$al, $ar]")
}
if ((ar + 1) == bl) {
throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
}
if (ar > bl) {
throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
}
}
case None => Nil
}
}
if (bit_count != obj.bit_count) {
throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
}
}
def intervals_or_v0(length: Int, abi: Row, bbi: Row): LiteBitmaskStruct = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intervals_or_v0(length, bitmask_a, bitmask_b)
}
def intervals_or_v0(length: Int, bitmask_a: LiteBitmaskStruct, bitmask_b: LiteBitmaskStruct): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val intervalGen: Seq[List[Int]] = getCombinedIntervalsUntilAllConsumed(
bitmask_a.intervals,
bitmask_b.intervals)
val intervals = ListBuffer.empty[List[Int]]
var bitCount = 0
var bitIdxFirst = -1
var bitIdxLast = -1
var prev_start = -1
var prev_end = -1
val itr = intervalGen.iterator
for (List(next_start, next_end) <- itr) {
(next_start, next_end) match {
case (-2, -2) => // case where there is nothing
bitCount = 0
bitIdxFirst = -1
bitIdxLast = -1
case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
prev_start = next_start
prev_end = next_end
intervals += List(prev_start, prev_end)
bitCount = prev_end - prev_start + 1
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
case _ =>
if (next_start <= prev_end + 1) {
val new_end = Math.max(prev_end, next_end)
intervals(intervals.length - 1) = List(prev_start, new_end)
bitCount += new_end - prev_end
prev_end = new_end
} else {
intervals += List(next_start, next_end)
bitCount += next_end - next_start + 1
prev_start = next_start
prev_end = next_end
}
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
}
}
LiteBitmaskStruct(length, bitCount, intervals.map(_.toList).toList, bitIdxFirst, bitIdxLast)
}
def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
/* convert a LiteBitmaskStruct to a Row */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count,
bitmaskStruct.intervals, bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last), LiteBitmaskSchema)
row
}
def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
/* convert a Row to a LiteBitmaskStruct */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val bitmaskStruct: LiteBitmaskStruct = row match {
case Row(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
LiteBitmaskStruct(length, bit_count, intervals, bit_idx_first, bit_idx_last)
case Row(length: Int, bit_count: Int, intervals: ArraySeq[ArraySeq[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
LiteBitmaskStruct(length, bit_count, intervals.toList.map(_.toList), bit_idx_first, bit_idx_last)
case _ => LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4))
}
bitmaskStruct
}
val logical_or: (Row, Row) => LiteBitmaskStruct = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match {
case (null, null) => null
case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
case (_, _) => intervals_or_v0(a_bitmask.getAs("length"), a_bitmask, b_bitmask)
}
def registerAll(sparkSession: SparkSession) {
sparkSession.udf.register("logical_or", logical_or(_: Row, _: Row), LiteBitmaskSchema)
}
}