Skip to content

Commit bcc6edd

Browse files
authored
Merge pull request invrs-io#99 from invrs-io/scan
Make use of new method for computing scattering matrices
2 parents 4656bfd + a0afa37 commit bcc6edd

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

docs/notebooks/metalens_challenge.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@
151151
"\n",
152152
"ex, ey, ez = aux[\"efield\"]\n",
153153
"x, _, z = aux[\"field_coordinates\"]\n",
154-
"xplot, zplot = onp.meshgrid(x[:, 0], z, indexing=\"ij\")\n",
154+
"xplot, zplot = onp.meshgrid(x[0, :, 0], z, indexing=\"ij\")\n",
155155
"\n",
156156
"abs_field = onp.sqrt(onp.abs(ex) ** 2 + onp.abs(ey) ** 2 + onp.abs(ez) ** 2)\n",
157157
"\n",

src/invrs_gym/challenges/metalens/challenge.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Tuple
99

1010
from fmmax import fmm # type: ignore[import-untyped]
11+
import jax
12+
from jax import nn
1113
from jax import numpy as jnp
1214
from totypes import symmetry, types
1315

@@ -64,9 +66,8 @@ def loss(self, response: metalens_component.MetalensResponse) -> jnp.ndarray:
6466
enhancement = response.enhancement_ex
6567
else:
6668
enhancement = response.enhancement_ey
67-
# The `sqrt` is used to help avoid situtations where one wavelength has
68-
# strong enhancement while the others have less.
69-
return -jnp.mean(jnp.sqrt(enhancement))
69+
70+
return soft_amax(-enhancement, scale=10.0)
7071

7172
def distance_to_target(
7273
self, response: metalens_component.MetalensResponse
@@ -107,6 +108,24 @@ def metrics(
107108
return metrics
108109

109110

111+
def soft_amax(x: jnp.ndarray, scale: float) -> jnp.ndarray:
112+
"""A soft version of `amax`.
113+
114+
The softness is set by `scale`. For small values, the output is close to that of
115+
`amax`, while for larger values it is closer to that of `mean`. This function can
116+
be used to scalarize a vector objective in a manner related to the concept of
117+
minimax optimization.
118+
119+
Args:
120+
x: The array to be scalarized.
121+
scale: The scale of smoothness.
122+
123+
Returns:
124+
The scalarized array.
125+
"""
126+
return jnp.sum(jax.lax.stop_gradient(nn.softmax(x / scale)) * x)
127+
128+
110129
METALENS_SPEC = metalens_component.MetalensSpec(
111130
permittivity_ambient=(1.0 + 0.0001j) ** 2,
112131
permittivity_metalens=(2.4 + 0.0001j) ** 2,

src/invrs_gym/challenges/metalens/component.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -372,30 +372,47 @@ def eigensolve_fn(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
372372
solve_result_ambient = eigensolve_fn(
373373
permittivity=jnp.full(spec.grid_shape, spec.permittivity_ambient)
374374
)
375-
solve_results_metalens = [
376-
eigensolve_fn(permittivity=p) for p in metalens_permittivities
377-
]
378375
solve_result_substrate = eigensolve_fn(
379376
permittivity=jnp.full(spec.grid_shape, spec.permittivity_substrate)
380377
)
381-
layer_solve_results = (
382-
[solve_result_ambient] + solve_results_metalens + [solve_result_substrate]
383-
)
384378

385379
if compute_fields:
386380
# If the field calculation is desired, compute the interior scattering
387381
# matrices. For each layer in the stack, the interior scattering matrices
388382
# consist of a pair of matrices, one for the substack below the layer, and
389383
# one for the substack above the layer.
384+
solve_results_metalens = [
385+
eigensolve_fn(permittivity=p) for p in metalens_permittivities
386+
]
387+
layer_solve_results = (
388+
[solve_result_ambient] + solve_results_metalens + [solve_result_substrate]
389+
)
390390
s_matrices_interior = scattering.stack_s_matrices_interior(
391391
layer_solve_results=layer_solve_results,
392392
layer_thicknesses=layer_thicknesses,
393393
)
394394
s_matrix = s_matrices_interior[-1][0]
395395
else:
396-
s_matrix = scattering.stack_s_matrix(
397-
layer_solve_results=layer_solve_results,
398-
layer_thicknesses=layer_thicknesses,
396+
solve_results_metalens_batch = eigensolve_fn(
397+
permittivity=jnp.asarray(metalens_permittivities)[:, jnp.newaxis, :, :]
398+
)
399+
# Merge with the ambient and substrate solve results to get the solve results
400+
# for the full stack, needed by `stack_s_matrix_scan`.
401+
stack_layer_solve_results = tree_util.tree_map(
402+
lambda a, b, c: jnp.concatenate(
403+
[
404+
a[jnp.newaxis, ...],
405+
jnp.broadcast_to(b, (num_layers,) + b.shape[1:]),
406+
c[jnp.newaxis, ...],
407+
]
408+
),
409+
solve_result_ambient,
410+
solve_results_metalens_batch,
411+
solve_result_substrate,
412+
)
413+
s_matrix = scattering.stack_s_matrix_scan(
414+
layer_solve_results=stack_layer_solve_results,
415+
layer_thicknesses=jnp.asarray(layer_thicknesses),
399416
)
400417

401418
# Compute the source, consisting of a smoothed step function.

0 commit comments

Comments
 (0)