multistage.multistage_train
- multistage.multistage_train(net, residual_fun_s1, residual_fun_s2, loss_fun_s1, loss_fun_s2, x, training_samples, optimizer, steps, *, learning_rate=None, adaptive_sample_freq=1000, n_stages=2, width_size=20, depth=4, activation=<PjitFunction of <function tanh>>, num_samples_for_epsilon=(1024, ), order=(1, ), beta_fun=None, heuristic=0.9, chebyshev=False, x_stage2=None, training_samples_stage2=None, return_loss_history=True, print_every=100, key=None, net_kwargs_for_save=None, name='', checkpoint_dir='checkpoints', checkpoint_every=5000, benchmark_state=None, **adaptive_sample_kwargs)Source
Multi-stage training.
Examples
See
tests/test_burgers.py.
- Parameters:
net (eqx.Module) – The initial model architecture.
residual_fun_s1 (callable) – Function to compute PDE residuals for the first stage.
residual_fun_s2 (callable) – Function to compute PDE residuals for subsequent stages.
loss_fun_s1 (callable) – Scalar output loss function for the first stage.
loss_fun_s2 (callable) – Scalar output loss function for subsequent stages.
x (tuple or list of jax.Array) – Input coordinates for the first stage.
training_samples (jax.Array) – Target values for the first stage.
optimizer (optax.GradientTransformation) – Optimizer for training loops.
steps (int) – Number of training steps per stage.
learning_rate (float, optional) – Learning rate passed to the optimizer.
adaptive_sample_freq (int, optional) – Frequency of adaptive sampling during training.
n_stages (int, optional) – Total number of training stages. Default is 2.
width_size (int, optional) – Width of the sub-networks added in later stages.
depth (int, optional) – Depth of the sub-networks added in later stages.
activation (callable, optional) – Activation function for new stages. Default is
jnp.tanh.num_samples_for_epsilon (tuple, optional) – Number of samples used to estimate error statistics between stages.
order (tuple, optional) – Order of error estimation.
beta_fun (callable, optional) – Function defining the beta distribution for error bounds.
heuristic (float, optional) – Heuristic multiplier for error estimation. Default is 0.9.
chebyshev (bool) – Whether to use Chebyshev feature mapping instead of Fourier. If given,
heuristicis ignored.x_stage2 (tuple of jax.Array, optional) – Input coordinates for stage 2 and beyond. Default is
x.training_samples_stage2 (jax.Array, optional) – Training data for stage 2 and beyond. Default is
training_samples.return_loss_history (bool, optional) – If True, returns loss histories for all stages.
print_every (int, optional) – Logging frequency.
key (jax.random.PRNGKey, optional) – Random key for initialization and sampling.
net_kwargs_for_save (dict, optional) – Additional metadata to save with the model.
name (str, optional) – Base name for saving models and checkpoints.
checkpoint_dir (str, optional) – Directory to store stage-specific checkpoints.
checkpoint_every (int, optional) – Frequency of checkpointing within stages.
benchmark_state (callable, optional) – Callback for external benchmarking or logging. Signature:
benchmark_state(net,stage,name,step=step).
- Returns:
net (eqx.Module) – The final trained multi-stage model.
loss_histories (list of list of float, optional) – A list containing the loss history for each stage (if return_loss_history` is True).