Source code for jaxga.ops.select

import jax
import jax.numpy as jnp
from functools import cache


[docs]@cache def get_mv_select(a_blade_indices, select_index): out_blade_indices = [] out_a_indices = [] for i, a_index in enumerate(a_blade_indices): if select_index(a_index): out_blade_indices.append(a_index) out_a_indices.append(i) out_a_indices = jnp.array(out_a_indices, dtype=jnp.int32) def _values_mv_select(a_values): return a_values[out_a_indices] _values_mv_select_jit = jax.jit(_values_mv_select) return _values_mv_select_jit, tuple(out_blade_indices)