Skip to content
Snippets Groups Projects
DataUDFTest.scala 5.21 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

package Octeres

import Octeres.DataUDF.{LiteBitmaskSchema, logical_or}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoders, Row, SparkSession}
import org.scalatest.funsuite.AnyFunSuite


object ExampleUDF {
    // very simple UDF function as an example.
    case class SimpleAddStruct(a: Int, b: Int)
    val SimpleAddSchema: StructType = Encoders.product[SimpleAddStruct].schema
    def simpleAdd(a: Int, b: Int): Int = a + b

}

class DataUDFTest extends AnyFunSuite {

    test("test_add") {
        val result = ExampleUDF.simpleAdd(1, 2)
        assert(result == 3)
    }

    test("test_simple") {
        val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
        import sparkSession.implicits._
        val deck = Seq(
            Row(1, 2),
        )
        var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = ExampleUDF.SimpleAddSchema)
        sparkSession.udf.register("SimpleAdd", ExampleUDF.simpleAdd(_: Int, _: Int), IntegerType)
        df = df.withColumn("c", expr("SimpleAdd(a, b)"))
        df.show()
    }

    test("test_logical_or_00") {
        println("test_logical_or_00")
        val bitmask_length = 8
        val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(0, 2)), 0, 2), LiteBitmaskSchema)
        val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(1, 3)), 1, 3), LiteBitmaskSchema)
        println(s"a_bitmask=${a_bitmask}")
        println(s"b_bitmask=${b_bitmask}")
        println(DataUDF.logical_or(a_bitmask, b_bitmask))
        println(DataUDF.logical_or(null, b_bitmask))
        println(DataUDF.logical_or(a_bitmask, null))
    }

    test("test_logical_or_01") {
        println("test_logical_or_01")
        val bitmask_length = 8
        val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(0, 2)), 0, 2), LiteBitmaskSchema)
        val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 0, List(), -1, -1), LiteBitmaskSchema)
        println(s"a_bitmask=${a_bitmask}")
        println(s"b_bitmask=${b_bitmask}")
        println(DataUDF.logical_or(a_bitmask, b_bitmask))
        println(DataUDF.logical_or(null, b_bitmask))
        println(DataUDF.logical_or(a_bitmask, null))
        println(DataUDF.logical_or(null, null))
        println(DataUDF.logical_or(a_bitmask, a_bitmask))
        println(DataUDF.logical_or(b_bitmask, b_bitmask))
    }

    test("test_bitmask") {
        println("test_bitmask")
        val bitmask_length = 8
        // val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, Array(Array(0, 2)), 0, 2), LiteBitmaskSchema)
        // val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, Array(Array(1, 3)), 1, 3), LiteBitmaskSchema)
        val a_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(0, 2)), 0, 2), LiteBitmaskSchema)
        val b_bitmask = new GenericRowWithSchema(Array(bitmask_length, 3, List(List(1, 3)), 1, 3), LiteBitmaskSchema)
        val c_bitmask = new GenericRowWithSchema(Array(bitmask_length, 5, List(List(0, 4)), 0, 4), LiteBitmaskSchema)
        val d_bitmask = new GenericRowWithSchema(Array(bitmask_length, 0, List(), -1, -1), LiteBitmaskSchema)
        val e_bitmask = new GenericRowWithSchema(Array(bitmask_length, 6, List(List(0, 5)), 0, 5), LiteBitmaskSchema)

        val schema = StructType(List(
            StructField("a", IntegerType, nullable = true),
            StructField("b", IntegerType, nullable = true),
            StructField("bitmask", LiteBitmaskSchema, nullable = true)
        ))

        val deck = Seq(
            Row(0, 1, null),
            Row(1, 2, a_bitmask),
            Row(2, 4, b_bitmask),
            Row(3, 8, c_bitmask),
            Row(4, 16, d_bitmask),
            Row(5, 32, e_bitmask),
        )
        val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
        var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema)
        DataUDF.registerAll(sparkSession)
        df.createOrReplaceTempView("temp_view_a")
        df.createOrReplaceTempView("temp_view_b")
        df.show(32, truncate=false)
        sparkSession.sql("select bitmask.intervals from temp_view_a").show(32, truncate=false)
        var df2 = sparkSession.sql("Select * from temp_view_a")
        // https://spark.apache.org/docs/latest/sql-ref-datatypes.html
        df2.show(32, truncate=false)
        df2 = df2.withColumn("r_bitmask", expr("logical_or(bitmask, bitmask)"))
        df2.show(32, truncate=false)
        var df3 = sparkSession.sql("select tva.a, tva.b, " +
            "tva.bitmask as a_bitmask, " +
            "tvb.bitmask as b_bitmask, " +
            "logical_or(tva.bitmask, tvb.bitmask) as c_bitmask " +
            // "intersection_any(tva.bitmask, tvb.bitmask) " +
            "from temp_view_a tva join temp_view_b tvb"
        )
        df3.show(32, truncate=false)
        df3.printSchema()
    }