# Copyright (C) 2024, UChicago Argonne, LLC
# Licensed under the 3-clause BSD license.  See accompanying LICENSE.txt file
# in the top-level directory.

from collections import OrderedDict
from typing import TypeVar, Dict
from collections.abc import Callable

K = TypeVar("K")
V = TypeVar("V")


def merge_dictionaries(dct_a: Dict[K, V], dct_b: Dict[K, V], default: object, operation: Callable) -> Dict[K, V]:
    """given two dictionaries, merge them.  If the key is in both, use the operation to combine
    the two values."""
    dct_c = OrderedDict()
    for key_a, value_a in dct_a.items():
        dct_c[key_a] = dct_a[key_a]
    for key_b, value_b in dct_b.items():
        dct_c[key_b] = operation(dct_c.get(key_b, default), value_b)
    return dct_c


class partial_arg_kw:
    """a helper function to apply arguments and kwargs to a function.
    partial args are applied first and kwargs are applied after the partial kwargs."""

    def __init__(self, func, *args, **kwargs):
        self.func = func
        self.args = args
        self.kwargs = kwargs

    def __call__(self, *args, **kwargs):
        return self.func(*self.args, *args, **kwargs, **self.kwargs)