Skip to content

Parameters

FillTriangularTransform

Bases: Transform

Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix. The ordering is assumed to be: (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)

__call__

__call__(x)

Forward transformation.

Parameters

x : array_like, shape (..., L) Input vector with L = n(n+1)/2 for some integer n.

Returns

y : array_like, shape (..., n, n) Lower-triangular matrix (with zeros in the upper triangle) filled in row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]).

Parameter

Parameter(value, tag, prior=None, **kwargs)

Bases: Variable[T]

Parameter base class.

All trainable parameters in GPJax should inherit from this class.

tag property

tag

Return the parameter's constraint tag.

NonNegativeReal

NonNegativeReal(value, tag='non_negative', **kwargs)

Bases: Parameter[T]

Parameter that is non-negative.

tag property

tag

Return the parameter's constraint tag.

PositiveReal

PositiveReal(value, tag='positive', **kwargs)

Bases: Parameter[T]

Parameter that is strictly positive.

tag property

tag

Return the parameter's constraint tag.

Real

Real(value, tag='real', **kwargs)

Bases: Parameter[T]

Parameter that can take any real value.

tag property

tag

Return the parameter's constraint tag.

SigmoidBounded

SigmoidBounded(value, tag='sigmoid', **kwargs)

Bases: Parameter[T]

Parameter that is bounded between 0 and 1.

tag property

tag

Return the parameter's constraint tag.

LowerTriangular

LowerTriangular(value, tag='lower_triangular', **kwargs)

Bases: Parameter[T]

Parameter that is a lower triangular matrix.

tag property

tag

Return the parameter's constraint tag.

transform

transform(params, params_bijection, inverse=False)

Transforms parameters using a bijector.

Example

from gpjax.parameters import PositiveReal, transform import jax.numpy as jnp import numpyro.distributions.transforms as npt from flax import nnx params = nnx.State( ... { ... "a": PositiveReal(jnp.array([1.0])), ... "b": PositiveReal(jnp.array([2.0])), ... } ... ) params_bijection = {'positive': npt.SoftplusTransform()} transformed_params = transform(params, params_bijection) print(transformed_params["a"].value) [1.3132617]

Parameters:

  • params (State) –

    A nnx.State object containing parameters to be transformed.

  • params_bijection (Dict[str, Transform]) –

    A dictionary mapping parameter types to bijectors.

  • inverse (bool, default: False ) –

    Whether to apply the inverse transformation.

Returns:

  • State ( State ) –

    A new nnx.State object containing the transformed parameters.