Newer
Older
# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file
# in the top-level directory.

Eric Pershey
committed
from types import SimpleNamespace
from typing import List

Eric Pershey
committed
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, LongType, ArrayType, BooleanType, Row

Eric Pershey
committed
from Octeres.bitmask.bitmask_lite import logical_or, LiteBitmaskSlots, logical_subtract, intersection_any, intersection_all, \
join_bitmasks, LiteBitmaskSlotsLike, LiteBitmask

Eric Pershey
committed
bitmaskType = StructType(
[
StructField("length", LongType(), False),
StructField("bit_count", LongType(), False),
StructField("bit_idx_first", LongType(), False),
StructField("bit_idx_last", LongType(), False),
StructField("intervals", ArrayType(ArrayType(LongType(), True), True), True),
]
)

Eric Pershey
committed
def spark_logical_or(a_bitmask: Row, b_bitmask: Row) -> LiteBitmaskSlots:

Eric Pershey
committed
if a_bitmask is None and b_bitmask is None:
c_bitmask = None
elif a_bitmask is None and b_bitmask is not None:
c_bitmask = b_bitmask
elif a_bitmask is not None and b_bitmask is None:
c_bitmask = a_bitmask
else:

Eric Pershey
committed
c_bitmask = logical_or(a_bitmask, b_bitmask)
return c_bitmask

Eric Pershey
committed
def spark_logical_subtract(a_bitmask: Row, b_bitmask: Row) -> LiteBitmaskSlots:
if a_bitmask is None and b_bitmask is None:
c_bitmask = None
elif a_bitmask is None and b_bitmask is not None:
c_bitmask = None # None - b is still None
elif a_bitmask is not None and b_bitmask is None:
c_bitmask = a_bitmask # a - None is still a
else:

Eric Pershey
committed
c_bitmask = logical_subtract(a_bitmask, b_bitmask)

Eric Pershey
committed
return c_bitmask

Eric Pershey
committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def spark_intersection_any(a_bitmask: Row, b_bitmask: Row) -> bool:
if a_bitmask is None and b_bitmask is None:
truth = False
elif a_bitmask is None and b_bitmask is not None:
truth = False
elif a_bitmask is not None and b_bitmask is None:
truth = False
else:
truth = intersection_any(a_bitmask, b_bitmask)
return truth
def spark_intersection_all(a_bitmask: Row, b_bitmask: Row) -> bool:
if a_bitmask is None and b_bitmask is None:
truth = False
elif a_bitmask is None and b_bitmask is not None:
truth = False
elif a_bitmask is not None and b_bitmask is None:
truth = False
else:
truth = intersection_all(a_bitmask, b_bitmask)
return truth
def spark_hostname_fold(hostnames):
"""given a list of hostnames, fold them as done in nodeset -f"""
pass
def spark_hostnames_to_bitmask(hostnames):
"""given a list of hostnames, return the bitmask"""
pass
def spark_join_bitmasks(bitmask_lst: list, operation="or") -> LiteBitmaskSlots:
if operation == 'or':
if type(bitmask_lst) == list:
bitmask = join_bitmasks(bitmask_lst, operation=operation)
else:
raise NotImplementedError(f"type {type(bitmask_lst)} not supported")
else:
raise NotImplementedError(f"operation {operation} not supported")
return bitmask
# udtf
# https://spark.apache.org/docs/latest/api/python/user_guide/sql/python_udtf.html
def register_all(spark):
udf_logical_or = F.udf(spark_logical_or, bitmaskType)
udf_logical_subtract = F.udf(spark_logical_subtract, bitmaskType)
udf_intersection_any = F.udf(spark_intersection_any, BooleanType())
udf_intersection_all = F.udf(spark_intersection_all, BooleanType())
# udf_join_bitmasks = F.udf(spark_join_bitmasks, bitmaskType)
udf_lookup = {
"logical_or": udf_logical_or,
"logical_subtract": udf_logical_subtract,
"intersection_any": udf_intersection_any,
"intersection_all": udf_intersection_all,
}
for name, udf_func in udf_lookup.items():
spark.udf.register(name, udf_func)
return SimpleNamespace(**udf_lookup)