Newer
Older
/*
* Copyright (C) 2024, UChicago Argonne, LLC
* Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
* in the top-level directory.
*/
package Octeres
import Octeres.DataUDF.{LiteBitmaskSchema, intersection_any, liteBitmaskStructToRow, logical_or, validateLiteBitmaskSlotsLike}
import Octeres.Bitmask.BitmaskLite.{intersection_any_v0, intersection_any_v1, IntervalIterator}

Eric Pershey
committed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoders, Row, SparkSession}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._

Eric Pershey
committed
import org.scalatest.prop.TableFor3
import scala.collection.immutable.ArraySeq

Eric Pershey
committed
object ExampleUDF {
// very simple UDF function as an example.
case class SimpleAddStruct(a: Int, b: Int)
val SimpleAddSchema: StructType = Encoders.product[SimpleAddStruct].schema
def simpleAdd(a: Int, b: Int): Int = a + b
}
class DataUDFTest extends AnyFunSuite {

Eric Pershey
committed
test("test_add") {
val result = ExampleUDF.simpleAdd(1, 2)
assert(result == 3)
}
test("test_simple") {
val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
import sparkSession.implicits._
val deck = Seq(
Row(1, 2),
)
var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = ExampleUDF.SimpleAddSchema)
sparkSession.udf.register("SimpleAdd", ExampleUDF.simpleAdd(_: Int, _: Int), IntegerType)
df = df.withColumn("c", expr("SimpleAdd(a, b)"))
df.show()
}
test("test_logical_or_00") {
println("test_logical_or_00")
val bitmask_length = 8

Eric Pershey
committed
val a_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
val b_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3)

Eric Pershey
committed
println(s"a_bitmask=${a_bitmask}")
println(s"b_bitmask=${b_bitmask}")
println(DataUDF.logical_or(a_bitmask, b_bitmask))
println(DataUDF.logical_or(null, b_bitmask))
println(DataUDF.logical_or(a_bitmask, null))
}
test("test_logical_or_01") {

Eric Pershey
committed
println("test_logical_or_00")
val bitmask_length = 8
val a_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
val b_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3)
println(s"a_bitmask=${a_bitmask}")
println(s"b_bitmask=${b_bitmask}")
println(DataUDF.logical_or(a_bitmask, b_bitmask))
println(DataUDF.logical_or(null, b_bitmask))
println(DataUDF.logical_or(a_bitmask, null))
}
test("test_logical_or_02") {

Eric Pershey
committed
println("test_logical_or_01")
val bitmask_length = 8

Eric Pershey
committed
val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
val b_bitmask = DataUDF.getLiteBitmaskZeros(bitmask_length)
// val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 0, ArraySeq[Long](), -1, -1)

Eric Pershey
committed
println(s"a_bitmask=${a_bitmask}")
println(s"b_bitmask=${b_bitmask}")
println(DataUDF.logical_or(a_bitmask, b_bitmask))
println(DataUDF.logical_or(null, b_bitmask))
println(DataUDF.logical_or(a_bitmask, null))
println(DataUDF.logical_or(null, null))
println(DataUDF.logical_or(a_bitmask, a_bitmask))
println(DataUDF.logical_or(b_bitmask, b_bitmask))
}
test("test_bitmask") {
println("test_bitmask")
val bitmask_length = 8

Eric Pershey
committed
// val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, Array(Array(0, 2)), 0, 2)
// val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, Array(Array(1, 3)), 1, 3)
val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3)
val c_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4)
val d_bitmask = DataUDF.getLiteBitmaskZeros(bitmask_length)
val e_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 6, ArraySeq(ArraySeq[Long](0, 5)), 0, 5)

Eric Pershey
committed
val schema = StructType(List(
StructField("a", IntegerType, nullable = true),
StructField("b", IntegerType, nullable = true),
StructField("bitmask", LiteBitmaskSchema, nullable = true)
))
val deck = Seq(

Eric Pershey
committed
Row(1, 2, a_bitmask),
Row(2, 4, b_bitmask),
Row(3, 8, c_bitmask),
Row(4, 16, d_bitmask),
Row(5, 32, e_bitmask),
)
val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
val df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema)

Eric Pershey
committed
df.createOrReplaceTempView("temp_view_a")
df.createOrReplaceTempView("temp_view_b")
df.show(32, truncate=false)

Eric Pershey
committed
println("selecting from view")

Eric Pershey
committed
sparkSession.sql("select bitmask.intervals from temp_view_a").show(32, truncate=false)
val df2 = sparkSession.sql("Select * from temp_view_a")

Eric Pershey
committed
// https://spark.apache.org/docs/latest/sql-ref-datatypes.html

Eric Pershey
committed
println("selecting from view again")

