/* * Copyright (C) 2024, UChicago Argonne, LLC * Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file * in the top-level directory. */ package Octeres import Octeres.DataUDF.{LiteBitmaskSchema, logical_or} import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Encoders, Row, SparkSession} import org.scalatest.funsuite.AnyFunSuite object ExampleUDF { // very simple UDF function as an example. case class SimpleAddStruct(a: Int, b: Int) val SimpleAddSchema: StructType = Encoders.product[SimpleAddStruct].schema def simpleAdd(a: Int, b: Int): Int = a + b } class DataUDFTest extends AnyFunSuite { test("test_add") { val result = ExampleUDF.simpleAdd(1, 2) assert(result == 3) } test("test_simple") { val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate() import sparkSession.implicits._ val deck = Seq( Row(1, 2), ) var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = ExampleUDF.SimpleAddSchema) sparkSession.udf.register("SimpleAdd", ExampleUDF.simpleAdd(_: Int, _: Int), IntegerType) df = df.withColumn("c", expr("SimpleAdd(a, b)")) df.show() } test("test_logical_or_00") { println("test_logical_or_00") val bitmask_length = 8 val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(0, 2)), 0, 2), LiteBitmaskSchema) val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(1, 3)), 1, 3), LiteBitmaskSchema) println(s"a_bitmask=${a_bitmask}") println(s"b_bitmask=${b_bitmask}") println(DataUDF.logical_or(a_bitmask, b_bitmask)) println(DataUDF.logical_or(null, b_bitmask)) println(DataUDF.logical_or(a_bitmask, null)) } test("test_logical_or_01") { println("test_logical_or_01") val bitmask_length = 8 val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(0, 2)), 0, 2), LiteBitmaskSchema) val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 0, List(), -1, -1), LiteBitmaskSchema) println(s"a_bitmask=${a_bitmask}") println(s"b_bitmask=${b_bitmask}") println(DataUDF.logical_or(a_bitmask, b_bitmask)) println(DataUDF.logical_or(null, b_bitmask)) println(DataUDF.logical_or(a_bitmask, null)) println(DataUDF.logical_or(null, null)) println(DataUDF.logical_or(a_bitmask, a_bitmask)) println(DataUDF.logical_or(b_bitmask, b_bitmask)) } test("test_bitmask") { println("test_bitmask") val bitmask_length = 8 // val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, Array(Array(0, 2)), 0, 2), LiteBitmaskSchema) // val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, Array(Array(1, 3)), 1, 3), LiteBitmaskSchema) val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(0, 2)), 0, 2), LiteBitmaskSchema) val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(1, 3)), 1, 3), LiteBitmaskSchema) val c_bitmask = new GenericRowWithSchema(Array(bitmask_length, 5, List(List(0, 4)), 0, 4), LiteBitmaskSchema) val d_bitmask = new GenericRowWithSchema(Array(bitmask_length, 0, List(), -1, -1), LiteBitmaskSchema) val e_bitmask = new GenericRowWithSchema(Array(bitmask_length, 6, List(List(0, 5)), 0, 5), LiteBitmaskSchema) val schema = StructType(List( StructField("a", IntegerType, nullable = true), StructField("b", IntegerType, nullable = true), StructField("bitmask", LiteBitmaskSchema, nullable = true) )) val deck = Seq( Row(0, 1, null), Row(1, 2, a_bitmask), Row(2, 4, b_bitmask), Row(3, 8, c_bitmask), Row(4, 16, d_bitmask), Row(5, 32, e_bitmask), ) val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate() var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema) DataUDF.registerAll(sparkSession) df.createOrReplaceTempView("temp_view_a") df.createOrReplaceTempView("temp_view_b") df.show(32, truncate=false) sparkSession.sql("select bitmask.intervals from temp_view_a").show(32, truncate=false) var df2 = sparkSession.sql("Select * from temp_view_a") // https://spark.apache.org/docs/latest/sql-ref-datatypes.html df2.show(32, truncate=false) df2 = df2.withColumn("r_bitmask", expr("logical_or(bitmask, bitmask)")) df2.show(32, truncate=false) var df3 = sparkSession.sql("select tva.a, tva.b, " + "tva.bitmask as a_bitmask, " + "tvb.bitmask as b_bitmask, " + "logical_or(tva.bitmask, tvb.bitmask) as c_bitmask " + // "intersection_any(tva.bitmask, tvb.bitmask) " + "from temp_view_a tva join temp_view_b tvb" ) // BROKEN |2 |4 |{8, 3, [[1, 3]], 1, 3}|{8, 6, [[0, 5]], 0, 5}|{8, 0, [], -1, -1} | df3.show(32, truncate=false) df3.printSchema() } }