desc.integrals.Bounce2D.batch
- static Bounce2D.batch(fun, data, grid, *, angle, names=(), custom_data=None, flux_data=None, batch_size=1, sparse=True, shard_input_data=False)Source
Compute function
funbatched over flux surfaces.You may want to also JIT compile your code which calls this utility method.
Examples
desc/compute/_fast_ion.pydesc/compute/_neoclassical.pydesc/compute/_turbulence.py
- Parameters:
fun (callable) – A function which takes a single argument
fun_dataand computes bounce integrals assumingfun_dataholds all required quantities to construct aBounce2Doperator as well as call its methods.data (dict[str, jnp.ndarray]) – Data dictionary with the same structure as the data returned by the functions in
desc.compute. Must contain the quantities inBounce2D.required_names,min_tz |B|,max_tz |B|, and any entries requested bynames.grid (Grid) – Grid on which
datawas computed.angle (jnp.ndarray) – Shape (num rho, X, Y). Angle returned by
Bounce2D.angle.names (tuple[str]) – Optional, things in
datathat are not constant on each flux surface. These will be FFT’d and passed tofunin batches.custom_data (dict[str, jnp.ndarray]) – Optional, other data that is not constant on each flux surface. These will be FFT’d and passed to
funin batches.flux_data (dict[str, jnp.ndarray]) – Optional, other data constant on each flux surface. These will be passed to
funas a scalar for each surface. All arrays must have dimension one.batch_size (int or None) – Number of flux surfaces to compute simultaneously. Default is
1.sparse (bool) – Whether to use sparsity preserving pullbacks. Default is
True, which makes the most sense if the output has shape (num rho, ). Otherwise, if the output shape is larger, and the final objective of interest is a lower dimensional quantity than the output, it may be preferable to delay the vjp by setting toFalse.shard_input_data (bool) – Whether to shard batched input data across devices before applying chunked batching. Default is
False.
- Returns:
The output
fun(fun_data).