Skip to content

Commit 0b89b23

Browse files
Merge pull request #29186 from jakevdp:jax-numpy-imports
PiperOrigin-RevId: 766730509
2 parents d17b292 + cecf2f6 commit 0b89b23

19 files changed

+182
-167
lines changed

jax/_src/lax/fft.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121

2222
import numpy as np
2323

24-
from jax import lax
25-
2624
from jax._src import dispatch
2725
from jax._src import dtypes
2826
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
2927
from jax._src.core import Primitive, is_constant_shape
3028
from jax._src.interpreters import ad
3129
from jax._src.interpreters import batching
3230
from jax._src.interpreters import mlir
31+
from jax._src.lax import lax
3332
from jax._src.lib.mlir.dialects import hlo
3433

3534
__all__ = [

jax/_src/numpy/array_api_metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from types import ModuleType
2323

24-
import jax
2524
from jax._src.sharding import Sharding
2625
from jax._src.lib import xla_client as xc
2726
from jax._src import config
@@ -40,6 +39,7 @@ def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType:
4039
if api_version is not None and api_version != __array_api_version__:
4140
raise ValueError(f"{api_version=!r} is not available; "
4241
f"available versions are: {[__array_api_version__]}")
42+
import jax.numpy # pytype: disable=import-error
4343
return jax.numpy
4444

4545

@@ -77,7 +77,7 @@ def default_device(self):
7777
def devices(self):
7878
out = [None] # None indicates "uncommitted"
7979
for backend in xb.backends():
80-
out.extend(jax.devices(backend))
80+
out.extend(xb.devices(backend))
8181
return out
8282

8383
def capabilities(self):

jax/_src/numpy/array_creation.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@
1919

2020
import numpy as np
2121

22-
import jax
23-
from jax import lax
24-
from jax._src.api import jit
22+
from jax._src.api import device_put, jit
2523
from jax._src import core
2624
from jax._src import dtypes
27-
from jax._src.lax import lax as lax_internal
25+
from jax._src.lax import lax
2826
from jax._src.lib import xla_client as xc
2927
from jax._src.numpy import ufuncs
3028
from jax._src.numpy import util
29+
from jax._src.sharding import Sharding
3130
from jax._src.typing import Array, ArrayLike, DuckTypedArray, DTypeLike
3231
from jax._src.util import canonicalize_axis, set_module
33-
from jax.sharding import Sharding
3432

3533

3634
export = set_module('jax.numpy')
@@ -205,15 +203,17 @@ def full(shape: Any, fill_value: ArrayLike,
205203
Array([[0, 1, 2],
206204
[0, 1, 2]], dtype=int32)
207205
"""
206+
from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error
207+
208208
dtypes.check_user_dtype_supported(dtype, "full")
209209
util.check_arraylike("full", fill_value)
210210

211211
if np.ndim(fill_value) == 0:
212212
shape = canonicalize_shape(shape)
213213
return lax.full(shape, fill_value, dtype, sharding=util.normalize_device_to_sharding(device))
214214
else:
215-
return jax.device_put(
216-
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
215+
return device_put(
216+
util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
217217

218218

219219
@export
@@ -394,6 +394,8 @@ def full_like(a: ArrayLike | DuckTypedArray,
394394
Array([[1, 1, 1],
395395
[2, 2, 2]], dtype=int32)
396396
"""
397+
from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error
398+
397399
if hasattr(a, 'dtype') and hasattr(a, 'shape'): # support duck typing
398400
util.check_arraylike("full_like", 0, fill_value)
399401
else:
@@ -408,8 +410,8 @@ def full_like(a: ArrayLike | DuckTypedArray,
408410
else:
409411
shape = np.shape(a) if shape is None else shape # type: ignore[arg-type]
410412
dtype = dtypes.result_type(a) if dtype is None else dtype
411-
return jax.device_put(
412-
util._broadcast_to(jax.numpy.asarray(fill_value, dtype=dtype), shape), device)
413+
return device_put(
414+
util._broadcast_to(asarray(fill_value, dtype=dtype), shape), device)
413415

414416
@overload
415417
def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
@@ -510,6 +512,8 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
510512
axis: int = 0,
511513
*, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]:
512514
"""Implementation of linspace differentiable in start and stop args."""
515+
from jax._src.numpy.lax_numpy import asarray # pytype: disable=import-error
516+
513517
dtypes.check_user_dtype_supported(dtype, "linspace")
514518
if num < 0:
515519
raise ValueError(f"Number of samples, {num}, must be non-negative.")
@@ -529,13 +533,13 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
529533
bounds_shape.insert(axis, 1)
530534
div = (num - 1) if endpoint else num
531535
if num > 1:
532-
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / jax.numpy.array(div, dtype=computation_dtype)
536+
delta: Array = lax.convert_element_type(stop - start, computation_dtype) / asarray(div, dtype=computation_dtype)
533537
iota_shape = [1,] * len(bounds_shape)
534538
iota_shape[axis] = div
535539
# This approach recovers the endpoints with float32 arithmetic,
536540
# but can lead to rounding errors for integer outputs.
537541
real_dtype = dtypes.finfo(computation_dtype).dtype
538-
step = lax.iota(real_dtype, div).reshape(iota_shape) / jax.numpy.array(div, real_dtype)
542+
step = lax.iota(real_dtype, div).reshape(iota_shape) / asarray(div, real_dtype)
539543
step = step.astype(computation_dtype)
540544
out = (broadcast_start.reshape(bounds_shape) * (1 - step) +
541545
broadcast_stop.reshape(bounds_shape) * step)
@@ -545,7 +549,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
545549
canonicalize_axis(axis, out.ndim))
546550

547551
elif num == 1:
548-
delta = jax.numpy.asarray(np.nan if endpoint else stop - start, dtype=computation_dtype)
552+
delta = asarray(np.nan if endpoint else stop - start, dtype=computation_dtype)
549553
out = broadcast_start.reshape(bounds_shape)
550554
else: # num == 0 degenerate case, match numpy behavior
551555
empty_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
@@ -557,7 +561,7 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
557561
out = lax.floor(out)
558562

559563
sharding = util.canonicalize_device_to_sharding(device)
560-
result = lax_internal._convert_element_type(out, dtype, sharding=sharding)
564+
result = lax._convert_element_type(out, dtype, sharding=sharding)
561565
return (result, delta) if retstep else result
562566

563567

jax/_src/numpy/array_methods.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@
2929
from typing import Any, Callable, Sequence
3030

3131
import numpy as np
32-
import jax
32+
3333
from jax import lax
34-
from jax.sharding import Sharding
3534
from jax._src import api
3635
from jax._src import core
3736
from jax._src import dtypes
@@ -44,6 +43,7 @@
4443
from jax._src.numpy import lax_numpy
4544
from jax._src.numpy import tensor_contractions
4645
from jax._src.pjit import PartitionSpec
46+
from jax._src.sharding import Sharding
4747
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
4848
from jax._src.numpy import reductions
4949
from jax._src.numpy import ufuncs
@@ -612,12 +612,13 @@ def _deepcopy(self: Array, memo: Any) -> Array:
612612

613613
def __array_module__(self, types):
614614
if all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types):
615+
import jax.numpy # pytype: disable=import-error
615616
return jax.numpy
616617
else:
617618
return NotImplemented
618619

