Skip to content
Snippets Groups Projects
Commit 57c950e9 authored by Eric Pershey's avatar Eric Pershey
Browse files

updating the scala tests, failing but progress.

parent 6d63aa6c
No related branches found
No related tags found
No related merge requests found
......@@ -14,6 +14,7 @@ resolvers += "Typesafe Repository" at "https://repo.typesafe.com/typesafe/releas
conflictManager := ConflictManager.latestRevision
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.19" % Test
libraryDependencies += "org.scalatest" %% "scalatest-flatspec" % "3.2.19" % Test
val sparkVersion = "3.5.1"
libraryDependencies ++= Seq(
......
......@@ -8,7 +8,9 @@ package Octeres
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import scala.collection.immutable.ArraySeq
import scala.collection.mutable.ListBuffer
......@@ -20,6 +22,10 @@ object DataUDF {
// val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType]
val LiteBitmaskEncoder = Encoders.bean(LiteBitmaskStruct.getClass)
def getLiteBitmaskRow(length: Int, bit_count: Int, intervals: List[List[Int]], bit_idx_first: Int, bit_idx_last: Int): Row = {
new GenericRowWithSchema(Array(length, bit_count, intervals, bit_idx_first, bit_idx_last), LiteBitmaskSchema)
}
def getCombinedIntervalsUntilAllConsumed(alst: List[List[Int]], blst: List[List[Int]]): Seq[List[Int]] = {
var aidx = 0
var bidx = 0
......@@ -47,11 +53,90 @@ object DataUDF {
}.takeWhile(_ != exit_list)
}
class BitmaskError(message: String) extends Exception(message) {
def this(message: String, cause: Throwable = null) {
this(message)
initCause(cause)
}
def this(cause: Throwable) = {
this(Option(cause).map(_.toString).orNull, cause)
}
def this() = {
this(null: String)
}
}
def validateLiteBitmaskSlotsLike(row: Row): Unit = {
val obj = rowToLiteBitmaskStruct(row)
validateLiteBitmaskSlotsLike(obj)
}
def validateLiteBitmaskSlotsLike(obj: LiteBitmaskStruct): Unit = {
val intervals = obj.intervals
if (obj.bit_count > 0) {
// println(s"obj.bit_count > 0: ${intervals} ${obj.bit_idx_first} ${obj.bit_idx_last}")
// println(s"obj.bit_count > 0: ${intervals.head}")
// println(s"obj.bit_count > 0: ${intervals.head.head}")
if (obj.bit_idx_first != intervals.head.head) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
if (obj.bit_idx_last != intervals.last.last) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}")
}
} else {
if (obj.bit_idx_first != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
if (obj.bit_idx_last != -1) {
throw new BitmaskError(s"bit_idx_first:${obj.bit_idx_first} != -1")
}
}
// val len_intervals = intervals.length
var bit_count = 0
for ((alar, i) <- intervals.zipWithIndex) {
val al = alar.head
val ar = alar.last
val sub_bit_count = (ar - al) + 1
if (sub_bit_count < 1) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
bit_count += sub_bit_count
if (al > ar) {
throw new BitmaskError(s"negative interval: [$al, $ar]")
}
val blbr: Option[List[Int]] = intervals.lift(i + 1)
blbr match {
case Some(x: List[Int]) => {
val bl: Int = x.head
val br: Int = x.last
if ((al == bl) & (ar == br)) {
throw new BitmaskError(s"duplicate interval: [$al, $ar]")
}
if ((ar + 1) == bl) {
throw new BitmaskError(s"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr")
}
if (ar > bl) {
throw new BitmaskError(s"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr")
}
}
case None => Nil
}
}
if (bit_count != obj.bit_count) {
throw new BitmaskError(s"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}")
}
}
def intervals_or_v0(length: Int, abi: Row, bbi: Row): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val bitmask_a = rowToLiteBitmaskStruct(abi)
val bitmask_b = rowToLiteBitmaskStruct(bbi)
intervals_or_v0(length, bitmask_a, bitmask_b)
}
def intervals_or_v0(length: Int, bitmask_a: LiteBitmaskStruct, bitmask_b: LiteBitmaskStruct): LiteBitmaskStruct = {
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val intervalGen: Seq[List[Int]] = getCombinedIntervalsUntilAllConsumed(
bitmask_a.intervals,
bitmask_b.intervals)
......@@ -94,8 +179,17 @@ object DataUDF {
LiteBitmaskStruct(length, bitCount, intervals.map(_.toList).toList, bitIdxFirst, bitIdxLast)
}
def liteBitmaskStructToRow(bitmaskStruct: LiteBitmaskStruct): Row = {
/* convert a LiteBitmaskStruct to a Row */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val row: Row = new GenericRowWithSchema(Array(bitmaskStruct.length, bitmaskStruct.bit_count,
bitmaskStruct.intervals, bitmaskStruct.bit_idx_first, bitmaskStruct.bit_idx_last), LiteBitmaskSchema)
row
}
def rowToLiteBitmaskStruct(row: Row): LiteBitmaskStruct = {
/* convert a row to a LiteBitmaskStruct */
/* convert a Row to a LiteBitmaskStruct */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val bitmaskStruct: LiteBitmaskStruct = row match {
......@@ -103,12 +197,10 @@ object DataUDF {
LiteBitmaskStruct(length, bit_count, intervals, bit_idx_first, bit_idx_last)
case Row(length: Int, bit_count: Int, intervals: ArraySeq[ArraySeq[Int]], bit_idx_first: Int, bit_idx_last: Int) =>
LiteBitmaskStruct(length, bit_count, intervals.toList.map(_.toList), bit_idx_first, bit_idx_last)
case _ =>
LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4))
case _ => LiteBitmaskStruct(row.getInt(0), row.getInt(1), row.getAs[List[List[Int]]]("intervals"), row.getInt(3), row.getInt(4))
}
bitmaskStruct
}
val logical_or: (Row, Row) => LiteBitmaskStruct = (a_bitmask: Row, b_bitmask: Row) => (a_bitmask, b_bitmask) match {
case (null, null) => null
case (null, _) => rowToLiteBitmaskStruct(b_bitmask)
......
......@@ -6,12 +6,15 @@
package Octeres
import Octeres.DataUDF.{LiteBitmaskSchema, logical_or}
import Octeres.DataUDF.{LiteBitmaskSchema, liteBitmaskStructToRow, logical_or, validateLiteBitmaskSlotsLike}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoders, Row, SparkSession}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.matchers.should.Matchers._
object ExampleUDF {
......@@ -94,7 +97,7 @@ class DataUDFTest extends AnyFunSuite {
Row(5, 32, e_bitmask),
)
val sparkSession = SparkSession.builder().appName("DataUDFTest").master("local").getOrCreate()
var df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema)
val df = sparkSession.createDataFrame(sparkSession.sparkContext.parallelize(deck), schema = schema)
DataUDF.registerAll(sparkSession)
df.createOrReplaceTempView("temp_view_a")
df.createOrReplaceTempView("temp_view_b")
......@@ -116,4 +119,47 @@ class DataUDFTest extends AnyFunSuite {
df3.show(32, truncate=false)
df3.printSchema()
}
test("test_bitmask_functions") {
val bitmaskTable = Table(
("bitmask0", "bitmask1", "bitmask_or"),
(
DataUDF.getLiteBitmaskRow(8, 3, List(List(0, 2)), 0, 2),
DataUDF.getLiteBitmaskRow(8, 3, List(List(1, 3)), 1, 3),
DataUDF.getLiteBitmaskRow(8, 4, List(List(0, 3)), 0, 3)
), (
DataUDF.getLiteBitmaskRow(8, 5, List(List(0, 4)), 0, 4),
DataUDF.getLiteBitmaskRow(8, 0, List(), -1, -1),
DataUDF.getLiteBitmaskRow(8, 5, List(List(0, 4)), 0, 4)
), (
DataUDF.getLiteBitmaskRow(16, 0, List(), -1, -1),
DataUDF.getLiteBitmaskRow(16, 0, List(), -1, -1),
DataUDF.getLiteBitmaskRow(16, 0, List(), -1, -1)
), (
DataUDF.getLiteBitmaskRow(8, 4, List(List(0, 1), List(3, 4)), 0, 4),
DataUDF.getLiteBitmaskRow(8, 2, List(List(1, 2)), 1, 2),
DataUDF.getLiteBitmaskRow(8, 5, List(List(0, 4)), 0, 4)
), (
DataUDF.getLiteBitmaskRow(8, 4, List(List(1, 2), List(5, 6)), 1, 6),
DataUDF.getLiteBitmaskRow(8, 4, List(List(3, 4), List(7, 8)), 3, 4),
DataUDF.getLiteBitmaskRow(8, 8, List(List(1, 8)), 1, 8)
)
)
forAll(bitmaskTable) { (bitmask0, bitmask1, bitmask_or: Row) => {
println(s"bitmask0: ${bitmask0}")
println(s"bitmask1: ${bitmask1}")
println(s"bitmask_or: ${bitmask_or}")
validateLiteBitmaskSlotsLike(bitmask0)
validateLiteBitmaskSlotsLike(bitmask1)
validateLiteBitmaskSlotsLike(bitmask_or)
val bitmask_r = logical_or(bitmask0, bitmask1)
val bitmask_rr = liteBitmaskStructToRow(bitmask_r)
println(s"bitmask_rr: ${bitmask_rr}")
validateLiteBitmaskSlotsLike(bitmask_rr)
bitmask_rr should equal(bitmask_or)
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment