Skip to content
Snippets Groups Projects
DataUDFTest.scala 11.3 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

package Octeres

import Octeres.DataUDF.{LiteBitmaskSchema, intersection_any, liteBitmaskStructToRow, logical_or, validateLiteBitmaskSlotsLike}
import Octeres.Bitmask.BitmaskLite.{intersection_any_v0, intersection_any_v1, IntervalIterator}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoders, Row, SparkSession}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._
import org.scalatest.prop.TableFor3

import scala.collection.immutable.ArraySeq

object ExampleUDF {
    // very simple UDF function as an example.
    case class SimpleAddStruct(a: Int, b: Int)
    val SimpleAddSchema: StructType = Encoders.product[SimpleAddStruct].schema
    def simpleAdd(a: Int, b: Int): Int = a + b

}

class DataUDFTest extends AnyFunSuite {

    test("test_add") {
        val result = ExampleUDF.simpleAdd(1, 2)
        assert(result == 3)
    }

    test("test_simple") {
        val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
        import sparkSession.implicits._
        val deck = Seq(
            Row(1, 2),
        )
        var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = ExampleUDF.SimpleAddSchema)
        sparkSession.udf.register("SimpleAdd", ExampleUDF.simpleAdd(_: Int, _: Int), IntegerType)
        df = df.withColumn("c", expr("SimpleAdd(a, b)"))
        df.show()
    }

    test("test_logical_or_00") {
        println("test_logical_or_00")
        val bitmask_length = 8
        val a_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
        val b_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3)
        println(s"a_bitmask=${a_bitmask}")
        println(s"b_bitmask=${b_bitmask}")
        println(DataUDF.logical_or(a_bitmask, b_bitmask))
        println(DataUDF.logical_or(null, b_bitmask))
        println(DataUDF.logical_or(a_bitmask, null))
    }

    test("test_logical_or_01") {
        println("test_logical_or_00")
        val bitmask_length = 8
        val a_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
        val b_bitmask: Row = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3)
        println(s"a_bitmask=${a_bitmask}")
        println(s"b_bitmask=${b_bitmask}")
        println(DataUDF.logical_or(a_bitmask, b_bitmask))
        println(DataUDF.logical_or(null, b_bitmask))
        println(DataUDF.logical_or(a_bitmask, null))
    }

    test("test_logical_or_02") {
        println("test_logical_or_01")
        val bitmask_length = 8
        val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
        val b_bitmask = DataUDF.getLiteBitmaskZeros(bitmask_length)
//        val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 0, ArraySeq[Long](), -1, -1)
        println(s"a_bitmask=${a_bitmask}")
        println(s"b_bitmask=${b_bitmask}")
        println(DataUDF.logical_or(a_bitmask, b_bitmask))
        println(DataUDF.logical_or(null, b_bitmask))
        println(DataUDF.logical_or(a_bitmask, null))
        println(DataUDF.logical_or(null, null))
        println(DataUDF.logical_or(a_bitmask, a_bitmask))
        println(DataUDF.logical_or(b_bitmask, b_bitmask))
    }

    test("test_bitmask") {
        println("test_bitmask")
        val bitmask_length = 8
        // val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, Array(Array(0, 2)), 0, 2)
        // val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, Array(Array(1, 3)), 1, 3)
        val a_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2)
        val b_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3)
        val c_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4)
        val d_bitmask = DataUDF.getLiteBitmaskZeros(bitmask_length)
        val e_bitmask = DataUDF.getLiteBitmaskRow(bitmask_length, 6, ArraySeq(ArraySeq[Long](0, 5)), 0, 5)

        val schema = StructType(List(
            StructField("a", IntegerType, nullable = true),
            StructField("b", IntegerType, nullable = true),
            StructField("bitmask", LiteBitmaskSchema, nullable = true)
        ))

        val deck = Seq(
            Row(0, 1, null),  // cannot have null.
            Row(1, 2, a_bitmask),
            Row(2, 4, b_bitmask),
            Row(3, 8, c_bitmask),
            Row(4, 16, d_bitmask),
            Row(5, 32, e_bitmask),
        )
        val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
        val df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema)
        DataUDF.registerAll()
        df.createOrReplaceTempView("temp_view_a")
        df.createOrReplaceTempView("temp_view_b")
        df.show(32, truncate=false)
        sparkSession.sql("select bitmask.intervals from temp_view_a").show(32, truncate=false)
        val df2 = sparkSession.sql("Select * from temp_view_a")
        // https://spark.apache.org/docs/latest/sql-ref-datatypes.html
        println("direct intersection_any")
        df2.withColumn("r_bitmask", expr("intersection_any(bitmask, bitmask)")).show(32, truncate=false)
        df2.withColumn("r_bitmask", expr("logical_or(bitmask, bitmask)")).show(32, truncate=false)
        println("about to select..")
        val df3 = sparkSession.sql("select tva.a, tva.b, " +
            "tva.bitmask as a_bitmask, " +
            "tvb.bitmask as b_bitmask, " +
            "logical_or(tva.bitmask, tvb.bitmask) as c_bitmask " +
            // "intersection_any(tva.bitmask, tvb.bitmask) " +
            "from temp_view_a tva join temp_view_b tvb"
        )
