desc.integrals.Bounce2D.batch

static Bounce2D.batch(fun, data, grid, *, angle, names=(), custom_data=None, flux_data=None, batch_size=1, sparse=True, shard_input_data=False)Source

Compute function fun batched over flux surfaces.

You may want to also JIT compile your code which calls this utility method.

Examples

  • desc/compute/_fast_ion.py

  • desc/compute/_neoclassical.py

  • desc/compute/_turbulence.py

Parameters:
  • fun (callable) – A function which takes a single argument fun_data and computes bounce integrals assuming fun_data holds all required quantities to construct a Bounce2D operator as well as call its methods.

  • data (dict[str, jnp.ndarray]) – Data dictionary with the same structure as the data returned by the functions in desc.compute. Must contain the quantities in Bounce2D.required_names, min_tz |B|, max_tz |B|, and any entries requested by names.

  • grid (Grid) – Grid on which data was computed.

  • angle (jnp.ndarray) – Shape (num rho, X, Y). Angle returned by Bounce2D.angle.

  • names (tuple[str]) – Optional, things in data that are not constant on each flux surface. These will be FFT’d and passed to fun in batches.

  • custom_data (dict[str, jnp.ndarray]) – Optional, other data that is not constant on each flux surface. These will be FFT’d and passed to fun in batches.

  • flux_data (dict[str, jnp.ndarray]) – Optional, other data constant on each flux surface. These will be passed to fun as a scalar for each surface. All arrays must have dimension one.

  • batch_size (int or None) – Number of flux surfaces to compute simultaneously. Default is 1.

  • sparse (bool) – Whether to use sparsity preserving pullbacks. Default is True, which makes the most sense if the output has shape (num rho, ). Otherwise, if the output shape is larger, and the final objective of interest is a lower dimensional quantity than the output, it may be preferable to delay the vjp by setting to False.

  • shard_input_data (bool) – Whether to shard batched input data across devices before applying chunked batching. Default is False.

Returns:

The output fun(fun_data).