29
29
from typing import Any , Callable , Sequence
30
30
31
31
import numpy as np
32
- import jax
32
+
33
33
from jax import lax
34
- from jax .sharding import Sharding
35
34
from jax ._src import api
36
35
from jax ._src import core
37
36
from jax ._src import dtypes
44
43
from jax ._src .numpy import lax_numpy
45
44
from jax ._src .numpy import tensor_contractions
46
45
from jax ._src .pjit import PartitionSpec
46
+ from jax ._src .sharding import Sharding
47
47
from jax ._src .sharding_impls import canonicalize_sharding , NamedSharding
48
48
from jax ._src .numpy import reductions
49
49
from jax ._src .numpy import ufuncs
@@ -612,12 +612,13 @@ def _deepcopy(self: Array, memo: Any) -> Array:
612
612
613
613
def __array_module__ (self , types ):
614
614
if all (issubclass (t , _HANDLED_ARRAY_TYPES ) for t in types ):
615
+ import jax .numpy # pytype: disable=import-error
615
616
return jax .numpy
616
617
else :
617
618
return NotImplemented
618
619
619
620
620
- @partial (jax .jit , static_argnums = (1 ,2 ,3 ))
621
+ @partial (api .jit , static_argnums = (1 ,2 ,3 ))
621
622
def _multi_slice (self : Array ,
622
623
start_indices : tuple [tuple [int , ...]],
623
624
limit_indices : tuple [tuple [int , ...]],
@@ -637,7 +638,7 @@ def _multi_slice(self: Array,
637
638
638
639
# The next two functions are related to iter(array), implemented here to
639
640
# avoid circular imports.
640
- @jax .jit
641
+ @api .jit
641
642
def _unstack (x : Array ) -> list [Array ]:
642
643
dims = (0 ,)
643
644
return [lax .squeeze (t , dims ) for t in lax .split (x , (1 ,) * x .shape [0 ])]
@@ -776,7 +777,7 @@ def __repr__(self) -> str:
776
777
return f"_IndexUpdateRef({ self .array !r} , { self .index !r} )"
777
778
778
779
def get (self , * , indices_are_sorted : bool = False , unique_indices : bool = False ,
779
- mode : str | jax . lax .GatherScatterMode | None = None ,
780
+ mode : str | lax .GatherScatterMode | None = None ,
780
781
fill_value : ArrayLike | None = None , out_sharding : Sharding | None = None ):
781
782
"""Equivalent to ``x[idx]``.
782
783
@@ -798,7 +799,7 @@ def get(self, *, indices_are_sorted: bool = False, unique_indices: bool = False,
798
799
799
800
def set (self , values : ArrayLike , * , indices_are_sorted : bool = False ,
800
801
unique_indices : bool = False ,
801
- mode : str | jax . lax .GatherScatterMode | None = None ) -> None :
802
+ mode : str | lax .GatherScatterMode | None = None ) -> None :
802
803
"""Pure equivalent of ``x[idx] = y``.
803
804
804
805
Returns the value of ``x`` that would result from the NumPy-style
@@ -816,7 +817,7 @@ def set(self, values: ArrayLike, *, indices_are_sorted: bool = False,
816
817
817
818
def apply (self , func : Callable [[ArrayLike ], Array ], * ,
818
819
indices_are_sorted : bool = False , unique_indices : bool = False ,
819
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
820
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
820
821
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
821
822
822
823
Returns the value of ``x`` that would result from applying the unary
@@ -840,7 +841,7 @@ def _scatter_apply(x, indices, y, dims, **kwargs):
840
841
841
842
def add (self , values : ArrayLike , * ,
842
843
indices_are_sorted : bool = False , unique_indices : bool = False ,
843
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
844
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
844
845
"""Pure equivalent of ``x[idx] += y``.
845
846
846
847
Returns the value of ``x`` that would result from the NumPy-style
@@ -855,7 +856,7 @@ def add(self, values: ArrayLike, *,
855
856
856
857
def subtract (self , values : ArrayLike , * ,
857
858
indices_are_sorted : bool = False , unique_indices : bool = False ,
858
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
859
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
859
860
"""Pure equivalent of ``x[idx] -= y``.
860
861
861
862
Returns the value of ``x`` that would result from the NumPy-style
@@ -870,7 +871,7 @@ def subtract(self, values: ArrayLike, *,
870
871
871
872
def multiply (self , values : ArrayLike , * ,
872
873
indices_are_sorted : bool = False , unique_indices : bool = False ,
873
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
874
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
874
875
"""Pure equivalent of ``x[idx] *= y``.
875
876
876
877
Returns the value of ``x`` that would result from the NumPy-style
@@ -887,7 +888,7 @@ def multiply(self, values: ArrayLike, *,
887
888
888
889
def divide (self , values : ArrayLike , * ,
889
890
indices_are_sorted : bool = False , unique_indices : bool = False ,
890
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
891
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
891
892
"""Pure equivalent of ``x[idx] /= y``.
892
893
893
894
Returns the value of ``x`` that would result from the NumPy-style
@@ -904,7 +905,7 @@ def divide(self, values: ArrayLike, *,
904
905
905
906
def power (self , values : ArrayLike , * ,
906
907
indices_are_sorted : bool = False , unique_indices : bool = False ,
907
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
908
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
908
909
"""Pure equivalent of ``x[idx] **= y``.
909
910
910
911
Returns the value of ``x`` that would result from the NumPy-style
@@ -921,7 +922,7 @@ def power(self, values: ArrayLike, *,
921
922
922
923
def min (self , values : ArrayLike , * ,
923
924
indices_are_sorted : bool = False , unique_indices : bool = False ,
924
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
925
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
925
926
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
926
927
927
928
Returns the value of ``x`` that would result from the NumPy-style
@@ -937,7 +938,7 @@ def min(self, values: ArrayLike, *,
937
938
938
939
def max (self , values : ArrayLike , * ,
939
940
indices_are_sorted : bool = False , unique_indices : bool = False ,
940
- mode : str | jax . lax .GatherScatterMode | None = None ) -> Array :
941
+ mode : str | lax .GatherScatterMode | None = None ) -> Array :
941
942
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
942
943
943
944
Returns the value of ``x`` that would result from the NumPy-style
0 commit comments