Skip to content

Commit a4edfd0

Browse files
Martin SchubertMartin Schubert
Martin Schubert
authored and
Martin Schubert
committed
Update loss fn
1 parent aa3bbdc commit a4edfd0

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/invrs_gym/challenges/metalens/challenge.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def loss(self, response: metalens_component.MetalensResponse) -> jnp.ndarray:
6464
enhancement = response.enhancement_ex
6565
else:
6666
enhancement = response.enhancement_ey
67-
return -jnp.mean(enhancement)
67+
# The `sqrt` is used to help avoid situtations where one wavelength has
68+
# strong enhancement while the others have less.
69+
return -jnp.mean(jnp.sqrt(enhancement))
6870

6971
def distance_to_target(
7072
self, response: metalens_component.MetalensResponse

0 commit comments

Comments
 (0)