multistage.Stage1

class multistage.Stage1(lb, ub, in_size, out_size, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, **kwargs)Source

First stage PINN solver.

Examples

  • See tests/test_burgers.py.

Parameters:
  • lb (jax.Array) – Lower bounds of the domain [x1_min, …, x_i_min, …, x_n_min].

  • ub (jax.Array) – Upper bounds of the domain [x1_max, …, x_i_max, …, x_n_max].

  • in_size (int) – Number of dimensions of input. The input should to the network should be in_size arguments.

  • out_size (int) – The output should have shape (out_size, )

  • width_size (int) – Size of each hidden layer.

  • depth (int) – The number of hidden layers, including the output layer.

  • activation (callable) – The activation function after each hidden layer. Default is jnp.tanh.

  • params (dict[str, jax.Array]) –

    Dictionary of parameters to learn and initial guesses. E.g. for Burgers: {

    ”lambda_1”: jax.random.normal(l1_key, (1,)) * 0.1, “log_lambda_2”: -6.0 + jax.random.normal(l2_key, (1,)) * 0.1,

    }

  • params_are_trainable (bool) – Whether the params values should be frozen or an optimizable quantity. Default is False for frozen.

  • key (float) – Key for reproducibility.

  • kwargs (dict) – Keyword arguments to equinox.nn.MLP.

__init__(lb, ub, in_size, out_size, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, params=None, params_are_trainable=False, key=None, **kwargs)Source

Methods

__init__(lb, ub, in_size, out_size[, ...])

get_param(key[, default])

Return self.params["key"] if it exists and is not None else default.

print_frozen_params()

Print the frozen parameters of this network.

print_params()

Print the params of this network.

Attributes

epsilon

Estimated magnitude scale of output.

in_size

Number of input dimensions.

kappa

Estimated dominant frequency of output.

lb

Lower bound of input coordinates.

out_size

Number of output dimensions.

params

Params for this stage.

ub

Upper bound of input coordinates.