multistage.Stage2
- class multistage.Stage2(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, feature_map='separable', **kwargs)Source
Initializes the Stage 2 PINN model.
Examples
See
tests/test_burgers.py.
- Parameters:
s1 (Stage1) – The frozen model from stage 1.
epsilon (float) – Approximate magnitude of output.
kappa (jax.Array) – Approximate angular frequency of this stage’s output in normalized coordinates. For
feature_map="separable", each entry is the target frequency for one input direction. Forfeature_map="random", the vector gives anisotropic component scales; if all entries are equal, the random wave-vector RMS norm matches that common value. Shape(s1.in_size,).width_size (int) – Size of each hidden layer.
depth (int) – The number of hidden layers, including the output layer.
activation (callable) – The activation function for each hidden layer after the first. Default is
jnp.tanh.params (dict[str, jax.Array]) –
Dictionary of parameter corrections to learn and initial guesses. The automatic multistage constructors initialize corrections so that the total PDE parameters are unchanged at stage creation. Manual transformed parameters are allowed, but must be initialized in their transformed coordinates. E.g. for Burgers: {
”lambda_1”: jax.random.normal(l1_key, (1,)) * 0.1, “log_lambda_2”: jnp.log(0.5),
}
params_are_trainable (bool) – Whether the
paramsvalues should be frozen or an optimizable quantity. Default is False for frozen.key (float) – Key for reproducibility.
chebyshev (bool) – Whether the frequency
kappais associated with a Chebyshev feature mapping instead of Fourier. Default is False.feature_map ({"separable", "random"}) – First-layer Fourier feature geometry.
"separable"assigns each sinusoidal feature to one input coordinate;"random"preserves the original dense random plane-wave mapping and interpretskappaas an isotropic wave-vector norm when all entries are equal.kwargs (dict) – Keyword arguments to
equinox.nn.MLP.
- __init__(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, feature_map='separable', **kwargs)Source
Methods
__init__(s1, epsilon, kappa[, width_size, ...])compute_s2(*args)Compute just this stage of the output.
get_param(key[, default])Return
self.params["key"]if it exists and is not None else default.print_frozen_params()Print the frozen parameters of this network.
print_params()Print the params of this network.
Attributes
epsilonEstimated magnitude scale of output for this stage.
in_sizeNumber of input dimensions.
kappaEstimated dominant frequency of output for this stage.
lbLower bound of input coordinates.
out_sizeNumber of output dimensions.
paramsParams for this stage.
s1Returns the previous stage network.
ubUpper bound of input coordinates.