Skip to content

Additional Helper functions

get_union

get_union(ref, target)

Gives the union of a single reference and a single target group, allowing \(G(r,t)\) to be split into \(G_s(r,t)\) and \(G_t(r,t)\).

PARAMETER DESCRIPTION
ref

The array containing the indices of the reference particles.

TYPE: np.ndarray, dtype=int

target

The array containing the indices of the target particles.

TYPE: np.ndarray, dtype=int

RETURNS DESCRIPTION
union

List containing the indices contained in both the reference and the target group.

TYPE: np.ndarray, dtype=int

Source code in speadi/common_tools/get_union.py
def get_union(ref, target):
    """Gives the union of a single reference and a single target group, allowing $G(r,t)$ to be split
    into $G_s(r,t)$ and $G_t(r,t)$.

    Parameters
    ----------
    ref : np.ndarray, dtype=int
        The array containing the indices of the reference particles.
    target : np.ndarray, dtype=int
        The array containing the indices of the target particles.

    Returns
    -------
    union : np.ndarray, dtype=int
        List containing the indices contained in both the reference and the target group.

    """
    part_union, ref_indices, target_indices = np.intersect1d(ref, target, return_indices=True, assume_unique=True)
    union = [ref_indices, target_indices]

    return union

get_all_unions

get_all_unions(g1, g2, g1_lens, g2_lens)

Gives the union of all reference groups and all target groups, allowing \(G(r,t)\) to be split into \(G_s(r,t)\) and \(G_t(r,t)\).

PARAMETER DESCRIPTION
g1

List object containing arrays of reference groups.

TYPE: list

g2

List object containing arrays of target groups.

TYPE: list

g1_lens

Integer number of elements in each reference group contained in g1.

TYPE: np.ndarray, dtype=int

g2_lens

Integer number of elements in each target group contained in g2.

TYPE: np.ndarray, dtype=int

RETURNS DESCRIPTION
unions

Dictionary containing the overlapping indices in each combination of groups in g1 and g2.

TYPE: dict

Source code in speadi/common_tools/get_union.py
def get_all_unions(g1, g2, g1_lens, g2_lens):
    """Gives the union of all reference groups and all target groups, allowing $G(r,t)$ to be split
    into $G_s(r,t)$ and $G_t(r,t)$.

    Parameters
    ----------
    g1 : list
        List object containing arrays of reference groups.
    g2 : list
        List object containing arrays of target groups.
    g1_lens : np.ndarray, dtype=int
        Integer number of elements in each reference group contained in g1.
    g2_lens : np.ndarray, dtype=int
        Integer number of elements in each target group contained in g2.

    Returns
    -------
    unions : dict
        Dictionary containing the overlapping indices in each combination of groups in g1 and g2.

    """
    Ng1 = g1_lens.shape[0]
    Ng2 = g2_lens.shape[0]

    unions = {}
    for i in range(Ng1):
        unions[str(i)] = {}
        for j in range(Ng2):
            unions[str(i)][str(j)] = get_union(g1[i], g2[j])

    return unions

check_jax

check_jax()

Checks the availability of jax and the jaxlib library in the current python environment.

RETURNS DESCRIPTION
JAX_AVAILABLE

Boolean variable that other functions can use to import the correct accelerated versions of code.

TYPE: bool

Source code in speadi/common_tools/check_acceleration.py
def check_jax():
    """Checks the availability of `jax` and the `jaxlib` library in the current python environment.

    Returns
    -------
    JAX_AVAILABLE : bool
        Boolean variable that other functions can use to import the correct accelerated versions of code.

    """
    JAX_AVAILABLE = False
    try:
        from jax import __version__ as __jax_version
        from jaxlib import __version__ as __jaxlib_version
        JAX_AVAILABLE = True
    except ImportError:
        JAX_AVAILABLE = False

    return JAX_AVAILABLE

check_numba

check_numba()

Checks the availability of numba in the current python environment.

RETURNS DESCRIPTION
NUMBA_AVAILABLE

Boolean variable that other functions can use to import the correct accelerated versions of code.

TYPE: bool

Source code in speadi/common_tools/check_acceleration.py
def check_numba():
    """Checks the availability of `numba` in the current python environment.

    Returns
    -------
    NUMBA_AVAILABLE : bool
        Boolean variable that other functions can use to import the correct accelerated versions of code.

    """
    NUMBA_AVAILABLE = False
    try:
        from numba import __version__ as __numba_version__
        NUMBA_AVAILABLE = True
    except ImportError:
        NUMBA_AVAILABLE = False

    return NUMBA_AVAILABLE