/* * Copyright (C) 2024, UChicago Argonne, LLC * Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file * in the top-level directory. */ package Octeres import org.apache.spark.sql.SparkSession import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ import scala.collection.immutable.ArraySeq import scala.collection.mutable.ListBuffer object DataUDF { case class LiteBitmaskStruct(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema // val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType] val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass) def getLiteBitmaskRow(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): Row = { new GenericRowWithSchema(Array(length, bit_count, intervals, bit_idx_first, bit_idx_last), LiteBitmaskSchema) } def getCombinedIntervalsUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = { var aidx = 0 var bidx = 0 val exit_list = List(-2, -2) LazyList.continually { val amin: Int = if (aidx < alst.length) alst(aidx)(0) else -2 val amax: Int = if (aidx < alst.length) alst(aidx)(1) else -2 val bmin: Int = if (bidx < blst.length) blst(bidx)(0) else -2 val bmax: Int = if (bidx < blst.length) blst(bidx)(1) else -2 (amin, bmin) match { case (-2, -2) => exit_list case (_, -2) => aidx += 1 List(amin, amax) case (-2, _) => bidx += 1 List(bmin, bmax) case (_, _) if (amin != -2 & amin <= bmin) => aidx += 1 List(amin, amax) case _ => exit_list } }.takeWhile(_ != exit_list) } class BitmaskError(message: String) extends Exception(message) { def this(message: String, cause: Throwable = null) { this(message) initCause(cause) } def this(cause: Throwable) = { this(Option(cause).map(_.toString).orNull, cause) } def this() = { this(null: String) } } def validateLiteBitmaskSlotsLike(row: Row): Unit = { val obj = rowToLiteBitmaskStruct(row) validateLiteBitmaskSlotsLike(obj) } def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = { val intervals = obj.intervals if (obj.bit_count > 0) { // println(s"obj.bit_count > 0: ${intervals} ${obj.bit_idx_first} ${obj.bit_idx_last}") // println(s"obj.bit_count > 0: ${intervals.head}") // println(s"obj.bit_count > 0: ${intervals.head.head}") if (obj.bit_idx_first != intervals.head.head) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}") } if (obj.bit_idx_last != intervals.last.last) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}") } } else { if (obj.bit_idx_first != -1) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1") } if (obj.bit_idx_last != -1) { throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1") } } // val len_intervals = intervals.length var bit_count = 0 for ((alar, i) <- intervals.zipWithIndex) { val al = alar.head val ar = alar.last val sub_bit_count = (ar - al) + 1 if (sub_bit_count < 1) { throw new BitmaskError(s"negative interval: [$al, $ar]") } bit_count += sub_bit_count if (al > ar) { throw new BitmaskError(s"negative interval: [$al, $ar]") } val blbr: Option[List[Int]] = intervals.lift(i + 1) blbr match { case Some(x: List[Int]) => { val bl: Int = x.head val br: Int = x.last if ((al == bl) & (ar == br)) { throw new BitmaskError(s"duplicate interval: [$al, $ar]") } if ((ar + 1) == bl) { throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr") } if (ar > bl) { throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr") } } case None => Nil } } if (bit_count != obj.bit_count) { throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}") } } def intervals_or_v0(length: Int, abi: Row, bbi: Row): LiteBitmaskStruct = { val bitmask_a = rowToLiteBitmaskStruct(abi) val bitmask_b = rowToLiteBitmaskStruct(bbi) intervals_or_v0(length, bitmask_a, bitmask_b) } def intervals_or_v0(length: Int, bitmask_a: LiteBitmaskStruct, bitmask_b: LiteBitmaskStruct): LiteBitmaskStruct = { /* does a logical or of two bit intervals * From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */ val intervalGen: Seq[List[Int]] = getCombinedIntervalsUntilAllConsumed( bitmask_a.intervals, bitmask_b.intervals) val intervals = ListBuffer.empty[List[Int]] var bitCount = 0 var bitIdxFirst = -1 var bitIdxLast = -1 var prev_start = -1 var prev_end = -1 val itr = intervalGen.iterator for (List(next_start, next_end) <- itr) { (next_start, next_end) match { case (-2, -2) => // case where there is nothing bitCount = 0 bitIdxFirst = -1 bitIdxLast = -1 case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous prev_start = next_start prev_end = next_end intervals += List(prev_start, prev_end) bitCount = prev_end - prev_start + 1 bitIdxFirst = intervals.head.head bitIdxLast = intervals.last.last case _ => if (next_start <= prev_end + 1) { val new_end = Math.max(prev_end, next_end) intervals(intervals.length - 1) = List(prev_start, new_end) bitCount += new_end - prev_end prev_end = new_end } else { intervals += List(next_start, next_end) bitCount += next_end - next_start + 1 prev_start = next_start prev_end = next_end } bitIdxFirst = intervals.head.head bitIdxLast = intervals.last.last } } LiteBitmaskStruct(length, bitCount, intervals.map(_.toList).toList, bitIdxFirst, bitIdxLast) } def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = { /* convert a LiteBitmaskStruct to a Row */ // might need this somehow: import sparkSession.implicits._ // row.asInstanceOf[LiteBitmaskStruct] // does not work val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count, bitmaskStruct.intervals, bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last), LiteBitmaskSchema) row } def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = { /* convert a Row to a LiteBitmaskStruct */ // might need this somehow: import sparkSession.implicits._ // row.asInstanceOf[LiteBitmaskStruct] // does not work val bitmaskStruct: LiteBitmaskStruct = row match { case Row(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) => LiteBitmaskStruct(length, bit_count, intervals, bit_idx_first, bit_idx_last) case Row(length: Int, bit_count: Int, intervals: ArraySeq[ArraySeq[Int]], bit_idx_first: Int, bit_idx_last: Int) => LiteBitmaskStruct(length, bit_count, intervals.toList.map(_.toList), bit_idx_first, bit_idx_last) case _ => LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4)) } bitmaskStruct } val logical_or: (Row, Row) => LiteBitmaskStruct = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match { case (null, null) => null case (null, _) => rowToLiteBitmaskStruct(b_bitmask) case (_, null) => rowToLiteBitmaskStruct(a_bitmask) case (_, _) => intervals_or_v0(a_bitmask.getAs("length"), a_bitmask, b_bitmask) } def registerAll(sparkSession: SparkSession) { sparkSession.udf.register("logical_or", logical_or(_: Row, _: Row), LiteBitmaskSchema) } }