Source code for jaxga.ops.add

import jax
import jax.numpy as jnp
from .reduce_same import get_mv_reduce_same
from functools import cache


[docs]@cache def get_mv_add(a_blade_indices, b_blade_indices): out_blade_indices = a_blade_indices + b_blade_indices mv_reduce_same, out_blade_indices = get_mv_reduce_same(out_blade_indices) def _values_mv_reduce_same(a_values, b_values): return mv_reduce_same(jnp.concatenate([a_values, b_values], axis=0)) _values_mv_add_jit = jax.jit(_values_mv_reduce_same) return _values_mv_add_jit, out_blade_indices