Skip to content
Snippets Groups Projects
DataUDF.scala 4.13 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

/*
Todo: the .toInt are not the best for indexing a ArraySeq, this may require a rewrite using a different larger structure other
    than ArraySeq[ArraySeq[Int]].
 */
package Octeres

import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import Octeres.Bitmask.BitmaskLite.LiteBitmaskStruct
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
import scala.collection.immutable.ArraySeq
object DataUDF {
    val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
    val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)

    def getLiteBitmaskZeros(length: Long): Row = {
        new GenericRowWithSchema(Array(length, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema)
    }

    def getLiteBitmaskZeros(length: Int): Row = {
        new GenericRowWithSchema(Array(length.toLong, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema)
    }

    def getLiteBitmaskRow(length: Long, bit_count: Long, intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long): Row = {
        new GenericRowWithSchema(Array(length, bit_count, bit_idx_first, bit_idx_last, intervals), LiteBitmaskSchema)
    }

    def validateLiteBitmaskSlotsLike(row: Row): Unit = {
        val obj = rowToLiteBitmaskStruct(row)
        Octeres.Bitmask.BitmaskLite.validateLiteBitmaskSlotsLike(obj)
    def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
        /* convert a LiteBitmaskStruct to a Row */
        // might need this somehow: import sparkSession.implicits._
        val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count,
            bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last, bitmaskStruct.intervals), LiteBitmaskSchema)
    def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
        /* convert a Row to a LiteBitmaskStruct */
        val length: Long = row.getAs[Long]("length")
        val bit_count: Long = row.getAs[Long]("bit_count")
        val bit_idx_first: Long = row.getAs[Long]("bit_idx_first")
        val bit_idx_last: Long = row.getAs[Long]("bit_idx_last")
        val intervals: ArraySeq[ArraySeq[Long]] = row.getAs[ArraySeq[ArraySeq[Long]]]("intervals")
        LiteBitmaskStruct(length, bit_count, bit_idx_first, bit_idx_last, intervals)
    def logical_or(a_bitmask: Row, b_bitmask: Row): LiteBitmaskStruct = {
        (a_bitmask, b_bitmask) match {
            case (null, null) => null
            case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
            case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
            case (_, _) => Octeres.Bitmask.BitmaskLite.logical_or(rowToLiteBitmaskStruct(a_bitmask), rowToLiteBitmaskStruct(b_bitmask))
    def intersection_any(a_bitmask: Row, b_bitmask: Row): Boolean = {
        val result = (a_bitmask, b_bitmask) match {
            case (null, null) => false
            case (null, _) => false
            case (_, null) => false
            case (_, _) => Octeres.Bitmask.BitmaskLite.intersection_any(rowToLiteBitmaskStruct(a_bitmask), rowToLiteBitmaskStruct(b_bitmask))
    val udf_intersection_any: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => intersection_any(a_bitmask, b_bitmask))
    val udf_logical_or: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => logical_or(a_bitmask, b_bitmask))


    def testCall(): Unit = {
        println(s"from Scala: nothing to see here!")
    }

    def testCall(a: Long, b: Long): Long = {
        println(s"from Scala: ${a} + ${b} = ${a + b}")
        a + b
    }

    def registerAll(): Unit = {
        val sparkSession: SparkSession = SparkSession.builder().getOrCreate()
        println("Registering Scala UDF functions using sparkSession.")
        sparkSession.udf.register("logical_or", udf_logical_or)
        sparkSession.udf.register("intersection_any", udf_intersection_any)