multistage.Stage2

class multistage.Stage2(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, **kwargs)Source

Initializes the Stage 2 PINN model.

Examples

  • See tests/test_burgers.py.

Parameters:
  • s1 (Stage1) – The frozen model from stage 1.

  • epsilon (float) – Approximate magnitude of output.

  • kappa (jax.Array) – Approximate dominant frequency of output in each direction. Shape (s1._mlp.in_size, )

  • width_size (int) – Size of each hidden layer.

  • depth (int) – The number of hidden layers, including the output layer.

  • activation (callable) – The activation function for each hidden layer after the first. Default is jnp.tanh.

  • params (dict[str, jax.Array]) –

    Dictionary of parameters to learn and initial guesses. E.g. for Burgers: {

    ”lambda_1”: jax.random.normal(l1_key, (1,)) * 0.1, “log_lambda_2”: -6.0 + jax.random.normal(l2_key, (1,)) * 0.1,

    }

  • params_are_trainable (bool) – Whether the params values should be frozen or an optimizable quantity. Default is False for frozen.

  • key (float) – Key for reproducibility.

  • chebyshev (bool) – Whether the frequency kappa is associated with a Chebyshev feature mapping instead of Fourier. Default is False.

  • kwargs (dict) – Keyword arguments to equinox.nn.MLP.

__init__(s1, epsilon, kappa, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, *, chebyshev=False, **kwargs)Source

Methods

__init__(s1, epsilon, kappa[, width_size, ...])

compute_s2(*args)

Compute just this stage of the output.

get_param(key[, default])

Return self.params["key"] if it exists and is not None else default.

print_frozen_params()

Print the frozen parameters of this network.

print_params()

Print the params of this network.

Attributes

epsilon

Estimated magnitude scale of output for this stage.

in_size

Number of input dimensions.

kappa

Estimated dominant frequency of output for this stage.

lb

Lower bound of input coordinates.

out_size

Number of output dimensions.

params

Params for this stage.

s1

Returns the previous stage network.

ub

Upper bound of input coordinates.