Eric Pershey's avatar
Eric Pershey committed
        // BROKEN |2  |4  |{8, 3, [[1, 3]], 1, 3}|{8, 6, [[0, 5]], 0, 5}|{8, 0, [], -1, -1}    |
        df3.show(32, truncate=false)
        df3.printSchema()
    }
        val bitmaskTable: TableFor3[Row, Row, Row] = Table(
            ("bitmask0", "bitmask1", "bitmask_or"),
            (
                DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2),
                DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3),
                DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 3)), 0, 3)
                DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4),
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4)
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0, 4),
                DataUDF.getLiteBitmaskRow(8, 2, ArraySeq(ArraySeq[Long](1, 2)), 1, 2),
                DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4)
//                DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6),
//                DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 4),
//                DataUDF.getLiteBitmaskRow(8, 8, ArraySeq(ArraySeq[Long](1, 8)), 1, 8)
                DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6),
                DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 8),
                DataUDF.getLiteBitmaskRow(16, 8, ArraySeq(ArraySeq[Long](1, 8)), 1, 8)
            )
        )

        forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: Row) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            println(s"bitmask_or: ${bitmask_or}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            validateLiteBitmaskSlotsLike(bitmask_or)
            val bitmask_r = logical_or(bitmask0, bitmask1)
            val bitmask_rr = liteBitmaskStructToRow(bitmask_r)
            println(s"bitmask_rr: ${bitmask_rr}")
            validateLiteBitmaskSlotsLike(bitmask_rr)
            bitmask_rr should equal(bitmask_or)
        val bitmaskTable: TableFor3[ArraySeq[ArraySeq[Long]], Long, Long] = Table(
            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](4, 5)), 0, -1),
            (ArraySeq(ArraySeq[Long](2, 4), ArraySeq[Long](8, 10)), 4, 9),
            (ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 3, 6),
            (ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 6)
        )

        forAll(bitmaskTable) { (intervals, bit_idx_first, bit_idx_last) => {
            {
                println(intervals)
                val ii = new IntervalIterator(intervals, bit_idx_first, bit_idx_last)
                ii.printState()
                while (!ii.isEmpty) {
                    println(s"i: ${ii.get()}")
                }
            }
        }}
    }

    test("test_bitmask_intersection_any") {
        val bitmaskTable: TableFor3[Row, Row, Boolean] = Table(
                DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](0, 2)), 0, 2),
                DataUDF.getLiteBitmaskRow(8, 3, ArraySeq(ArraySeq[Long](1, 3)), 1, 3),
                DataUDF.getLiteBitmaskRow(8, 5, ArraySeq(ArraySeq[Long](0, 4)), 0, 4),
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskZeros(8),
                DataUDF.getLiteBitmaskRow(8, 4, ArraySeq(ArraySeq[Long](0, 1), ArraySeq[Long](3, 4)), 0, 4),
                DataUDF.getLiteBitmaskRow(8, 2, ArraySeq(ArraySeq[Long](1, 2)), 1, 2),
                DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](1, 2), ArraySeq[Long](5, 6)), 1, 6),
                DataUDF.getLiteBitmaskRow(16, 4, ArraySeq(ArraySeq[Long](3, 4), ArraySeq[Long](7, 8)), 3, 8),
        forAll(bitmaskTable) { (bitmask0, bitmask1, logical_truth: Boolean) => {
            println(s"bitmask0: ${bitmask0}")
            println(s"bitmask1: ${bitmask1}")
            validateLiteBitmaskSlotsLike(bitmask0)
            validateLiteBitmaskSlotsLike(bitmask1)
            val truth = intersection_any(bitmask0, bitmask1)
            println(s"truth: ${truth}")
            truth should equal(logical_truth)
        }
        }