Newer
Older
/*
* Copyright (C) 2024, UChicago Argonne, LLC
* Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
* in the top-level directory.
*/
package Octeres

Eric Pershey
committed
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema

Eric Pershey
committed
import org.apache.spark.sql.types._

Eric Pershey
committed
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.ListBuffer

Eric Pershey
committed
case class LiteBitmaskStruct(length: Int, bit_count: Int, intervals: List[List[Int]],
bit_idx_first: Int, bit_idx_last: Int)
val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
// val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType]
val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)
def getLiteBitmaskRow(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): Row = {
new GenericRowWithSchema(Array(length, bit_count, intervals, bit_idx_first, bit_idx_last), LiteBitmaskSchema)
}

Eric Pershey
committed
def getCombinedIntervalsUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = {
var aidx = 0
var bidx = 0
val exit_list = List(-2, -2)
LazyList.continually {
val amin: Int = if (aidx < alst.length) alst(aidx)(0) else -2
val amax: Int = if (aidx < alst.length) alst(aidx)(1) else -2
val bmin: Int = if (bidx < blst.length) blst(bidx)(0) else -2
val bmax: Int = if (bidx < blst.length) blst(bidx)(1) else -2
(amin, bmin) match {
case (-2, -2) =>
exit_list
case (_, -2) =>
aidx += 1
List(amin, amax)
case (-2, _) =>
bidx += 1
List(bmin, bmax)
case (_, _) if (amin != -2 & amin <= bmin) =>
aidx += 1
List(amin, amax)
case _ =>
exit_list
}
}.takeWhile(_ != exit_list)
}
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class BitmaskError(message: String) extends Exception(message) {
def this(message: String, cause: Throwable = null) {
this(message)
initCause(cause)
}
def this(cause: Throwable) = {
this(Option(cause).map(_.toString).orNull, cause)
}
def this() = {
this(null: String)
}
}
def validateLiteBitmaskSlotsLike(row: Row): Unit = {
val obj = rowToLiteBitmaskStruct(row)
validateLiteBitmaskSlotsLike(obj)
}
def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
val intervals = obj.intervals
if (obj.bit_count > 0) {
// println(s"obj.bit_count > 0: ${intervals} ${obj.bit_idx_first} ${obj.bit_idx_last}")
// println(s"obj.bit_count > 0: ${intervals.head}")
// println(s"obj.bit_count > 0: ${intervals.head.head}")
if (obj.bit_idx_first != intervals.head.head) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
if (obj.bit_idx_last != intervals.last.last) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
} else {
if (obj.bit_idx_first != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
if (obj.bit_idx_last != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
}
// val len_intervals = intervals.length
var bit_count = 0
for ((alar, i) <- intervals.zipWithIndex) {
val al = alar.head
val ar = alar.last
val sub_bit_count = (ar - al) + 1
if (sub_bit_count < 1) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
bit_count += sub_bit_count
if (al > ar) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
val blbr: Option[List[Int]] = intervals.lift(i + 1)
blbr match {
case Some(x: List[Int]) => {
val bl: Int = x.head
val br: Int = x.last
if ((al == bl) & (ar == br)) {
throw new BitmaskError(s"duplicate interval: [$al, $ar]")
}
if ((ar + 1) == bl) {
throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
}
if (ar > bl) {
throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
}
}
case None => Nil
}
}
if (bit_count != obj.bit_count) {
throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
}
}

Eric Pershey
committed
def intervals_or_v0(length: Int, abi: Row, bbi: Row): LiteBitmaskStruct = {
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intervals_or_v0(length, bitmask_a, bitmask_b)
}
def intervals_or_v0(length: Int, bitmask_a: LiteBitmaskStruct, bitmask_b: LiteBitmaskStruct): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */

Eric Pershey
committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
val intervalGen: Seq[List[Int]] = getCombinedIntervalsUntilAllConsumed(
bitmask_a.intervals,
bitmask_b.intervals)
val intervals = ListBuffer.empty[List[Int]]
var bitCount = 0
var bitIdxFirst = -1
var bitIdxLast = -1
var prev_start = -1
var prev_end = -1
val itr = intervalGen.iterator
for (List(next_start, next_end) <- itr) {
(next_start, next_end) match {
case (-2, -2) => // case where there is nothing
bitCount = 0
bitIdxFirst = -1
bitIdxLast = -1
case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
prev_start = next_start
prev_end = next_end
intervals += List(prev_start, prev_end)
bitCount = prev_end - prev_start + 1
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
case _ =>
if (next_start <= prev_end + 1) {
val new_end = Math.max(prev_end, next_end)
intervals(intervals.length - 1) = List(prev_start, new_end)
bitCount += new_end - prev_end
prev_end = new_end
} else {
intervals += List(next_start, next_end)
bitCount += next_end - next_start + 1
prev_start = next_start
prev_end = next_end
}
bitIdxFirst = intervals.head.head
bitIdxLast = intervals.last.last
}
}
LiteBitmaskStruct(length, bitCount, intervals.map(_.toList).toList, bitIdxFirst, bitIdxLast)
}
def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
/* convert a LiteBitmaskStruct to a Row */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count,
bitmaskStruct.intervals, bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last), LiteBitmaskSchema)
row
}

Eric Pershey
committed
def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
/* convert a Row to a LiteBitmaskStruct */

Eric Pershey
committed
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val bitmaskStruct: LiteBitmaskStruct = row match {
case Row(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
LiteBitmaskStruct(length, bit_count, intervals, bit_idx_first, bit_idx_last)
case Row(length: Int, bit_count: Int, intervals: ArraySeq[ArraySeq[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
LiteBitmaskStruct(length, bit_count, intervals.toList.map(_.toList), bit_idx_first, bit_idx_last)
case _ => LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4))

Eric Pershey
committed
}
bitmaskStruct
}
val logical_or: (Row, Row) => LiteBitmaskStruct = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match {
case (null, null) => null
case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
case (_, _) => intervals_or_v0(a_bitmask.getAs("length"), a_bitmask, b_bitmask)
}
def registerAll(sparkSession: SparkSession) {
sparkSession.udf.register("logical_or", logical_or(_: Row, _: Row), LiteBitmaskSchema)
}