Skip to content
Snippets Groups Projects
DataUDF.scala 5.45 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

package Octeres

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.ListBuffer


object DataUDF {
    case class LiteBitmaskStruct(length: Int, bit_count: Int, intervals: List[List[Int]],
                                 bit_idx_first: Int, bit_idx_last: Int)
    val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
    // val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType]
    val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)

    def getCombinedIntervalsUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = {
        var aidx = 0
        var bidx = 0
        val exit_list = List(-2, -2)
        LazyList.continually {
            val amin: Int = if (aidx < alst.length) alst(aidx)(0) else -2
            val amax: Int = if (aidx < alst.length) alst(aidx)(1) else -2
            val bmin: Int = if (bidx < blst.length) blst(bidx)(0) else -2
            val bmax: Int = if (bidx < blst.length) blst(bidx)(1) else -2
            (amin, bmin) match {
                case (-2, -2) =>
                    exit_list
                case (_, -2) =>
                    aidx += 1
                    List(amin, amax)
                case (-2, _) =>
                    bidx += 1
                    List(bmin, bmax)
                case (_, _) if (amin != -2 & amin <= bmin) =>
                    aidx += 1
                    List(amin, amax)
                case _ =>
                    exit_list
            }
        }.takeWhile(_ != exit_list)
    }

    def intervals_or_v0(length: Int, abi: Row, bbi: Row): LiteBitmaskStruct = {
        /* does a logical or of two bit intervals
         * From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
        val bitmask_a = rowToLiteBitmaskStruct(abi)
        val bitmask_b = rowToLiteBitmaskStruct(bbi)
        val intervalGen: Seq[List[Int]] = getCombinedIntervalsUntilAllConsumed(
            bitmask_a.intervals,
            bitmask_b.intervals)
        val intervals = ListBuffer.empty[List[Int]]
        var bitCount = 0
        var bitIdxFirst = -1
        var bitIdxLast = -1
        var prev_start = -1
        var prev_end = -1
        val itr = intervalGen.iterator
        for (List(next_start, next_end) <- itr) {
            (next_start, next_end) match {
                case (-2, -2) => // case where there is nothing
                    bitCount = 0
                    bitIdxFirst = -1
                    bitIdxLast = -1
                case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
                    prev_start = next_start
                    prev_end = next_end
                    intervals += List(prev_start, prev_end)
                    bitCount = prev_end - prev_start + 1
                    bitIdxFirst = intervals.head.head
                    bitIdxLast = intervals.last.last
                case _ =>
                    if (next_start <= prev_end + 1) {
                        val new_end = Math.max(prev_end, next_end)
                        intervals(intervals.length - 1) = List(prev_start, new_end)
                        bitCount += new_end - prev_end
                        prev_end = new_end
                    } else {
                        intervals += List(next_start, next_end)
                        bitCount += next_end - next_start + 1
                        prev_start = next_start
                        prev_end = next_end
                    }
                    bitIdxFirst = intervals.head.head
                    bitIdxLast = intervals.last.last
            }
        }
        LiteBitmaskStruct(length, bitCount, intervals.map(_.toList).toList, bitIdxFirst, bitIdxLast)
    }

    def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
        /* convert a row to a LiteBitmaskStruct */
        // might need this somehow: import sparkSession.implicits._
        // row.asInstanceOf[LiteBitmaskStruct]  // does not work
        val bitmaskStruct: LiteBitmaskStruct = row match {
            case Row(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
                LiteBitmaskStruct(length, bit_count, intervals, bit_idx_first, bit_idx_last)
            case Row(length: Int, bit_count: Int, intervals: ArraySeq[ArraySeq[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
                LiteBitmaskStruct(length, bit_count, intervals.toList.map(_.toList), bit_idx_first, bit_idx_last)
            case _ =>
                LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4))
        }
        bitmaskStruct
    }

    val logical_or: (Row, Row) => LiteBitmaskStruct = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match {
        case (null, null) => null
        case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
        case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
        case (_, _) => intervals_or_v0(a_bitmask.getAs("length"), a_bitmask, b_bitmask)
    }

    def registerAll(sparkSession: SparkSession) {
        sparkSession.udf.register("logical_or", logical_or(_: Row, _: Row), LiteBitmaskSchema)
    }