Source code for jaxga.mv

from .ops.sandwich import get_mv_sandwich
from .ops.inverse import get_mv_inverse
from .ops.dual import get_mv_dual
from .ops.reduce_same import get_mv_reduce_same
from .ops.keepnonzero import get_mv_keep_nonzero
from .ops.multiply import get_mv_multiply
from .ops.add import get_mv_add
from .ops.simple_exp import get_mv_simple_exp
from .ops.select import get_mv_select
from .jaxga import reverse_indices, mv_repr
from .signatures import positive_signature
import jax.numpy as jnp


[docs]class MultiVector:
[docs] def e(*indices, **kwargs): signature = kwargs["signature"] if "signature" in kwargs else positive_signature batch_shape = ((1,) + tuple(kwargs["batch_shape"])) if "batch_shape" in kwargs else (1,) return MultiVector( values=jnp.ones(batch_shape, dtype=jnp.float32), indices=(tuple(indices),), signature=signature )
def __init__(self, values, indices, signature=positive_signature): self.values = values self.indices = tuple(indices) self.signature = signature def __add__(self, other): if not isinstance(other, MultiVector): other = MultiVector.e() * other mv_add, out_indices = get_mv_add(self.indices, other.indices) out_values = mv_add(self.values, other.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) def __radd__(self, other): if not isinstance(other, MultiVector): other = MultiVector.e() * other mv_add, out_indices = get_mv_add(other.indices, self.indices) out_values = mv_add(other.values, self.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) def __sub__(self, other): if not isinstance(other, MultiVector): other = MultiVector.e() * other mv_add, out_indices = get_mv_add(self.indices, other.indices) out_values = mv_add(self.values, -other.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) def __rsub__(self, other): if not isinstance(other, MultiVector): other = MultiVector.e() * other mv_add, out_indices = get_mv_add(other.indices, self.indices) out_values = mv_add(other.values, -self.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) def __mul__(self, other): if isinstance(other, MultiVector): mv_multiply, out_indices = get_mv_multiply( self.indices, other.indices, self.signature) out_values = mv_multiply(self.values, other.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) return MultiVector(values=self.values * other, indices=self.indices, signature=self.signature) def __rmul__(self, other): if isinstance(other, MultiVector): mv_multiply, out_indices = get_mv_multiply( other.indices, self.indices, self.signature) out_values = mv_multiply(other.values, self.values) return MultiVector(values=out_values, indices=out_indices) return MultiVector(values=self.values * other, indices=self.indices, signature=self.signature) def __xor__(self, other): mv_multiply, out_indices = get_mv_multiply( self.indices, other.indices, self.signature, "op") out_values = mv_multiply(self.values, other.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) def __or__(self, other): mv_multiply, out_indices = get_mv_multiply( self.indices, other.indices, self.signature, "ip") out_values = mv_multiply(self.values, other.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature) def __invert__(self): return MultiVector(values=self.values, indices=reverse_indices(self.indices), signature=self.signature) def __neg__(self): return MultiVector(values=-self.values, indices=self.indices, signature=self.signature)
[docs] def sandwich(self, other): mv_sandwich, out_indices = get_mv_sandwich( self.indices, other.indices, self.signature) out_values = mv_sandwich(self.values, other.values) return MultiVector(values=out_values, indices=out_indices, signature=self.signature)
[docs] def inverse(self): mv_inv, inv_indices = get_mv_inverse(self.indices, self.signature) inv_values = mv_inv(self.values) return MultiVector(values=inv_values, indices=inv_indices, signature=self.signature)
def __truediv__(self, other): if isinstance(other, MultiVector): return self * other.inverse() return self * (1 / other) def __rtruediv__(self, other): return other * self.inverse() def __repr__(self): return mv_repr(self.indices, self.values) def __getitem__(self, select_indices): mv_select, out_indices = get_mv_select(self.indices, select_indices) out_values = mv_select(self.values) return MultiVector(out_values, out_indices, signature=self.signature)
[docs] def simple_exp(self): mv_simple_exp, out_indices = get_mv_simple_exp( self.indices, self.signature) out_values = mv_simple_exp(self.values) return MultiVector(out_values, out_indices, signature=self.signature)
[docs] def keep_nonzero(self): mv_keep_nonzero, out_indices = get_mv_keep_nonzero( self.indices, self.values ) out_values = mv_keep_nonzero(self.values) return MultiVector(out_values, out_indices, signature=self.signature)
[docs] def reduce_same(self): mv_reduce_same, out_indices = get_mv_reduce_same( self.indices ) out_values = mv_reduce_same(self.values) return MultiVector(out_values, out_indices, signature=self.signature)
[docs] def dual(self, dims): mv_dual, out_indices = get_mv_dual(self.indices, dims) out_values = mv_dual(self.values) return MultiVector(out_values, out_indices, signature=self.signature)