Skip to content
Snippets Groups Projects
data_spark.py 1.23 KiB
Newer Older
# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.

from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, LongType, ArrayType, BooleanType, Row

from Octeres.bitmask.bitmask_lite import LiteBitmask

bitmaskType = StructType(
    [
        StructField("length", LongType(), False),
        StructField("bit_count", LongType(), False),
        StructField("bit_idx_first", LongType(), False),
        StructField("bit_idx_last", LongType(), False),
        StructField("intervals", ArrayType(ArrayType(LongType(), True), True), True),
    ]
)


@F.udf(returnType=bitmaskType)
def logical_or(a_bitmask: Row, b_bitmask: Row):
    if a_bitmask is None and b_bitmask is None:
        c_bitmask = None
    elif a_bitmask is None and b_bitmask is not None:
        c_bitmask = b_bitmask
    elif a_bitmask is not None and b_bitmask is None:
        c_bitmask = a_bitmask
    else:
        al_bitmask = LiteBitmask.from_spark(a_bitmask)
        bl_bitmask = LiteBitmask.from_spark(b_bitmask)
        cl_bitmask = LiteBitmask.logical_or(al_bitmask, bl_bitmask)
        c_bitmask = cl_bitmask.to_spark()
    return c_bitmask