Source code for jaxga.ops.inverse

import jax
from functools import cache
from .multiply import get_mv_multiply
from .select import get_mv_select
from .add import get_mv_add
from ..jaxga import is_scalar_index
import itertools


[docs]@cache def get_mv_inverse(a_blade_indices, signature): dims = len(set(itertools.chain.from_iterable(a_blade_indices))) n = 2 ** ((dims + 1) // 2) last_ind = a_blade_indices selects = [] adds = [] mults = [] for k in range(1, n): select, select_ind = get_mv_select(last_ind, is_scalar_index) add, add_ind = get_mv_add(last_ind, select_ind) mult, mult_ind = get_mv_multiply(a_blade_indices, add_ind, signature) selects.append(select) adds.append(add) mults.append(mult) last_ind = mult_ind select_last_scalar, _ = get_mv_select(last_ind, is_scalar_index) def _values_mv_inverse(a_values): u = a_values for k, (select, add, mult) in enumerate(zip(selects, adds, mults)): c = n / (k + 1) * select(u) u_minus_c = add(u, -c) u = mult(a_values, u_minus_c) return u_minus_c / select_last_scalar(u) _values_mv_inverse_jit = jax.jit(_values_mv_inverse) return _values_mv_inverse_jit, add_ind