Eric Pershey
committed
df2.show(32, truncate=false)
println("direct intersection_any")
df2.withColumn("r_bitmask", expr("intersection_any(bitmask, bitmask)")).show(32, truncate=false)

Eric Pershey
committed
println("direct logical_or")
df2.withColumn("r_bitmask", expr("logical_or(bitmask, bitmask)")).show(32, truncate=false)

Eric Pershey
committed
println("about to select..")
val df3 = sparkSession.sql("select tva.a, tva.b, " +

Eric Pershey
committed
"tva.bitmask as a_bitmask, " +
"tvb.bitmask as b_bitmask, " +
"logical_or(tva.bitmask, tvb.bitmask) as c_bitmask " +
// "intersection_any(tva.bitmask, tvb.bitmask) " +
"from temp_view_a tva join temp_view_b tvb"
)
// BROKEN |2 |4 |{8, 3, [[1, 3]], 1, 3}|{8, 6, [[0, 5]], 0, 5}|{8, 0, [], -1, -1} |

Eric Pershey
committed
df3.show(32, truncate=false)
df3.printSchema()
}

Eric Pershey
committed
test("test_bitmask_logical_or") {

Eric Pershey
committed
val bitmaskTable: TableFor3[Row, Row, Row] = Table(
("bitmask0", "bitmask1", "bitmask_or"),
(

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2),
DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3),
DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 3)), 0, 3)

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4),
DataUDF.getLiteBitmaskZeros(8),
DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4)

Eric Pershey
committed
DataUDF.getLiteBitmaskZeros(8),
DataUDF.getLiteBitmaskZeros(8),
DataUDF.getLiteBitmaskZeros(8),

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0, 4),
DataUDF.getLiteBitmaskRow(8, 2, ArraySeq(ArraySeq[Long](1, 2)), 1, 2),
DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4)

Eric Pershey
committed
// ), ( // this will be a problem 8 is the max.

Eric Pershey
committed
// DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6),
// DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 4),
// DataUDF.getLiteBitmaskRow(8, 8, ArraySeq(ArraySeq[Long](1, 8)), 1, 8)

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6),
DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 8),
DataUDF.getLiteBitmaskRow(16, 8, ArraySeq(ArraySeq[Long](1, 8)), 1, 8)
)
)
forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: Row) => {
println(s"bitmask0: ${bitmask0}")
println(s"bitmask1: ${bitmask1}")
println(s"bitmask_or: ${bitmask_or}")
validateLiteBitmaskSlotsLike(bitmask0)
validateLiteBitmaskSlotsLike(bitmask1)
validateLiteBitmaskSlotsLike(bitmask_or)
val bitmask_r = logical_or(bitmask0, bitmask1)
val bitmask_rr = liteBitmaskStructToRow(bitmask_r)
println(s"bitmask_rr: ${bitmask_rr}")
validateLiteBitmaskSlotsLike(bitmask_rr)
bitmask_rr should equal(bitmask_or)

Eric Pershey
committed
println(intersection_any(bitmask0, bitmask1))

Eric Pershey
committed
}
test("test_getIntervalBitsGenerator") {

Eric Pershey
committed
val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table(

Eric Pershey
committed
("intervals", "bit_idx_first", "bit_idx_last"),

Eric Pershey
committed
(ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1),
(ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9),
(ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6),
(ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6)

Eric Pershey
committed
)
forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
{
println(intervals)
val ii = new IntervalIterator(intervals, bit_idx_first, bit_idx_last)

Eric Pershey
committed
ii.printState()
while (!ii.isEmpty) {
println(s"i: ${ii.get()}")
}
}
}}
}
test("test_bitmask_intersection_any") {

Eric Pershey
committed
val bitmaskTable: TableFor3[Row, Row, Boolean] = Table(

Eric Pershey
committed
("bitmask0", "bitmask1", "logical_truth"),
(

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2),
DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3),

Eric Pershey
committed
true
), (

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4),
DataUDF.getLiteBitmaskZeros(8),

Eric Pershey
committed
false
), (

Eric Pershey
committed
DataUDF.getLiteBitmaskZeros(8),
DataUDF.getLiteBitmaskZeros(8),

Eric Pershey
committed
false
), (

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0, 4),
DataUDF.getLiteBitmaskRow(8, 2, ArraySeq(ArraySeq[Long](1, 2)), 1, 2),

Eric Pershey
committed
true
), (

Eric Pershey
committed
DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6),
DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 8),

Eric Pershey
committed
false
)
)

Eric Pershey
committed
forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => {
println(s"bitmask0: ${bitmask0}")
println(s"bitmask1: ${bitmask1}")
validateLiteBitmaskSlotsLike(bitmask0)
validateLiteBitmaskSlotsLike(bitmask1)
val truth = intersection_any(bitmask0, bitmask1)
println(s"truth: ${truth}")
truth should equal(logical_truth)
}
}