Source code for jaxga.ops.simple_exp

import jax
import jax.numpy as jnp
from .multiply import get_mv_multiply
from .add import get_mv_add
from .select import get_mv_select
from ..jaxga import is_scalar_index
from functools import cache


[docs]@cache def get_mv_simple_exp(a_blade_indices, signature): mv_multiply, a_sq_indices = get_mv_multiply( a_blade_indices, a_blade_indices, signature ) mv_select_scalar, scalar_indices = get_mv_select( a_sq_indices, is_scalar_index ) mv_add, out_indices = get_mv_add(scalar_indices, a_blade_indices) def _values_mv_simple_exp(a_values): a_sq_values = mv_select_scalar(mv_multiply(a_values, a_values)) a_sq_sqrt = jnp.sign(a_sq_values) * jnp.sqrt(jnp.abs(a_sq_values)) out_scalar = jnp.where( a_sq_sqrt < 0, jnp.cos(a_sq_sqrt), jnp.where( a_sq_sqrt > 0, jnp.cosh(a_sq_sqrt), jnp.ones_like(a_sq_sqrt) ) ) out_blade = jnp.where( a_sq_sqrt < 0, a_values / a_sq_sqrt * jnp.sin(a_sq_sqrt), jnp.where( a_sq_sqrt > 0, a_values / a_sq_sqrt * jnp.sinh(a_sq_sqrt), a_values ) ) return mv_add(out_scalar, out_blade) _values_mv_simple_exp_jit = jax.jit(_values_mv_simple_exp) return _values_mv_simple_exp_jit, out_indices