Source code for jaxga.ops.keepnonzero
import jax
import jax.numpy as jnp
[docs]def get_mv_keep_nonzero(a_blade_indices, a_values):
out_blade_indices = []
out_a_indices = []
for i, a_index in enumerate(a_blade_indices):
if not jnp.allclose(a_values[i], 0):
out_blade_indices.append(a_index)
out_a_indices.append(i)
out_a_indices = jnp.array(out_a_indices, dtype=jnp.int32)
def _values_mv_keep_nonzero(a_values):
return a_values[out_a_indices]
_values_mv_keep_nonzero_jit = jax.jit(_values_mv_keep_nonzero)
return _values_mv_keep_nonzero_jit, tuple(out_blade_indices)