Utils
jax_gather_nd
Slice a params array at a set of indices.
This is a reimplementation of TensorFlow's gather_nd function:
link
Parameters:
-
params(Float[Array, ' N *rest']) βan arbitrary array with leading axes of length \(N\) upon which we shall slice.
-
indices(Int[Array, ' M 1']) βan integer array of length \(M\) with values in the range \([0, N)\) whose value at index \(i\) will be used to slice
paramsat index \(i\).
Returns:
-
Float[Array, ' M *rest']βAn arbitrary array with leading axes of length \(M\).
calculate_heat_semigroup
Returns the rescaled heat semigroup, S
Parameters:
-
kernel(GraphKernel) βinstance of the graph kernel
Returns:
-
Float[Array, 'N M']βS