API Documentation

Multistage networks

multistage.Stage1(lb, ub, in_size, out_size)

First stage PINN solver.

multistage.Stage2(s1, epsilon, kappa[, ...])

Initializes the Stage 2 PINN model.

multistage.multistage_train(net, ...[, ...])

Multi-stage training.

multistage.multistage_trust_region_train(...)

Multi-stage training using trust region based optimization.

Plotting

multistage.plot_2d_residual(net, residual_fun)

Plots the signed PDE residual on a uniform 2D grid with a linear scale.

multistage.plot_2d_solution(net[, ...])

Plots solution, truth, and error on a uniform 2D grid.

multistage.plot_loss(checkpoint_path, figname)

Plot loss history from specified path.

IO

multistage.load(filename, model_constructor, ...)

Load the model from a file containing Pytree skeleton and binary data.

multistage.save(filename, model, **kwargs)

Save the model weights and configuration to a single file.