Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
O
octeres
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Package Registry
Operate
Terraform modules
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
AIG-public
octeres
Commits
57c950e9
Commit
57c950e9
authored
1 month ago
by
Eric Pershey
Browse files
Options
Downloads
Patches
Plain Diff
updating the scala tests, failing but progress.
parent
6d63aa6c
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
build.sbt
+1
-0
1 addition, 0 deletions
build.sbt
src/main/scala/Octeres/DataUDF.scala
+98
-6
98 additions, 6 deletions
src/main/scala/Octeres/DataUDF.scala
src/test/scala/Octeres/DataUDFTest.scala
+48
-2
48 additions, 2 deletions
src/test/scala/Octeres/DataUDFTest.scala
with
147 additions
and
8 deletions
build.sbt
+
1
−
0
View file @
57c950e9
...
...
@@ -14,6 +14,7 @@ resolvers += "Typesafe Repository" at "https://repo.typesafe.com/typesafe/releas
conflictManager
:=
ConflictManager
.
latestRevision
libraryDependencies
+=
"org.scalatest"
%%
"scalatest"
%
"3.2.19"
%
Test
libraryDependencies
+=
"org.scalatest"
%%
"scalatest-flatspec"
%
"3.2.19"
%
Test
val
sparkVersion
=
"3.5.1"
libraryDependencies
++=
Seq
(
...
...
This diff is collapsed.
Click to expand it.
src/main/scala/Octeres/DataUDF.scala
+
98
−
6
View file @
57c950e9
...
...
@@ -8,7 +8,9 @@ package Octeres
import
org.apache.spark.sql.SparkSession
import
org.apache.spark.sql._
import
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import
org.apache.spark.sql.types._
import
scala.collection.immutable.ArraySeq
import
scala.collection.mutable.ListBuffer
...
...
@@ -20,6 +22,10 @@ object DataUDF {
// val LiteBitmaskSchema: StructType = ScalaReflection.schemaFor[LiteBitmaskStruct].dataType.asInstanceOf[StructType]
val
LiteBitmaskEncoder
=
Encoders
.
bean
(
LiteBitmaskStruct
.
getClass
)
def
getLiteBitmaskRow
(
length
:
Int
,
bit_count
:
Int
,
intervals
:
List
[
List
[
Int
]],
bit_idx_first
:
Int
,
bit_idx_last
:
Int
)
:
Row
=
{
new
GenericRowWithSchema
(
Array
(
length
,
bit_count
,
intervals
,
bit_idx_first
,
bit_idx_last
),
LiteBitmaskSchema
)
}
def
getCombinedIntervalsUntilAllConsumed
(
alst
:
List
[
List
[
Int
]],
blst
:
List
[
List
[
Int
]])
:
Seq
[
List
[
Int
]]
=
{
var
aidx
=
0
var
bidx
=
0
...
...
@@ -47,11 +53,90 @@ object DataUDF {
}.
takeWhile
(
_
!=
exit_list
)
}
class
BitmaskError
(
message
:
String
)
extends
Exception
(
message
)
{
def
this
(
message
:
String
,
cause
:
Throwable
=
null
)
{
this
(
message
)
initCause
(
cause
)
}
def
this
(
cause
:
Throwable
)
=
{
this
(
Option
(
cause
).
map
(
_
.
toString
).
orNull
,
cause
)
}
def
this
()
=
{
this
(
null
:
String
)
}
}
def
validateLiteBitmaskSlotsLike
(
row
:
Row
)
:
Unit
=
{
val
obj
=
rowToLiteBitmaskStruct
(
row
)
validateLiteBitmaskSlotsLike
(
obj
)
}
def
validateLiteBitmaskSlotsLike
(
obj
:
LiteBitmaskStruct
)
:
Unit
=
{
val
intervals
=
obj
.
intervals
if
(
obj
.
bit_count
>
0
)
{
// println(s"obj.bit_count > 0: ${intervals} ${obj.bit_idx_first} ${obj.bit_idx_last}")
// println(s"obj.bit_count > 0: ${intervals.head}")
// println(s"obj.bit_count > 0: ${intervals.head.head}")
if
(
obj
.
bit_idx_first
!=
intervals
.
head
.
head
)
{
throw
new
BitmaskError
(
s
"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}"
)
}
if
(
obj
.
bit_idx_last
!=
intervals
.
last
.
last
)
{
throw
new
BitmaskError
(
s
"bit_idx_first:${obj.bit_idx_first} != intervals.head.head:${intervals.head.head}"
)
}
}
else
{
if
(
obj
.
bit_idx_first
!=
-
1
)
{
throw
new
BitmaskError
(
s
"bit_idx_first:${obj.bit_idx_first} != -1"
)
}
if
(
obj
.
bit_idx_last
!=
-
1
)
{
throw
new
BitmaskError
(
s
"bit_idx_first:${obj.bit_idx_first} != -1"
)
}
}
// val len_intervals = intervals.length
var
bit_count
=
0
for
((
alar
,
i
)
<-
intervals
.
zipWithIndex
)
{
val
al
=
alar
.
head
val
ar
=
alar
.
last
val
sub_bit_count
=
(
ar
-
al
)
+
1
if
(
sub_bit_count
<
1
)
{
throw
new
BitmaskError
(
s
"negative interval: [$al, $ar]"
)
}
bit_count
+=
sub_bit_count
if
(
al
>
ar
)
{
throw
new
BitmaskError
(
s
"negative interval: [$al, $ar]"
)
}
val
blbr
:
Option
[
List
[
Int
]]
=
intervals
.
lift
(
i
+
1
)
blbr
match
{
case
Some
(
x
:
List
[
Int
])
=>
{
val
bl
:
Int
=
x
.
head
val
br
:
Int
=
x
.
last
if
((
al
==
bl
)
&
(
ar
==
br
))
{
throw
new
BitmaskError
(
s
"duplicate interval: [$al, $ar]"
)
}
if
((
ar
+
1
)
==
bl
)
{
throw
new
BitmaskError
(
s
"interval not merged: ${ar}+1 == ${bl} for $alar->$blbr"
)
}
if
(
ar
>
bl
)
{
throw
new
BitmaskError
(
s
"interval out of order or not merged: ${ar} > ${bl} for $alar->$blbr"
)
}
}
case
None
=>
Nil
}
}
if
(
bit_count
!=
obj
.
bit_count
)
{
throw
new
BitmaskError
(
s
"bit_count:${bit_count} != obj.bit_count:${obj.bit_count}"
)
}
}
def
intervals_or_v0
(
length
:
Int
,
abi
:
Row
,
bbi
:
Row
)
:
LiteBitmaskStruct
=
{
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val
bitmask_a
=
rowToLiteBitmaskStruct
(
abi
)
val
bitmask_b
=
rowToLiteBitmaskStruct
(
bbi
)
intervals_or_v0
(
length
,
bitmask_a
,
bitmask_b
)
}
def
intervals_or_v0
(
length
:
Int
,
bitmask_a
:
LiteBitmaskStruct
,
bitmask_b
:
LiteBitmaskStruct
)
:
LiteBitmaskStruct
=
{
/* does a logical or of two bit intervals
* From Spark ([8,3,ArraySeq(ArraySeq(0, 2)),0,2],[8,3,ArraySeq(ArraySeq(0, 2)),0,2]) */
val
intervalGen
:
Seq
[
List
[
Int
]]
=
getCombinedIntervalsUntilAllConsumed
(
bitmask_a
.
intervals
,
bitmask_b
.
intervals
)
...
...
@@ -94,8 +179,17 @@ object DataUDF {
LiteBitmaskStruct
(
length
,
bitCount
,
intervals
.
map
(
_
.
toList
).
toList
,
bitIdxFirst
,
bitIdxLast
)
}
def
liteBitmaskStructToRow
(
bitmaskStruct
:
LiteBitmaskStruct
)
:
Row
=
{
/* convert a LiteBitmaskStruct to a Row */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val
row
:
Row
=
new
GenericRowWithSchema
(
Array
(
bitmaskStruct
.
length
,
bitmaskStruct
.
bit_count
,
bitmaskStruct
.
intervals
,
bitmaskStruct
.
bit_idx_first
,
bitmaskStruct
.
bit_idx_last
),
LiteBitmaskSchema
)
row
}
def
rowToLiteBitmaskStruct
(
row
:
Row
)
:
LiteBitmaskStruct
=
{
/* convert a
r
ow to a LiteBitmaskStruct */
/* convert a
R
ow to a LiteBitmaskStruct */
// might need this somehow: import sparkSession.implicits._
// row.asInstanceOf[LiteBitmaskStruct] // does not work
val
bitmaskStruct
:
LiteBitmaskStruct
=
row
match
{
...
...
@@ -103,12 +197,10 @@ object DataUDF {
LiteBitmaskStruct
(
length
,
bit_count
,
intervals
,
bit_idx_first
,
bit_idx_last
)
case
Row
(
length
:
Int
,
bit_count
:
Int
,
intervals
:
ArraySeq
[
ArraySeq
[
Int
]],
bit_idx_first
:
Int
,
bit_idx_last
:
Int
)
=>
LiteBitmaskStruct
(
length
,
bit_count
,
intervals
.
toList
.
map
(
_
.
toList
),
bit_idx_first
,
bit_idx_last
)
case
_
=>
LiteBitmaskStruct
(
row
.
getInt
(
0
),
row
.
getInt
(
1
),
row
.
getAs
[
List
[
List
[
Int
]]](
"intervals"
),
row
.
getInt
(
3
),
row
.
getInt
(
4
))
case
_
=>
LiteBitmaskStruct
(
row
.
getInt
(
0
),
row
.
getInt
(
1
),
row
.
getAs
[
List
[
List
[
Int
]]](
"intervals"
),
row
.
getInt
(
3
),
row
.
getInt
(
4
))
}
bitmaskStruct
}
val
logical_or
:
(
Row
,
Row
)
=>
LiteBitmaskStruct
=
(
a_bitmask
:
Row
,
b_bitmask
:
Row
)
=>
(
a_bitmask
,
b_bitmask
)
match
{
case
(
null
,
null
)
=>
null
case
(
null
,
_
)
=>
rowToLiteBitmaskStruct
(
b_bitmask
)
...
...
This diff is collapsed.
Click to expand it.
src/test/scala/Octeres/DataUDFTest.scala
+
48
−
2
View file @
57c950e9
...
...
@@ -6,12 +6,15 @@
package
Octeres
import
Octeres.DataUDF.
{
LiteBitmaskSchema
,
l
ogical_or
}
import
Octeres.DataUDF.
{
LiteBitmaskSchema
,
l
iteBitmaskStructToRow
,
logical_or
,
validateLiteBitmaskSlotsLike
}
import
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import
org.apache.spark.sql.functions._
import
org.apache.spark.sql.types._
import
org.apache.spark.sql.
{
Encoders
,
Row
,
SparkSession
}
import
org.scalatest.funsuite.AnyFunSuite
import
org.scalatest.prop.TableDrivenPropertyChecks._
import
org.scalatest.matchers.should.Matchers._
object
ExampleUDF
{
...
...
@@ -94,7 +97,7 @@ class DataUDFTest extends AnyFunSuite {
Row
(
5
,
32
,
e_bitmask
),
)
val
sparkSession
=
SparkSession
.
builder
().
appName
(
"DataUDFTest"
).
master
(
"local"
).
getOrCreate
()
va
r
df
=
sparkSession
.
createDataFrame
(
sparkSession
.
sparkContext
.
parallelize
(
deck
),
schema
=
schema
)
va
l
df
=
sparkSession
.
createDataFrame
(
sparkSession
.
sparkContext
.
parallelize
(
deck
),
schema
=
schema
)
DataUDF
.
registerAll
(
sparkSession
)
df
.
createOrReplaceTempView
(
"temp_view_a"
)
df
.
createOrReplaceTempView
(
"temp_view_b"
)
...
...
@@ -116,4 +119,47 @@ class DataUDFTest extends AnyFunSuite {
df3
.
show
(
32
,
truncate
=
false
)
df3
.
printSchema
()
}
test
(
"test_bitmask_functions"
)
{
val
bitmaskTable
=
Table
(
(
"bitmask0"
,
"bitmask1"
,
"bitmask_or"
),
(
DataUDF
.
getLiteBitmaskRow
(
8
,
3
,
List
(
List
(
0
,
2
)),
0
,
2
),
DataUDF
.
getLiteBitmaskRow
(
8
,
3
,
List
(
List
(
1
,
3
)),
1
,
3
),
DataUDF
.
getLiteBitmaskRow
(
8
,
4
,
List
(
List
(
0
,
3
)),
0
,
3
)
),
(
DataUDF
.
getLiteBitmaskRow
(
8
,
5
,
List
(
List
(
0
,
4
)),
0
,
4
),
DataUDF
.
getLiteBitmaskRow
(
8
,
0
,
List
(),
-
1
,
-
1
),
DataUDF
.
getLiteBitmaskRow
(
8
,
5
,
List
(
List
(
0
,
4
)),
0
,
4
)
),
(
DataUDF
.
getLiteBitmaskRow
(
16
,
0
,
List
(),
-
1
,
-
1
),
DataUDF
.
getLiteBitmaskRow
(
16
,
0
,
List
(),
-
1
,
-
1
),
DataUDF
.
getLiteBitmaskRow
(
16
,
0
,
List
(),
-
1
,
-
1
)
),
(
DataUDF
.
getLiteBitmaskRow
(
8
,
4
,
List
(
List
(
0
,
1
),
List
(
3
,
4
)),
0
,
4
),
DataUDF
.
getLiteBitmaskRow
(
8
,
2
,
List
(
List
(
1
,
2
)),
1
,
2
),
DataUDF
.
getLiteBitmaskRow
(
8
,
5
,
List
(
List
(
0
,
4
)),
0
,
4
)
),
(
DataUDF
.
getLiteBitmaskRow
(
8
,
4
,
List
(
List
(
1
,
2
),
List
(
5
,
6
)),
1
,
6
),
DataUDF
.
getLiteBitmaskRow
(
8
,
4
,
List
(
List
(
3
,
4
),
List
(
7
,
8
)),
3
,
4
),
DataUDF
.
getLiteBitmaskRow
(
8
,
8
,
List
(
List
(
1
,
8
)),
1
,
8
)
)
)
forAll
(
bitmaskTable
)
{
(
bitmask0
,
bitmask1
,
bitmask_or
:
Row
)
=>
{
println
(
s
"bitmask0: ${bitmask0}"
)
println
(
s
"bitmask1: ${bitmask1}"
)
println
(
s
"bitmask_or: ${bitmask_or}"
)
validateLiteBitmaskSlotsLike
(
bitmask0
)
validateLiteBitmaskSlotsLike
(
bitmask1
)
validateLiteBitmaskSlotsLike
(
bitmask_or
)
val
bitmask_r
=
logical_or
(
bitmask0
,
bitmask1
)
val
bitmask_rr
=
liteBitmaskStructToRow
(
bitmask_r
)
println
(
s
"bitmask_rr: ${bitmask_rr}"
)
validateLiteBitmaskSlotsLike
(
bitmask_rr
)
bitmask_rr
should
equal
(
bitmask_or
)
}
}
}
}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment