Source code for jaxga.ops.dual
import jax
import jax.numpy as jnp
from ..jaxga import dual_blade_index
from functools import cache
[docs]@cache
def get_mv_dual(a_blade_indices, dims):
out_blade_indices = []
signs = jnp.empty([len(a_blade_indices)], dtype=jnp.float32)
for i, blade_index in enumerate(a_blade_indices):
sign, dual_index = dual_blade_index(blade_index, dims)
out_blade_indices.append(dual_index)
signs = signs.at[i].set(sign)
def _values_mv_dual(a_values):
return a_values * signs
_values_mv_dual_jit = jax.jit(_values_mv_dual)
return _values_mv_dual_jit, tuple(out_blade_indices)