multistage.Stage2
- class multistage.Stage2(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, **kwargs)Source
Initializes the Stage 2 PINN model.
Examples
See
tests/test_burgers.py.
- Parameters:
s1 (Stage1) – The frozen model from stage 1.
epsilon (float) – Approximate magnitude of output.
kappa (jax.Array) – Approximate dominant frequency of output in each direction. Shape (s1._mlp.in_size, )
width_size (int) – Size of each hidden layer.
depth (int) – The number of hidden layers, including the output layer.
activation (callable) – The activation function for each hidden layer after the first. Default is
jnp.tanh.params (dict[str, jax.Array]) –
Dictionary of parameters to learn and initial guesses. E.g. for Burgers: {
”lambda_1”: jax.random.normal(l1_key, (1,)) * 0.1, “log_lambda_2”: -6.0 + jax.random.normal(l2_key, (1,)) * 0.1,
}
params_are_trainable (bool) – Whether the
paramsvalues should be frozen or an optimizable quantity. Default is False for frozen.key (float) – Key for reproducibility.
chebyshev (bool) – Whether the frequency
kappais associated with a Chebyshev feature mapping instead of Fourier. Default is False.kwargs (dict) – Keyword arguments to
equinox.nn.MLP.
- __init__(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, **kwargs)Source
Methods
__init__(s1, epsilon, kappa[, width_size, ...])compute_s2(*args)Compute just this stage of the output.
get_param(key[, default])Return
self.params["key"]if it exists and is not None else default.print_frozen_params()Print the frozen parameters of this network.
print_params()Print the params of this network.
Attributes
epsilonEstimated magnitude scale of output for this stage.
in_sizeNumber of input dimensions.
kappaEstimated dominant frequency of output for this stage.
lbLower bound of input coordinates.
out_sizeNumber of output dimensions.
paramsParams for this stage.
s1Returns the previous stage network.
ubUpper bound of input coordinates.