Source code for jaxga.ops.reduce_same
import jax
import jax.numpy as jnp
from jaxga.signatures import positive_signature
from ..jaxga import reduce_bases
from functools import cache
[docs]@cache
def get_mv_reduce_same(a_blade_indices):
blade_to_index = {}
blade_to_blade_index = {}
indices = []
unique_count = len(set(a_blade_indices))
out_indices = [[] for _ in range(unique_count)]
out_signs = [[] for _ in range(unique_count)]
for i, blade_index in enumerate(a_blade_indices):
blade_index_set = frozenset(blade_index)
if blade_index_set in blade_to_index:
index = blade_to_index[blade_index_set]
sign, _ = reduce_bases(blade_index, blade_to_blade_index[blade_index_set], positive_signature)
else:
index = len(blade_to_index)
sign = 1
blade_to_index[blade_index_set] = index
blade_to_blade_index[blade_index_set] = blade_index
indices.append(blade_index)
out_indices[index].append(i)
out_signs[index].append(sign)
def _values_mv_reduce_same(a_values):
out_batch_shape = a_values.shape[1:]
result = jnp.empty([len(out_indices), *out_batch_shape], dtype=jnp.float32)
for i, (mm, signs) in enumerate(zip(out_indices, out_signs)):
for j, (m, sign) in enumerate(zip(mm, signs)):
if j == 0:
result = result.at[i].set(sign * a_values[m])
else:
result = result.at[i].add(sign * a_values[m])
return result
_values_mv_reduce_same_jit = jax.jit(_values_mv_reduce_same)
return _values_mv_reduce_same_jit, tuple(indices)