multistage.Stage2

class multistage.Stage2(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, feature_map='separable', **kwargs)Source

Initializes the Stage 2 PINN model.

Examples

  • See tests/test_burgers.py.

Parameters:
  • s1 (Stage1) – The frozen model from stage 1.

  • epsilon (float) – Approximate magnitude of output.

  • kappa (jax.Array) – Approximate angular frequency of this stage’s output in normalized coordinates. For feature_map="separable", each entry is the target frequency for one input direction. For feature_map="random", the vector gives anisotropic component scales; if all entries are equal, the random wave-vector RMS norm matches that common value. Shape (s1.in_size,).

  • width_size (int) – Size of each hidden layer.

  • depth (int) – The number of hidden layers, including the output layer.

  • activation (callable) – The activation function for each hidden layer after the first. Default is jnp.tanh.

  • params (dict[str, jax.Array]) –

    Dictionary of parameter corrections to learn and initial guesses. The automatic multistage constructors initialize corrections so that the total PDE parameters are unchanged at stage creation. Manual transformed parameters are allowed, but must be initialized in their transformed coordinates. E.g. for Burgers: {

    ”lambda_1”: jax.random.normal(l1_key, (1,)) * 0.1, “log_lambda_2”: jnp.log(0.5),

    }

  • params_are_trainable (bool) – Whether the params values should be frozen or an optimizable quantity. Default is False for frozen.

  • key (float) – Key for reproducibility.

  • chebyshev (bool) – Whether the frequency kappa is associated with a Chebyshev feature mapping instead of Fourier. Default is False.

  • feature_map ({"separable", "random"}) – First-layer Fourier feature geometry. "separable" assigns each sinusoidal feature to one input coordinate; "random" preserves the original dense random plane-wave mapping and interprets kappa as an isotropic wave-vector norm when all entries are equal.

  • kwargs (dict) – Keyword arguments to equinox.nn.MLP.

__init__(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, feature_map='separable', **kwargs)Source

Methods

__init__(s1, epsilon, kappa[, width_size, ...])

compute_s2(*args)

Compute just this stage of the output.

get_param(key[, default])

Return self.params["key"] if it exists and is not None else default.

print_frozen_params()

Print the frozen parameters of this network.

print_params()

Print the params of this network.

Attributes

epsilon

Estimated magnitude scale of output for this stage.

in_size

Number of input dimensions.

kappa

Estimated dominant frequency of output for this stage.

lb

Lower bound of input coordinates.

out_size

Number of output dimensions.

params

Params for this stage.

s1

Returns the previous stage network.

ub

Upper bound of input coordinates.