pjz πŸ˜΄πŸ‘•πŸ‘–πŸ’€: Photonics on JAX#

pjz is JAX and fdtd-z, a set of tools for runnning photonic simulation and optimization workflows at scale.

Main API#

pjz.scatter(epsilon: jax.Array, omega: jax.Array, ports: Sequence[Tuple[jax.Array, int]], field_kwargs: Dict)#

Returns scattering between ports, differentiable w.r.t. epsilon.

Parameters:
  • ports – Sequence of (mode, pos) tuples that define both the excitation and output overlap operations for modes.

  • epsilon – (3, xx, yy, zz) array of permittivity values.

  • omega – (ww,) array of angular frequencies.

Returns:

Scattering values as svals[i][j] nested lists of (ww,) arrays containing the scattering values from input port i to output port j over angular frequencies omega.

pjz.field(epsilon: jax.Array, source: jax.Array, omega: jax.Array, source_pos: int, sim_params: SimParams)#

Time-harmonic solution of Maxwell’s equations.

Parameters:
  • epsilon – (3, xx, yy, zz) array of permittivity values.

  • source – Array of excitation values of shape (2, 1, yy, zz), (2, xx, 1, yy), or (2, xx, yy, 1).

  • omega – (ww,) array of angular frequencies.

  • source_pos – Position of source along axis of propagation.

  • omega_range – (omega_min, omega_max) range for omega values.

  • tt – See fdtdz_jax.fdtdz().

  • dt – See fdtdz_jax.fdtdz().

  • source_width – Number of periods for ramp-up of time-harmonic sources.

  • source_delay – Delay before ramping up source, in source_width units.

  • absorption_padding – Padding cells to add along both boundaries of the x- and y-axes for adiabatic absorption boundary conditions.

  • absorption_strength – Scaling coefficient for adiabatic absorption boundary.

  • pml_widths – See fdtdz_jax.fdtdz() documentation.

  • pml_alpha_coeff – Constant value for pml_alpha parameter of fdtdz_jax.fdtdz().

  • pml_sigma_lnr – Natural logarithm of PML reflectivity.

  • pml_sigma_m – Exponent for spatial scaling of PML.

  • use_reduced_precision – See fdtdz_jax.fdtdz() documentation.

  • launch_params – See fdtdz_jax.fdtdz() documentation.

Returns:

(ww, 3, xx, yy, zz) array of complex-valued field values at the various omega.

class pjz.SimParams(omega_range, tt, dt, source_width, source_delay, absorption_padding, absorption_coeff, pml_widths, pml_alpha_coeff, pml_sigma_lnr, pml_sigma_m, use_reduced_precision, launch_params)#
pjz.mode(epsilon: jax.Array, omega: jax.Array, num_modes: int, init: jax.Array | None = None, shift_iters: int = 10, max_iters: int = 100000, tol: float = 0.0001) Tuple[jax.Array, jax.Array, jax.Array, int]#

Solve for waveguide modes.

Parameters:
  • epsilon – (3, xx, yy, zz) array of permittivity values with exactly one xx, yy, or zz equal to 1.

  • omega – Real-valued scalar angular frequency.

  • num_modes – Integer denoting number of modes to solve for.

  • init – (2, xx, yy, zz, num_modes) of values to use as initial guess.

  • shift_iters – Number of iterations used to determine the largest eigenvalue of the waveguide operator.

  • max_iters – Maximum number of eigenvalue solver iterations to execute.

  • tol – Error threshold for eigenvalue solver.

Returns:

(wavevector, excitation, err, iters) where iters is the number of executed solver iterations, and excitation.shape == (2, xx, yy, zz, num_modes) and wavevector.shape == err.shape == (num_modes,), with excitation[..., i], wavevector[i], and err[i] being ordered such that i == 0 corresponds to the fundamental mode.

pjz.epsilon(layers: jax.Array, interface_positions: jax.Array, magnification: int, zz: int) jax.Array#

Render a three-dimensional vector array of permittivity values.

Produces a 3D vector array of permittivity values on the Yee cell based on a layered stack of 2D profiles at magnification 2 * m. Along the z-axis, both the layer boundaries and grid positions are allowed to vary continuously, while along the x- and y-axes the size of each (unmagnified) cell is assumed to be 1.

Attempts to follow [1] but only computes the on-diagonal elements of the projection matrix and is adapted to a situation where there are no explicit interfaces because the pixel values are allowed to vary continuously within each layer.

Instead, the diagonal elements of the projection matrix for a given subvolume are estimated by computing gradients across it where df(u)/du is computed as the integral of f(u) * u over the integral of u**2 where u is relative to the center of the cell.

Parameters:
  • layers – (ll, 2 * m * xx, 2 * m * yy) array of magnified layer profiles.

  • interface_positions – (ll - 1) array of interface positions between the ll layer. Assumed to be in monotonically increasing order.

  • magnification – Denotes a 2 * m in-plane magnification factor of layer profiles.

  • zz – Number of cells along z-axis.

Returns:

(3, xx, yy, zz) array of permittivity values with offsets and vector components according to the finite-difference Yee cell.

pjz.density(u, radius, alpha, c=1.0, eta=0.5, eta_lo=0.25, eta_hi=0.75)#

Shape functions#

pjz.rect(shape: Tuple[int, ...], center: jax.Array, widths: jax.Array) jax.Array#

Rectangle.

pjz.circ(shape: Tuple[int, ...], center: jax.Array, radius: jax.Array) jax.Array#

Circle.

pjz.invert(a)#
pjz.union(a, b)#
pjz.intersect(a, b)#
pjz.dilate(a: jax.Array, radius: float) jax.Array#

Dilate

pjz.shift(a: jax.Array, axis: int, distance: float) jax.Array#

Shift.