Parameters
FillTriangularTransform
Bases: Transform
Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. The ordering is assumed to be: (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)
__call__
Parameter
NonNegativeReal
Bases: Parameter[T]
Parameter that is non-negative.
PositiveReal
Bases: Parameter[T]
Parameter that is strictly positive.
Real
Bases: Parameter[T]
Parameter that can take any real value.
SigmoidBounded
Bases: Parameter[T]
Parameter that is bounded between 0 and 1.
LowerTriangular
Bases: Parameter[T]
Parameter that is a lower triangular matrix.
transform
Transforms parameters using a bijector.
Example
from gpjax.parameters import PositiveReal, transform import jax.numpy as jnp import numpyro.distributions.transforms as npt from flax import nnx params = nnx.State( ... { ... "a": PositiveReal(jnp.array([1.0])), ... "b": PositiveReal(jnp.array([2.0])), ... } ... ) params_bijection = {'positive': npt.SoftplusTransform()} transformed_params = transform(params, params_bijection) print(transformed_params["a"].value) [1.3132617]
Parameters:
-
params(State) βA nnx.State object containing parameters to be transformed.
-
params_bijection(Dict[str, Transform]) βA dictionary mapping parameter types to bijectors.
-
inverse(bool, default:False) βWhether to apply the inverse transformation.
Returns:
-
State(State) βA new nnx.State object containing the transformed parameters.