multistage.multistage_train

multistage.multistage_train(net, residual_fun_s1, residual_fun_s2, loss_fun_s1, loss_fun_s2, x, training_samples, optimizer, steps, *, learning_rate=None, adaptive_sample_freq=1000, n_stages=2, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, num_samples_for_epsilon=(1024, ), order=(1, ), beta_fun=None, heuristic=0.9, chebyshev=False, x_stage2=None, training_samples_stage2=None, return_loss_history=True, print_every=100, key=None, net_kwargs_for_save=None, name='', checkpoint_dir='checkpoints', checkpoint_every=5000, benchmark_state=None, **adaptive_sample_kwargs)Source

Multi-stage training.

Examples

  • See tests/test_burgers.py.

Parameters:
  • net (eqx.Module) – The initial model architecture.

  • residual_fun_s1 (callable) – Function to compute PDE residuals for the first stage.

  • residual_fun_s2 (callable) – Function to compute PDE residuals for subsequent stages.

  • loss_fun_s1 (callable) – Scalar output loss function for the first stage.

  • loss_fun_s2 (callable) – Scalar output loss function for subsequent stages.

  • x (tuple or list of jax.Array) – Input coordinates for the first stage.

  • training_samples (jax.Array) – Target values for the first stage.

  • optimizer (optax.GradientTransformation) – Optimizer for training loops.

  • steps (int) – Number of training steps per stage.

  • learning_rate (float, optional) – Learning rate passed to the optimizer.

  • adaptive_sample_freq (int, optional) – Frequency of adaptive sampling during training.

  • n_stages (int, optional) – Total number of training stages. Default is 2.

  • width_size (int, optional) – Width of the sub-networks added in later stages.

  • depth (int, optional) – Depth of the sub-networks added in later stages.

  • activation (callable, optional) – Activation function for new stages. Default is jnp.tanh.

  • num_samples_for_epsilon (tuple, optional) – Number of samples used to estimate error statistics between stages.

  • order (tuple, optional) – Order of error estimation.

  • beta_fun (callable, optional) – Function defining the beta distribution for error bounds.

  • heuristic (float, optional) – Heuristic multiplier for error estimation. Default is 0.9.

  • chebyshev (bool) – Whether to use Chebyshev feature mapping instead of Fourier. If given, heuristic is ignored.

  • x_stage2 (tuple of jax.Array, optional) – Input coordinates for stage 2 and beyond. Default is x.

  • training_samples_stage2 (jax.Array, optional) – Training data for stage 2 and beyond. Default is training_samples.

  • return_loss_history (bool, optional) – If True, returns loss histories for all stages.

  • print_every (int, optional) – Logging frequency.

  • key (jax.random.PRNGKey, optional) – Random key for initialization and sampling.

  • net_kwargs_for_save (dict, optional) – Additional metadata to save with the model.

  • name (str, optional) – Base name for saving models and checkpoints.

  • checkpoint_dir (str, optional) – Directory to store stage-specific checkpoints.

  • checkpoint_every (int, optional) – Frequency of checkpointing within stages.

  • benchmark_state (callable, optional) – Callback for external benchmarking or logging. Signature: benchmark_state(net,stage,name,step=step).

Returns:

  • net (eqx.Module) – The final trained multi-stage model.

  • loss_histories (list of list of float, optional) – A list containing the loss history for each stage (if return_loss_history` is True).