# Copyright (C) 2024, UChicago Argonne, LLC # Licensed under the 3-clause BSD license. See accompanying LICENSE.txt file # in the top-level directory. from types import SimpleNamespace from typing import List from pyspark.sql import functions as F from pyspark.sql.types import StructType, StructField, LongType, ArrayType, BooleanType, Row from Octeres.bitmask.bitmask_lite import logical_or, LiteBitmaskSlots, logical_subtract, intersection_any, intersection_all, \ join_bitmasks, LiteBitmaskSlotsLike, LiteBitmask bitmaskType = StructType( [ StructField("length", LongType(), False), StructField("bit_count", LongType(), False), StructField("bit_idx_first", LongType(), False), StructField("bit_idx_last", LongType(), False), StructField("intervals", ArrayType(ArrayType(LongType(), True), True), True), ] ) def spark_logical_or(a_bitmask: Row, b_bitmask: Row) -> LiteBitmaskSlots: if a_bitmask is None and b_bitmask is None: c_bitmask = None elif a_bitmask is None and b_bitmask is not None: c_bitmask = b_bitmask elif a_bitmask is not None and b_bitmask is None: c_bitmask = a_bitmask else: c_bitmask = logical_or(a_bitmask, b_bitmask) return c_bitmask def spark_logical_subtract(a_bitmask: Row, b_bitmask: Row) -> LiteBitmaskSlots: if a_bitmask is None and b_bitmask is None: c_bitmask = None elif a_bitmask is None and b_bitmask is not None: c_bitmask = None # None - b is still None elif a_bitmask is not None and b_bitmask is None: c_bitmask = a_bitmask # a - None is still a else: c_bitmask = logical_subtract(a_bitmask, b_bitmask) return c_bitmask def spark_intersection_any(a_bitmask: Row, b_bitmask: Row) -> bool: if a_bitmask is None and b_bitmask is None: truth = False elif a_bitmask is None and b_bitmask is not None: truth = False elif a_bitmask is not None and b_bitmask is None: truth = False else: truth = intersection_any(a_bitmask, b_bitmask) return truth def spark_intersection_all(a_bitmask: Row, b_bitmask: Row) -> bool: if a_bitmask is None and b_bitmask is None: truth = False elif a_bitmask is None and b_bitmask is not None: truth = False elif a_bitmask is not None and b_bitmask is None: truth = False else: truth = intersection_all(a_bitmask, b_bitmask) return truth def spark_hostname_fold(hostnames): """given a list of hostnames, fold them as done in nodeset -f""" pass def spark_hostnames_to_bitmask(hostnames): """given a list of hostnames, return the bitmask""" pass def spark_join_bitmasks(bitmask_lst: list, operation="or") -> LiteBitmaskSlots: if operation == 'or': if type(bitmask_lst) == list: # FIXME: need to check for len(bitmask_lst) > 0: or this will throw an exception. # also, consider adding length to join_bitmasks to return a zero bitmask. bitmask = join_bitmasks(bitmask_lst, operation=operation) else: raise NotImplementedError(f"type {type(bitmask_lst)} not supported") else: raise NotImplementedError(f"operation {operation} not supported") return bitmask # udtf # https://spark.apache.org/docs/latest/api/python/user_guide/sql/python_udtf.html def register_all(spark): udf_logical_or = F.udf(spark_logical_or, bitmaskType) udf_logical_subtract = F.udf(spark_logical_subtract, bitmaskType) udf_intersection_any = F.udf(spark_intersection_any, BooleanType()) udf_intersection_all = F.udf(spark_intersection_all, BooleanType()) udf_join_bitmasks = F.udf(spark_join_bitmasks, bitmaskType) udf_lookup = { "logical_or": udf_logical_or, "logical_subtract": udf_logical_subtract, "intersection_any": udf_intersection_any, "intersection_all": udf_intersection_all, "join_bitmasks": udf_join_bitmasks, } for name, udf_func in udf_lookup.items(): spark.udf.register(name, udf_func) return SimpleNamespace(**udf_lookup)