Skip to content
Snippets Groups Projects
data_spark.py 3.86 KiB
Newer Older
# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.
from types import SimpleNamespace
from typing import List

from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, LongType, ArrayType, BooleanType, Row

from Octeres.bitmask.bitmask_lite import logical_or, LiteBitmaskSlots, logical_subtract, intersection_any, intersection_all, \
    join_bitmasks, LiteBitmaskSlotsLike, LiteBitmask

bitmaskType = StructType(
    [
        StructField("length", LongType(), False),
        StructField("bit_count", LongType(), False),
        StructField("bit_idx_first", LongType(), False),
        StructField("bit_idx_last", LongType(), False),
        StructField("intervals", ArrayType(ArrayType(LongType(), True), True), True),
    ]
)


def spark_logical_or(a_bitmask: Row, b_bitmask: Row) -> LiteBitmaskSlots:
    if a_bitmask is None and b_bitmask is None:
        c_bitmask = None
    elif a_bitmask is None and b_bitmask is not None:
        c_bitmask = b_bitmask
    elif a_bitmask is not None and b_bitmask is None:
        c_bitmask = a_bitmask
    else:
def spark_logical_subtract(a_bitmask: Row, b_bitmask: Row) -> LiteBitmaskSlots:
    if a_bitmask is None and b_bitmask is None:
        c_bitmask = None
    elif a_bitmask is None and b_bitmask is not None:
        c_bitmask = None  # None - b is still None
    elif a_bitmask is not None and b_bitmask is None:
        c_bitmask = a_bitmask  # a - None is still a
    else:
        c_bitmask = logical_subtract(a_bitmask, b_bitmask)


def spark_intersection_any(a_bitmask: Row, b_bitmask: Row) -> bool:
    if a_bitmask is None and b_bitmask is None:
        truth = False
    elif a_bitmask is None and b_bitmask is not None:
        truth = False
    elif a_bitmask is not None and b_bitmask is None:
        truth = False
    else:
        truth = intersection_any(a_bitmask, b_bitmask)
    return truth


def spark_intersection_all(a_bitmask: Row, b_bitmask: Row) -> bool:
    if a_bitmask is None and b_bitmask is None:
        truth = False
    elif a_bitmask is None and b_bitmask is not None:
        truth = False
    elif a_bitmask is not None and b_bitmask is None:
        truth = False
    else:
        truth = intersection_all(a_bitmask, b_bitmask)
    return truth


def spark_hostname_fold(hostnames):
    """given a list of hostnames, fold them as done in nodeset -f"""
    pass


def spark_hostnames_to_bitmask(hostnames):
    """given a list of hostnames, return the bitmask"""
    pass


def spark_join_bitmasks(bitmask_lst: list, operation="or") -> LiteBitmaskSlots:
    if operation == 'or':
        if type(bitmask_lst) == list:
            bitmask = join_bitmasks(bitmask_lst, operation=operation)
        else:
            raise NotImplementedError(f"type {type(bitmask_lst)} not supported")
    else:
        raise NotImplementedError(f"operation {operation} not supported")
    return bitmask


# udtf
# https://spark.apache.org/docs/latest/api/python/user_guide/sql/python_udtf.html


def register_all(spark):
    udf_logical_or = F.udf(spark_logical_or, bitmaskType)
    udf_logical_subtract = F.udf(spark_logical_subtract, bitmaskType)
    udf_intersection_any = F.udf(spark_intersection_any, BooleanType())
    udf_intersection_all = F.udf(spark_intersection_all, BooleanType())
Eric Pershey's avatar
Eric Pershey committed
    udf_join_bitmasks = F.udf(spark_join_bitmasks, bitmaskType)
    udf_lookup = {
        "logical_or": udf_logical_or,
        "logical_subtract": udf_logical_subtract,
        "intersection_any": udf_intersection_any,
        "intersection_all": udf_intersection_all,
Eric Pershey's avatar
Eric Pershey committed
        "join_bitmasks": udf_join_bitmasks,
    }
    for name, udf_func in udf_lookup.items():
        spark.udf.register(name, udf_func)
    return SimpleNamespace(**udf_lookup)