Skip to content
Snippets Groups Projects
DataUDF.scala 20.7 KiB
Newer Older
/*
 * Copyright (C) 2024, UChicago Argonne, LLC
 * Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
 * in the top-level directory.
 */

/*
Todo: the .toInt are not the best for indexing a ArraySeq, this may require a rewrite using a different larger structure other
    than ArraySeq[ArraySeq[Int]].
 */
package Octeres

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
object DataUDF {
    case class LiteBitmaskStruct(length: Long, bit_count: Long, bit_idx_first: Long, bit_idx_last: Long,
                                 intervals: ArraySeq[ArraySeq[Long]])
    val LiteBitmaskSchema: StructType = Encoders.product[LiteBitmaskStruct].schema
    val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)

    def getLiteBitmaskZeros(length: Long): Row = {
        new GenericRowWithSchema(Array(length, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema)
    }

    def getLiteBitmaskZeros(length: Int): Row = {
        new GenericRowWithSchema(Array(length.toLong, 0L, -1L, -1L, ArraySeq()), LiteBitmaskSchema)
    }

    def getLiteBitmaskRow(length: Long, bit_count: Long, intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long): Row = {
        new GenericRowWithSchema(Array(length, bit_count, bit_idx_first, bit_idx_last, intervals), LiteBitmaskSchema)
    def getCombinedIntervalsUntilAllConsumed(alst: ArraySeq[ArraySeq[Long]], blst: ArraySeq[ArraySeq[Long]]): Seq[List[Long]] = {
        var aidx: Long = 0
        var bidx: Long = 0
        val exit_list = List[Long](-2, -2)
            val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2L
            val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2L
            val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2L
            val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2L
                    List(bmin, bmax)
                case (_, _) if (amin != -2 & amin <= bmin) =>
    class BitmaskError(message: String) extends Exception(message) {
        def this(message: String, cause: Throwable = null) = {
            this(message)
            initCause(cause)
        }
        def this(cause: Throwable) = {
            this(Option(cause).map(_.toString).orNull, cause)
        }
        def this() = {
            this(null: String)
        }
    }
    def validateLiteBitmaskSlotsLike(row: Row): Unit = {
        val obj = rowToLiteBitmaskStruct(row)
        validateLiteBitmaskSlotsLike(obj)
    }

    def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
        val intervals = obj.intervals
        if (obj.bit_count > 0) {
            if (obj.bit_idx_last >= obj.length) {
                throw new BitmaskError(s"bit_idx_last is greater than length")
            }

            if (obj.bit_idx_first != intervals.head.head) {
                throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
            }
            if (obj.bit_idx_last != intervals.last.last) {
                throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != intervals.last.last:${intervals.last.last}")
            }
        } else {
            if (obj.bit_idx_first != -1) {
                throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
            }
            if (obj.bit_idx_last != -1) {
                throw new BitmaskError(s"bit_idx_last:${obj.bit_idx_last} != -1")
        for ((alar, i) <- intervals.zipWithIndex) {
            val al: Long = alar.head
            val ar: Long = alar.last
            val sub_bit_count: Long = (ar - al) + 1L
            if (sub_bit_count < 1) {
                throw new BitmaskError(s"negative interval: [$al, $ar]")
            }
            bit_count += sub_bit_count

            if (al > ar) {
                throw new BitmaskError(s"negative interval: [$al, $ar]")
            }
            val blbr: Option[ArraySeq[Long]] = intervals.lift(i + 1)
                case Some(x: ArraySeq[Long]) => {
                    val bl: Long = x.head
                    val br: Long = x.last
                    if ((al == bl) & (ar == br)) {
                        throw new BitmaskError(s"duplicate interval: [$al, $ar]")
                    }
                    if ((ar + 1) == bl) {
                        throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
                    }

                    if (ar > bl) {
                        throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
                    }
                }
                case None => Nil
            }
        }
        if (bit_count != obj.bit_count) {
            throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
        }
    }

    def logical_or_v0(abi: Row, bbi: Row): LiteBitmaskStruct = {
        logical_or_v0(rowToLiteBitmaskStruct(abi), rowToLiteBitmaskStruct(bbi))
    def logical_or_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): LiteBitmaskStruct = {
        /* does a logical or of two bit intervals
         * From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
            val intervalGen: Seq[List[Long]] = getCombinedIntervalsUntilAllConsumed(
        val intervals = ArrayBuffer.empty[ArraySeq[Long]]
        var bitCount: Long = 0
        var bitIdxFirst: Long = -1
        var bitIdxLast: Long = -1
        var prev_start: Long = -1
        var prev_end: Long = -1
        val itr = intervalGen.iterator
        for (List(next_start, next_end) <- itr) {
            (next_start, next_end) match {
                case (-2, -2) => // case where there is nothing
                    bitCount = 0
                    bitIdxFirst = -1
                    bitIdxLast = -1
                case _ if (prev_start == -1 & prev_end == -1) => // Initial variable, setting previous
                    prev_start = next_start
                    prev_end = next_end
                    intervals += ArraySeq(prev_start, prev_end)
                    bitCount = prev_end - prev_start + 1
                    bitIdxFirst = intervals.head.head
                    bitIdxLast = intervals.last.last
                case _ =>
                    if (next_start <= prev_end + 1) {
                        val new_end = Math.max(prev_end, next_end)
                        intervals(intervals.length - 1) = ArraySeq(prev_start, new_end)
                        bitCount += new_end - prev_end
                        prev_end = new_end
                    } else {
                        intervals += ArraySeq(next_start, next_end)
                        bitCount += next_end - next_start + 1
                        prev_start = next_start
                        prev_end = next_end
                    }
                    bitIdxFirst = intervals.head.head
                    bitIdxLast = intervals.last.last
            }
        }
        val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.from(intervals)
        // hopefully this is a view not a copy
        // val result_intervals: ArraySeq[ArraySeq[Long]] = ArraySeq.unsafeWrapArray(intervals.toArray)
        LiteBitmaskStruct(a_bitmask.length, bitCount, bitIdxFirst, bitIdxLast, result_intervals)
    def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
        /* convert a LiteBitmaskStruct to a Row */
        // might need this somehow: import sparkSession.implicits._
        val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count,
            bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last, bitmaskStruct.intervals), LiteBitmaskSchema)
    def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
        /* convert a Row to a LiteBitmaskStruct */
        val length: Long = row.getAs[Long]("length")
        val bit_count: Long = row.getAs[Long]("bit_count")
        val bit_idx_first: Long = row.getAs[Long]("bit_idx_first")
        val bit_idx_last: Long = row.getAs[Long]("bit_idx_last")
        val intervals: ArraySeq[ArraySeq[Long]] = row.getAs[ArraySeq[ArraySeq[Long]]]("intervals")
        LiteBitmaskStruct(length, bit_count, bit_idx_first, bit_idx_last, intervals)
    def logical_or(a_bitmask: Row, b_bitmask: Row): LiteBitmaskStruct = {
        (a_bitmask, b_bitmask) match {
            case (null, null) => null
            case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
            case (_, null) => rowToLiteBitmaskStruct(a_bitmask)
            case (_, _) => logical_or_v0(a_bitmask, b_bitmask)
        }
    }

    def intersection_possible(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
        var result = true
        if ((a_bitmask.bit_count == 0) | (b_bitmask.bit_count == 0)) {
            result = false
        } else if (a_bitmask.bit_idx_first > b_bitmask.bit_idx_last) {
            result = false
        } else if (a_bitmask.bit_idx_last < b_bitmask.bit_idx_first) {
            result = false
        } else {
            result = true
        }
        result
    }

    def intersection_possible(abi: Row, bbi: Row): Boolean = {
        val bitmask_a = rowToLiteBitmaskStruct(abi)
        val bitmask_b = rowToLiteBitmaskStruct(bbi)
        intersection_possible(bitmask_a, bitmask_b)
    }

    def getIntervalUntilAllConsumed(alst: List[List[Long]], blst: List[List[Long]]): Seq[List[Long]] = {
        var aidx: Long = 0
        var bidx: Long = 0
        val exit_list: List[Long] = List(-2, -2)
            val amin: Long = if (aidx < alst.length) alst(aidx.toInt).head else -2
            val amax: Long = if (aidx < alst.length) alst(aidx.toInt).last else -2
            val bmin: Long = if (bidx < blst.length) blst(bidx.toInt).head else -2
            val bmax: Long = if (bidx < blst.length) blst(bidx.toInt).last else -2
            //            println(s"aidx:${aidx} bidx:${bidx} ${amin},${bmin} amin:${amin} amax:${amax}  bmin:${bmin} bmax:${bmax}")
            (amin, bmin) match {
                case (-2, -2) =>
                    exit_list
                case (_, -2) =>
                    aidx += 1
                    List(amin, amax)
                case (-2, _) =>
                    bidx += 1
                    List(bmin, bmax)
                case (_, _) if (amin != -2 & amin <= bmin) =>
                    aidx += 1
                    List(amin, amax)
                case _ =>
                    bidx += 1
                    List(bmin, bmax)
            }
        }.takeWhile(_ != exit_list)
    }

    class IntervalIterator(intervals: ArraySeq[ArraySeq[Long]], bit_idx_first: Long, bit_idx_last: Long) {
        var cur_interval_idx: Long = 0
        var cur_interval: ArraySeq[Long] = intervals(cur_interval_idx.toInt)
        var cur_range_idx: Long = 0
        var cur_range: ArraySeq[Long] = ArraySeq.range(cur_interval.head, cur_interval.last+1)
        _jumpForward(bit_idx_first)

        def printState(): Unit = {
            println(s"bit_idx_first: ${bit_idx_first} bit_idx_last: ${bit_idx_last}")
            println(s"cur_interval_idx: ${cur_interval_idx} cur_interval: ${cur_interval}")
            println(s"cur_range_idx: ${cur_range_idx} cur_range: ${cur_range}")
        }

            var found = false
            while (!isEmpty & !found) {
                if (bit_idx_first > cur_interval.last) {
                    _moveForwardInterval()
                } else {
                    if (bit_idx_first > cur_range(cur_range_idx.toInt)){
                        _moveForwardRange()
                    } else {
                        found = true
                    }
                }

            }
        }

        def _moveEnd(): Unit = {
            cur_interval_idx = -2
            cur_range_idx = -2
        }

        def _moveForwardInterval(): Unit = {
            cur_interval_idx += 1
            cur_range_idx = 0
            cur_interval = intervals(cur_interval_idx.toInt)
            cur_range = ArraySeq.range(cur_interval.head, cur_interval.last+1)
        }

        def _moveForwardRange(): Unit = {
            cur_range_idx += 1
        }
        def isEmpty: Boolean = {
            if (cur_interval_idx == -2 & cur_range_idx == -2){
                true
            } else if (cur_interval_idx >= intervals.length & cur_range_idx >= cur_range.length){
                true
            } else {
                false
            }
        }

        def nonEmpty: Boolean = {
            !isEmpty
        }

        def _move(): Unit = {
            if (cur_range_idx+1 >= cur_range.length){
                if (cur_interval_idx+1 >= cur_interval.length) {
                    _moveEnd()
                } else {
                    _moveForwardInterval()
                }
            } else {
                _moveForwardRange()
            }
        }

        def get(): Long = {
            val i = cur_range(cur_range_idx.toInt)
//    def getIntervalBitsGenerator(intervals: List[List[Long]], bit_idx_first: Long, bit_idx_last: Long): Iterator[Long] = {
//        for (lst <- intervals) {
//            val start = lst(0)
//            val end = lst.last + 1
//            LazyList.range(start, end)
//        }
//    }

    def getIntervalBits(a_bitmask: LiteBitmaskStruct, bit_idx_first: Long = 0, bit_idx_last: Long = -2): List[Long] = {
        val bit_lst: ListBuffer[Long] = ListBuffer()
        a_bitmask.intervals.foreach(elem => {
            for (i <- elem.head to elem.last) { // inclusive
                if (i < bit_idx_first){

                } else if (bit_idx_last != -2 & (i > bit_idx_last)) {

                } else {
                    bit_lst += i
                }
            }
        })
        bit_lst.toList
    }

    def intersection_any_v0(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
        var result = false
        if (!intersection_possible(a_bitmask, b_bitmask)) {
            result = false
        } else {
            val first_a = a_bitmask.bit_idx_first
            val last_a = a_bitmask.bit_idx_last
            val first_b = b_bitmask.bit_idx_first
            val last_b = b_bitmask.bit_idx_last
            if ((first_a == first_b) | (last_a == last_b)) {
                result = true
            } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
                result = true
            } else {
                val bit_idx_first = if (first_a >= first_b) first_a else first_b
                val bit_idx_last = if (last_a < last_b) last_a else last_b
                var a_lst = getIntervalBits(a_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
                var b_lst = getIntervalBits(b_bitmask, bit_idx_first=bit_idx_first, bit_idx_last=bit_idx_last)
                if (a_lst.isEmpty | b_lst.isEmpty){
                } else {
                    var a_idx = a_lst.head
                    a_lst = a_lst.tail
                    var b_idx = b_lst.head
                    b_lst = b_lst.tail
                    while (a_idx >= 0 & b_idx >= 0) {
                        if (a_idx == b_idx) {
                            result = true
                            a_idx = -2
                            b_idx = -2
                        } else if (a_idx > b_idx){
                            // move b_idx
                            if (b_lst.nonEmpty) {
                                b_idx = b_lst.head
                                b_lst = b_lst.tail
                            } else {
                                b_idx = -2
                            }
                        } else {
                            // move a_idx
                            if (a_lst.nonEmpty){
                                a_idx = a_lst.head
                                a_lst = a_lst.tail
                            } else {
                                a_idx = -2
                            }
                        }
                    }
                }
            }
        }
        result
    }

    def intersection_any_v0(abi: Row, bbi: Row): Boolean = {
        val bitmask_a = rowToLiteBitmaskStruct(abi)
        val bitmask_b = rowToLiteBitmaskStruct(bbi)
        intersection_any_v0(bitmask_a, bitmask_b)
    }

    def intersection_any_v1(a_bitmask: LiteBitmaskStruct, b_bitmask: LiteBitmaskStruct): Boolean = {
        var result = false
        if (!intersection_possible(a_bitmask, b_bitmask)) {
            result = false
        } else {
            val first_a = a_bitmask.bit_idx_first
            val last_a = a_bitmask.bit_idx_last
            val first_b = b_bitmask.bit_idx_first
            val last_b = b_bitmask.bit_idx_last
            if ((first_a == first_b) | (last_a == last_b)) {
                result = true
            } else if ((a_bitmask.bit_idx_last == b_bitmask.bit_idx_first) | (a_bitmask.bit_idx_last == b_bitmask.bit_idx_last)) {
                result = true
            } else {
                val bit_idx_first = if (first_a >= first_b) first_a else first_b
                val bit_idx_last = if (last_a < last_b) last_a else last_b
                val a_ii = new IntervalIterator(a_bitmask.intervals, bit_idx_first, bit_idx_last)
                val b_ii = new IntervalIterator(b_bitmask.intervals, bit_idx_first, bit_idx_last)
                if (a_ii.isEmpty | b_ii.isEmpty){
                } else {
                    var a_idx = a_ii.get()
                    var b_idx = b_ii.get()
                    while (a_idx >= 0 & b_idx >= 0) {
                        if (a_idx == b_idx) {
                            result = true
                            a_idx = -2
                            b_idx = -2
                        } else if (a_idx > b_idx){
                            // move b_idx
                            if (b_ii.nonEmpty) {
                                b_idx = b_ii.get()
                            } else {
                                b_idx = -2
                            }
                        } else {
                            // move a_idx
                            if (a_ii.nonEmpty){
                                a_idx = a_ii.get()
                            } else {
                                a_idx = -2
                            }
                        }
                    }
                }
            }
        }
        result
    }

    def intersection_any_v1(abi: Row, bbi: Row): Boolean = {
        val bitmask_a = rowToLiteBitmaskStruct(abi)
        val bitmask_b = rowToLiteBitmaskStruct(bbi)
        intersection_any_v1(bitmask_a, bitmask_b)
    }

    def intersection_any(a_bitmask: Row, b_bitmask: Row): Boolean = {
        val result = (a_bitmask, b_bitmask) match {
            case (null, null) => false
            case (null, _) => false
            case (_, null) => false
            case (_, _) => intersection_any_v1(a_bitmask, b_bitmask)
        }
        result
    val udf_intersection_any: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => intersection_any(a_bitmask, b_bitmask))
    val udf_logical_or: UserDefinedFunction = udf((a_bitmask: Row, b_bitmask: Row) => logical_or(a_bitmask, b_bitmask))


    def testCall(): Unit = {
        println(s"from Scala: nothing to see here!")
    }

    def testCall(a: Long, b: Long): Long = {
        println(s"from Scala: ${a} + ${b} = ${a + b}")
        a + b
    }

    def registerAll(): Unit = {
        val sparkSession: SparkSession = SparkSession.builder().getOrCreate()
        println("Registering Scala UDF functions using sparkSession.")
        sparkSession.udf.register("logical_or", udf_logical_or)
        sparkSession.udf.register("intersection_any", udf_intersection_any)