/* * Copyright (C) 2024, UChicago Argonne, LLC * Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file * in the top-level directory. */ /* Todo: the .toInt are not the best for indexing a ArraySeq, this may require a rewrite using a different larger structure other than ArraySeq[ArraySeq[Int]]. */ package Octeres import org.apache.spark.sql._ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.udf import Octeres.Bitmask.BitmaskLite.LiteBitmaskStruct import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types.StructType import scala.collection.immutable.ArraySeq object DataUDF { val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass) def getLiteBitmaskZeros(length: Long): Row = { new GenericRowWithSchema(Array(length, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema) } def getLiteBitmaskZeros(length: Int): Row = { new GenericRowWithSchema(Array(length.toLong, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema) } def getLiteBitmaskRow(length: Long, bit_count: Long, intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long): Row = { new GenericRowWithSchema(Array(length, bit_count, bit_idx_first, bit_idx_last, intervals), LiteBitmaskSchema) } def validateLiteBitmaskSlotsLike(row: Row): Unit = { val obj = rowToLiteBitmaskStruct(row) Octeres.Bitmask.BitmaskLite.validateLiteBitmaskSlotsLike(obj) } def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = { /* convert a LiteBitmaskStruct to a Row */ // might need this somehow: import sparkSession.implicits._ val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count, bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last, bitmaskStruct.intervals), LiteBitmaskSchema) row } def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = { /* convert a Row to a LiteBitmaskStruct */ val length: Long = row.getAs[Long]("length") val bit_count: Long = row.getAs[Long]("bit_count") val bit_idx_first: Long = row.getAs[Long]("bit_idx_first") val bit_idx_last: Long = row.getAs[Long]("bit_idx_last") val intervals: ArraySeq[ArraySeq[Long]] = row.getAs[ArraySeq[ArraySeq[Long]]]("intervals") LiteBitmaskStruct(length, bit_count, bit_idx_first, bit_idx_last, intervals) } def logical_or(a_bitmask: Row, b_bitmask: Row): LiteBitmaskStruct = { (a_bitmask, b_bitmask) match { case (null, null) => null case (null, _) => rowToLiteBitmaskStruct(b_bitmask) case (_, null) => rowToLiteBitmaskStruct(a_bitmask) case (_, _) => Octeres.Bitmask.BitmaskLite.logical_or(rowToLiteBitmaskStruct(a_bitmask), rowToLiteBitmaskStruct(b_bitmask)) } } def intersection_any(a_bitmask: Row, b_bitmask: Row): Boolean = { val result = (a_bitmask, b_bitmask) match { case (null, null) => false case (null, _) => false case (_, null) => false case (_, _) => Octeres.Bitmask.BitmaskLite.intersection_any(rowToLiteBitmaskStruct(a_bitmask), rowToLiteBitmaskStruct(b_bitmask)) } result } val udf_intersection_any: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => intersection_any(a_bitmask, b_bitmask)) val udf_logical_or: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => logical_or(a_bitmask, b_bitmask)) def testCall(): Unit = { println(s"from Scala: nothing to see here!") } def testCall(a: Long, b: Long): Long = { println(s"from Scala: ${a} + ${b} = ${a + b}") a + b } def registerAll(): Unit = { val sparkSession: SparkSession = SparkSession.builder().getOrCreate() println("Registering Scala UDF functions using sparkSession.") sparkSession.udf.register("logical_or", udf_logical_or) sparkSession.udf.register("intersection_any", udf_intersection_any) } }