619620

620-
@partial(jax.jit, static_argnums=(1,2,3))
621+
@partial(api.jit, static_argnums=(1,2,3))
621622
def _multi_slice(self: Array,
622623
start_indices: tuple[tuple[int, ...]],
623624
limit_indices: tuple[tuple[int, ...]],
@@ -637,7 +638,7 @@ def _multi_slice(self: Array,
637638

638639
# The next two functions are related to iter(array), implemented here to
639640
# avoid circular imports.
640-
@jax.jit
641+
@api.jit
641642
def _unstack(x: Array) -> list[Array]:
642643
dims = (0,)
643644
return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])]
@@ -776,7 +777,7 @@ def __repr__(self) -> str:
776777
return f"_IndexUpdateRef({self.array!r}, {self.index!r})"
777778

778779
def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False,
779-
mode: str | jax.lax.GatherScatterMode | None = None,
780+
mode: str | lax.GatherScatterMode | None = None,
780781
fill_value: ArrayLike | None = None, out_sharding: Sharding | None = None):
781782
"""Equivalent to ``x[idx]``.
782783
@@ -798,7 +799,7 @@ def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False,
798799

799800
def set(self, values: ArrayLike, *, indices_are_sorted: bool = False,
800801
unique_indices: bool = False,
801-
mode: str | jax.lax.GatherScatterMode | None = None) -> None:
802+
mode: str | lax.GatherScatterMode | None = None) -> None:
802803
"""Pure equivalent of ``x[idx] = y``.
803804
804805
Returns the value of ``x`` that would result from the NumPy-style
@@ -816,7 +817,7 @@ def set(self, values: ArrayLike, *, indices_are_sorted: bool = False,
816817

817818
def apply(self, func: Callable[[ArrayLike], Array], *,
818819
indices_are_sorted: bool = False, unique_indices: bool = False,
819-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
820+
mode: str | lax.GatherScatterMode | None = None) -> Array:
820821
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
821822
822823
Returns the value of ``x`` that would result from applying the unary
@@ -840,7 +841,7 @@ def _scatter_apply(x, indices, y, dims, **kwargs):
840841

841842
def add(self, values: ArrayLike, *,
842843
indices_are_sorted: bool = False, unique_indices: bool = False,
843-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
844+
mode: str | lax.GatherScatterMode | None = None) -> Array:
844845
"""Pure equivalent of ``x[idx] += y``.
845846
846847
Returns the value of ``x`` that would result from the NumPy-style
@@ -855,7 +856,7 @@ def add(self, values: ArrayLike, *,
855856

856857
def subtract(self, values: ArrayLike, *,
857858
indices_are_sorted: bool = False, unique_indices: bool = False,
858-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
859+
mode: str | lax.GatherScatterMode | None = None) -> Array:
859860
"""Pure equivalent of ``x[idx] -= y``.
860861
861862
Returns the value of ``x`` that would result from the NumPy-style
@@ -870,7 +871,7 @@ def subtract(self, values: ArrayLike, *,
870871

871872
def multiply(self, values: ArrayLike, *,
872873
indices_are_sorted: bool = False, unique_indices: bool = False,
873-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
874+
mode: str | lax.GatherScatterMode | None = None) -> Array:
874875
"""Pure equivalent of ``x[idx] *= y``.
875876
876877
Returns the value of ``x`` that would result from the NumPy-style
@@ -887,7 +888,7 @@ def multiply(self, values: ArrayLike, *,
887888

888889
def divide(self, values: ArrayLike, *,
889890
indices_are_sorted: bool = False, unique_indices: bool = False,
890-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
891+
mode: str | lax.GatherScatterMode | None = None) -> Array:
891892
"""Pure equivalent of ``x[idx] /= y``.
892893
893894
Returns the value of ``x`` that would result from the NumPy-style
@@ -904,7 +905,7 @@ def divide(self, values: ArrayLike, *,
904905

905906
def power(self, values: ArrayLike, *,
906907
indices_are_sorted: bool = False, unique_indices: bool = False,
907-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
908+
mode: str | lax.GatherScatterMode | None = None) -> Array:
908909
"""Pure equivalent of ``x[idx] **= y``.
909910
910911
Returns the value of ``x`` that would result from the NumPy-style
@@ -921,7 +922,7 @@ def power(self, values: ArrayLike, *,
921922

922923
def min(self, values: ArrayLike, *,
923924
indices_are_sorted: bool = False, unique_indices: bool = False,
924-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
925+
mode: str | lax.GatherScatterMode | None = None) -> Array:
925926
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
926927
927928
Returns the value of ``x`` that would result from the NumPy-style
@@ -937,7 +938,7 @@ def min(self, values: ArrayLike, *,
937938

938939
def max(self, values: ArrayLike, *,
939940
indices_are_sorted: bool = False, unique_indices: bool = False,
940-
mode: str | jax.lax.GatherScatterMode | None = None) -> Array:
941+
mode: str | lax.GatherScatterMode | None = None) -> Array:
941942
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
942943
943944
Returns the value of ``x`` that would result from the NumPy-style

jax/_src/numpy/error.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import contextlib
1616
from typing import Literal, Sequence
1717

18-
import jax
18+
import numpy as np
19+
1920
from jax._src import config
20-
from jax._src.typing import ArrayLike
21+
from jax._src import dtypes
22+
from jax._src.typing import Array, ArrayLike
2123

2224
Category = Literal["nan", "divide", "oob"]
2325

@@ -40,7 +42,7 @@ def _is_category_disabled(
4042

4143

4244
def _set_error_if_with_category(
43-
pred: jax.Array,
45+
pred: Array,
4446
/,
4547
msg: str,
4648
category: Category | None = None,
@@ -65,7 +67,7 @@ def _set_error_if_with_category(
6567
error_check_lib.set_error_if(pred, msg)
6668

6769

68-
def _set_error_if_nan(pred: jax.Array, /):
70+
def _set_error_if_nan(pred: Array, /):
6971
"""Set the internal error state if any element of `pred` is `NaN`.
7072
7173
This function is disabled if the `jax_error_checking_behavior_nan` flag is
@@ -74,17 +76,17 @@ def _set_error_if_nan(pred: jax.Array, /):
7476
if config.error_checking_behavior_nan.value == "ignore":
7577
return
7678

77-
# TODO(mattjj): fix the circular import issue.
78-
import jax.numpy as jnp
79-
if not jnp.issubdtype(pred.dtype, jnp.floating): # only check floats
79+
if not dtypes.issubdtype(pred.dtype, np.floating): # only check floats
8080
return
8181

8282
# TODO(mattjj): fix the circular import issue.
8383
from jax._src import error_check as error_check_lib
84+
import jax.numpy as jnp
85+
8486
error_check_lib.set_error_if(jnp.isnan(pred), "NaN encountered")
8587

8688

87-
def _set_error_if_divide_by_zero(pred: jax.Array, /):
89+
def _set_error_if_divide_by_zero(pred: Array, /):
8890
"""Set the internal error state if any element of `pred` is zero.
8991
9092
This function is intended for checking if the denominator of a division is

0 commit comments

Comments
 (0)