/* * Copyright (C) 2024, UChicago Argonne, LLC * Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file * in the top-level directory. */ package Octeres import Octeres.DataUDF.{LiteBitmaskSchema, intersection_any, liteBitmaskStructToRow, logical_or, validateLiteBitmaskSlotsLike} import Octeres.Bitmask.BitmaskLite.{intersection_any_v0, intersection_any_v1, IntervalIterator} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Encoders, Row, SparkSession} import org.scalatest.funsuite.AnyFunSuite import org.scalatest.prop.TableDrivenPropertyChecks._ import org.scalatest.matchers.should.Matchers._ import org.scalatest.prop.TableFor3 import scala.collection.immutable.ArraySeq object ExampleUDF { // very simple UDF function as an example. case class SimpleAddStruct(a: Int, b: Int) val SimpleAddSchema: StructType = Encoders.product[SimpleAddStruct].schema def simpleAdd(a: Int, b: Int): Int = a + b } class DataUDFTest extends AnyFunSuite { test("test_add") { val result = ExampleUDF.simpleAdd(1, 2) assert(result == 3) } test("test_simple") { val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate() import sparkSession.implicits._ val deck = Seq( Row(1, 2), ) var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = ExampleUDF.SimpleAddSchema) sparkSession.udf.register("SimpleAdd", ExampleUDF.simpleAdd(_: Int, _: Int), IntegerType) df = df.withColumn("c", expr("SimpleAdd(a, b)")) df.show() } test("test_logical_or_00") { println("test_logical_or_00") val bitmask_length = 8 val a_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2) val b_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3) println(s"a_bitmask=${a_bitmask}") println(s"b_bitmask=${b_bitmask}") println(DataUDF.logical_or(a_bitmask, b_bitmask)) println(DataUDF.logical_or(null, b_bitmask)) println(DataUDF.logical_or(a_bitmask, null)) } test("test_logical_or_01") { println("test_logical_or_00") val bitmask_length = 8 val a_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2) val b_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3) println(s"a_bitmask=${a_bitmask}") println(s"b_bitmask=${b_bitmask}") println(DataUDF.logical_or(a_bitmask, b_bitmask)) println(DataUDF.logical_or(null, b_bitmask)) println(DataUDF.logical_or(a_bitmask, null)) } test("test_logical_or_02") { println("test_logical_or_01") val bitmask_length = 8 val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2) val b_bitmask = DataUDF.getLiteBitmaskZeros(bitmask_length) // val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 0, ArraySeq[Long](), -1, -1) println(s"a_bitmask=${a_bitmask}") println(s"b_bitmask=${b_bitmask}") println(DataUDF.logical_or(a_bitmask, b_bitmask)) println(DataUDF.logical_or(null, b_bitmask)) println(DataUDF.logical_or(a_bitmask, null)) println(DataUDF.logical_or(null, null)) println(DataUDF.logical_or(a_bitmask, a_bitmask)) println(DataUDF.logical_or(b_bitmask, b_bitmask)) } test("test_bitmask") { println("test_bitmask") val bitmask_length = 8 // val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, Array(Array(0, 2)), 0, 2) // val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, Array(Array(1, 3)), 1, 3) val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2) val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3) val c_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4) val d_bitmask = DataUDF.getLiteBitmaskZeros(bitmask_length) val e_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 6, ArraySeq(ArraySeq[Long](0, 5)), 0, 5) val schema = StructType(List( StructField("a", IntegerType, nullable = true), StructField("b", IntegerType, nullable = true), StructField("bitmask", LiteBitmaskSchema, nullable = true) )) val deck = Seq( Row(0, 1, null), // cannot have null. Row(1, 2, a_bitmask), Row(2, 4, b_bitmask), Row(3, 8, c_bitmask), Row(4, 16, d_bitmask), Row(5, 32, e_bitmask), ) val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate() val df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema) DataUDF.registerAll() df.createOrReplaceTempView("temp_view_a") df.createOrReplaceTempView("temp_view_b") df.show(32, truncate=false) println("selecting from view") sparkSession.sql("select bitmask.intervals from temp_view_a").show(32, truncate=false) val df2 = sparkSession.sql("Select * from temp_view_a") // https://spark.apache.org/docs/latest/sql-ref-datatypes.html println("selecting from view again") df2.show(32, truncate=false) println("direct intersection_any") df2.withColumn("r_bitmask", expr("intersection_any(bitmask, bitmask)")).show(32, truncate=false) println("direct logical_or") df2.withColumn("r_bitmask", expr("logical_or(bitmask, bitmask)")).show(32, truncate=false) println("about to select..") val df3 = sparkSession.sql("select tva.a, tva.b, " + "tva.bitmask as a_bitmask, " + "tvb.bitmask as b_bitmask, " + "logical_or(tva.bitmask, tvb.bitmask) as c_bitmask " + // "intersection_any(tva.bitmask, tvb.bitmask) " + "from temp_view_a tva join temp_view_b tvb" ) // BROKEN |2 |4 |{8, 3, [[1, 3]], 1, 3}|{8, 6, [[0, 5]], 0, 5}|{8, 0, [], -1, -1} | df3.show(32, truncate=false) df3.printSchema() } test("test_bitmask_logical_or") { val bitmaskTable: TableFor3[Row, Row, Row] = Table( ("bitmask0", "bitmask1", "bitmask_or"), ( DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2), DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3), DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 3)), 0, 3) ), ( DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4), DataUDF.getLiteBitmaskZeros(8), DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4) ), ( DataUDF.getLiteBitmaskZeros(8), DataUDF.getLiteBitmaskZeros(8), DataUDF.getLiteBitmaskZeros(8), ), ( DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0, 4), DataUDF.getLiteBitmaskRow(8, 2, ArraySeq(ArraySeq[Long](1, 2)), 1, 2), DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4) // ), ( // this will be a problem 8 is the max. // DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6), // DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 4), // DataUDF.getLiteBitmaskRow(8, 8, ArraySeq(ArraySeq[Long](1, 8)), 1, 8) ), ( DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6), DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 8), DataUDF.getLiteBitmaskRow(16, 8, ArraySeq(ArraySeq[Long](1, 8)), 1, 8) ) ) forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: Row) => { println(s"bitmask0: ${bitmask0}") println(s"bitmask1: ${bitmask1}") println(s"bitmask_or: ${bitmask_or}") validateLiteBitmaskSlotsLike(bitmask0) validateLiteBitmaskSlotsLike(bitmask1) validateLiteBitmaskSlotsLike(bitmask_or) val bitmask_r = logical_or(bitmask0, bitmask1) val bitmask_rr = liteBitmaskStructToRow(bitmask_r) println(s"bitmask_rr: ${bitmask_rr}") validateLiteBitmaskSlotsLike(bitmask_rr) bitmask_rr should equal(bitmask_or) println(intersection_any(bitmask0, bitmask1)) } } } test("test_getIntervalBitsGenerator") { val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table( ("intervals", "bit_idx_first", "bit_idx_last"), (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1), (ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9), (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6), (ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6) ) forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => { { println(intervals) val ii = new IntervalIterator(intervals, bit_idx_first, bit_idx_last) ii.printState() while (!ii.isEmpty) { println(s"i: ${ii.get()}") } } }} } test("test_bitmask_intersection_any") { val bitmaskTable: TableFor3[Row, Row, Boolean] = Table( ("bitmask0", "bitmask1", "logical_truth"), ( DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2), DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3), true ), ( DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4), DataUDF.getLiteBitmaskZeros(8), false ), ( DataUDF.getLiteBitmaskZeros(8), DataUDF.getLiteBitmaskZeros(8), false ), ( DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0, 4), DataUDF.getLiteBitmaskRow(8, 2, ArraySeq(ArraySeq[Long](1, 2)), 1, 2), true ), ( DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6), DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 8), false ) ) forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => { println(s"bitmask0: ${bitmask0}") println(s"bitmask1: ${bitmask1}") validateLiteBitmaskSlotsLike(bitmask0) validateLiteBitmaskSlotsLike(bitmask1) val truth = intersection_any(bitmask0, bitmask1) println(s"truth: ${truth}") truth should equal(logical_truth) } } } }