Source code for jaxga.ops.sandwich
import itertools
import jax
import jax.numpy as jnp
from ..jaxga import reduce_bases, reverse_indices
from functools import cache
[docs]@cache
def get_mv_sandwich(a_blade_indices, b_blade_indices, signature, prod="gp"):
"""a b ~a"""
out_indices = []
out_blade_indices = []
out_signs = []
out_indices = []
indices_a = []
indices_b = []
indices_a_r = []
blade_to_index = {}
for (i_a, index_a), (i_b, index_b), (i_a_r, index_a_r) in itertools.product(
enumerate(a_blade_indices),
enumerate(b_blade_indices),
enumerate(reverse_indices(a_blade_indices))
):
out_sign_1, out_index_1 = reduce_bases(index_a, index_b, signature)
out_sign_2, out_index = reduce_bases(out_index_1, index_a_r, signature)
out_sign = out_sign_1 * out_sign_2
if out_sign != 0 and (
prod == "gp" or
(prod == "op" and len(out_index) == abs(len(index_a) + len(index_b))) or
(prod == "ip" and len(out_index) == abs(len(index_a) - len(index_b)))
):
out_signs.append(out_sign)
indices_a.append(i_a)
indices_b.append(i_b)
indices_a_r.append(i_a_r)
if out_index in blade_to_index:
out_indices.append(blade_to_index[out_index])
else:
blade_to_index[out_index] = len(blade_to_index)
out_indices.append(blade_to_index[out_index])
out_blade_indices.append(out_index)
if len(out_indices) == 0:
def _values_mv_sandwich(a_values, b_values):
return jnp.zeros((), dtype=jnp.float32)
else:
out_size = max(out_indices) + 1
def _values_mv_sandwich(a_values, b_values):
out_batch_shape = jnp.broadcast_shapes(
a_values.shape[1:], b_values.shape[1:]
)
out_values = jnp.zeros(
[out_size, *out_batch_shape], dtype=jnp.float32
)
for index_a, index_b, index_a_r, out_sign, out_index in zip(indices_a, indices_b, indices_a_r, out_signs, out_indices):
out_values = out_values.at[out_index].add(
out_sign * a_values[index_a] * b_values[index_b] * a_values[index_a_r]
)
return out_values
_values_mv_sandwich_jit = jax.jit(_values_mv_sandwich)
return _values_mv_sandwich_jit, tuple(out_blade_indices)