-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
1 Release published by 1 person
-
jax-v0.7.0 JAX v0.7.0
published
Jul 22, 2025
541 Pull requests merged by 26 people
-
[build] avoid top-level dependencies in jax/_src
#30461 merged
Jul 24, 2025 -
Refactor: Add carry support to nd_loop
#30237 merged
Jul 24, 2025 -
Allow fusing all stride == 1 to support jax.lax.slice_in_dim in @fuser.fuse
#30425 merged
Jul 24, 2025 -
before iterating, copy dict that can be popped from via ARC / object deletion
#30465 merged
Jul 24, 2025 -
add psend/precv to auto-generated docs
#30464 merged
Jul 24, 2025 -
[Mosaic GPU] Fix some typos in docs
#30447 merged
Jul 23, 2025 -
Refactoring HTTIJ and surrounding pages, take 2
#30420 merged
Jul 23, 2025 -
avoid top-level imports in source_info_util
#30446 merged
Jul 23, 2025 -
Test B200 against CUDA 12.8 only
#30383 merged
Jul 23, 2025 -
Add Hermetic C++ Toolchains for Linux x86_64 builds.
#29672 merged
Jul 23, 2025 -
[Mosaic GPU] Query amount of shared memory programmatically
#30257 merged
Jul 23, 2025 -
[Pallas TPU] Remove
use_sreg_for_state
flag#30421 merged
Jul 23, 2025 -
Remove out_sharding from scatter_add because we now have a scatter sharding rule
#30443 merged
Jul 23, 2025 -
Use the correct tracing context when creating zeros in linearize_from_jvp
#30440 merged
Jul 23, 2025 -
Implement explicit mode sharding rule for scatter
#30377 merged
Jul 23, 2025 -
[jax] Export ir_attribute in jax.interpreters.mlir
#30430 merged
Jul 23, 2025 -
Fix typo in doc.
#30431 merged
Jul 23, 2025 -
[Mosaic] NFC: Just use ExtFOp and TruncFOp for bf16.
#30398 merged
Jul 23, 2025 -
[Mosaic GPU] Simplify type constraints in MGPU dialect.
#30397 merged
Jul 23, 2025 -
[Mosaic GPU] Add support for warp shuffles with elements wider than 32-bit
#30387 merged
Jul 23, 2025 -
[Mosaic] Add tpu.wait_indirect_dma assembly format
#30169 merged
Jul 23, 2025 -
Extend numpy.size to accept multiple axes.
#30132 merged
Jul 23, 2025 -
Avoid top-level imports in jax/_src/internal_test_util/*
#30402 merged
Jul 23, 2025 -
Avoid top-level imports in jax/_src/extend
#30403 merged
Jul 23, 2025 -
Remove JAX_IFRT_VERSION_NUMBER guards after the release of 0.7.0
#30419 merged
Jul 22, 2025 -
#sdy Delete JAX test configs enabling Shardy.
#30418 merged
Jul 22, 2025 -
Add LRU cache on
get_process_index_and_count
.#30417 merged
Jul 22, 2025 -
Avoid top-level imports in jax/_src/test_multiprocess
#30415 merged
Jul 22, 2025 -
Remove
vmap
andpmap
mentions in the rst section title.#30416 merged
Jul 22, 2025 -
Remove jaxlib_extension_version, ifrt_version guards now that 0.7.0 release is out
#30406 merged
Jul 22, 2025 -
Reverts d86857a9c8cdcae8d27e6a371daceb0ac3f53e28
#30408 merged
Jul 22, 2025 -
Release 0.7.0
#30404 merged
Jul 22, 2025 -
Refactor "How to think in JAX" and surrounding pages
#29541 merged
Jul 22, 2025 -
Avoid top-level imports in jax/_src/test_util.py
#30400 merged
Jul 22, 2025 -
[jax/BUILD] avoid
_src
dependencies on top-leveljax
packages#30394 merged
Jul 22, 2025 -
[JAX] Make
util.lru_cache(..., trace_context_in_key=False)
always cache#30392 merged
Jul 22, 2025 -
Enforce the shape of MMA operands and result in their types.
#30390 merged
Jul 22, 2025 -
[pallas:mosaic_gpu] Do not use
lax.zeros_like_array
#30385 merged
Jul 22, 2025 -
[pallas]
pl.loop
now acceptsstep=
#30274 merged
Jul 22, 2025 -
Update deprecation schedule for
mlir.custom_call
, attributes ofxla_client
, andxla_bridge.get_backend
.#30386 merged
Jul 22, 2025 -
[pallas:mosaic_gpu]
async_load_tmem
no longer accepts anidx
#30382 merged
Jul 22, 2025 -
Rollback XLA archival change, current automation does not calculate the shasum correctly
#30379 merged
Jul 22, 2025 -
[Mosaic GPU][NFC] Implement
inline_mgpu
andcustom_primitive
for WG semantics.#30147 merged
Jul 22, 2025 -
Add visibility to
//jax:pretty_printer
#30373 merged
Jul 22, 2025 -
write docs about controlling array layouts
#30368 merged
Jul 22, 2025 -
Remove
SafeNumDevices
now that all shardings are implemented in C++#30375 merged
Jul 22, 2025 -
Loosen test tolerance for
tests/linalg_sharding_test.py::LinalgShardingTest::test_batch_axis_sharding_jvp
.#30370 merged
Jul 22, 2025 -
Fix handling of empty arrays in ufunc.reduce/accumulate
#29248 merged
Jul 21, 2025 -
Fix mesh_cast + new rng keys usage
#30367 merged
Jul 21, 2025 -
[Pallas] Only run TPU interpret mode tests on CPU.
#30366 merged
Jul 21, 2025 -
Fix
py_import
dependencies.#30365 merged
Jul 21, 2025 -
[Pallas][Mosaic GPU] Fix racy test by having each warpgroup work on non-overlapping data.
#30354 merged
Jul 21, 2025 -
[JAX] Purge more caches when clearing backends
#30324 merged
Jul 21, 2025 -
Migrate third party uses of
jax.lib.xla_bridge.get_backend
tojax.extend.backend.get_backend
.#30359 merged
Jul 21, 2025 -
[Pallas] Make semantics of BlockSpec.memory_space in emit_pipeline consistent with pallas_call.
#30284 merged
Jul 21, 2025 -
jnp.fft.fftfreq: fix regression in complex dtype support
#30356 merged
Jul 21, 2025 -
[jex] Add
proto
<->xla_client.HloSharding
andproto ->
xla_client.OpSharding` APIs to jex.#30357 merged
Jul 21, 2025 -
Fetch XLA archive using the GitHub API endpoint instead of the web link
#30358 merged
Jul 21, 2025 -
Parametrize build system on CUDA major version
#28968 merged
Jul 21, 2025 -
Allow setting transfer_size for cross host transfers.
#30355 merged
Jul 21, 2025 -
gumbel distribution implementation
#29343 merged
Jul 21, 2025 -
Automated g4 rollback of changelist 784947368.
#30353 merged
Jul 21, 2025 -
Consider this example:
#30350 merged
Jul 21, 2025 -
[Mosaic] Add store canonicalization for an expand reshape->store fusion
#30333 merged
Jul 21, 2025 -
Add note about direct linearize to changelog
#30351 merged
Jul 21, 2025 -
Update editorconfig to the Python line length of 88
#30348 merged
Jul 21, 2025 -
Repair initializers not matching Initializer protocol
#29561 merged
Jul 21, 2025 -
test_multiprocess: only initialize GPUs that exist
#30248 merged
Jul 21, 2025 -
#sdy Add section on Shardy migration issues about all meshes being the same and custom_partitioning.
#30347 merged
Jul 21, 2025 -
update
use_shardy_partitioner
description.#30345 merged
Jul 21, 2025 -
improve AOT input sharding/layout mismatch error message
#30344 merged
Jul 21, 2025 -
Fix creation of sharding rules for fused_attention_stablehlo.
#30343 merged
Jul 20, 2025 -
[JAX] Enable Shardy by default in JAX.
#30302 merged
Jul 20, 2025 -
[mutable-arrays] remat discharge rule
#29370 merged
Jul 19, 2025 -
[mutable-arrays] basic shard_map + mutable arrays support
#30340 merged
Jul 19, 2025 -
Add JAX tests for deadlock verifier
#30264 merged
Jul 19, 2025 -
Fix CI config for non-rbe CUDA tests with py_import.
#30329 merged
Jul 19, 2025 -
[mosaic]
tpu_custom_call.CostEstimate
is now a typed dict#30303 merged
Jul 19, 2025 -
Reverts 7c394f41574b1f669cbb6a47fd03bf1925a19444
#30331 merged
Jul 19, 2025 -
[Pallas/TPU] Interpret mode: Don't use source_info context manager if it's None
#30326 merged
Jul 19, 2025 -
Use
operand.aval.sharding
in convert_element_type's transpose rule#30330 merged
Jul 19, 2025 -
Fixed a bug previously hidden by jax_use_direct_linearize=True
#30325 merged
Jul 19, 2025 -
Move ShapeDtypeStruct to core.py to break circular deps
#30323 merged
Jul 18, 2025 -
Create workflow for testing bazel cuda non rbe + py import
#28860 merged
Jul 18, 2025 -
Replace
jax.stages.OutInfo
withjax.ShapeDtypeStruct
#30318 merged
Jul 18, 2025 -
#sdy condition jax export tests on compatibility version when using shardy
#30312 merged
Jul 18, 2025 -
[Mosaic GPU] Fix barrier indexing in tests.
#30307 merged
Jul 18, 2025 -
Exit the CUDA test job early if not all wheels can be downloaded.
#30316 merged
Jul 18, 2025 -
Add
eval_shape
as a method ofTraced
.#30314 merged
Jul 18, 2025 -
#sdy fix-forward jaxlib compatibility issue
#30311 merged
Jul 18, 2025 -
Use direct linearization by default
#30262 merged
Jul 18, 2025 -
[Scaled Matmul] add a sharding rule and fix custom partition method
#30250 merged
Jul 18, 2025 -
[JAX] add a missing sharding rule in ffi.md and ffi.ipynb
#30305 merged
Jul 18, 2025 -
[Mosaic GPU] Forbid the assignment of splat layouts to non-splat constants.
#30280 merged
Jul 18, 2025 -
[Mosaic GPU] Prevent
vector.load
s from being assignedsplat
layouts.#30278 merged
Jul 18, 2025 -
[Mosaic GPU] Introduce constraints in the equational layout inference.
#30275 merged
Jul 18, 2025 -
#sdy add missing include in py_client.cc for Shardy.
#30300 merged
Jul 18, 2025 -
[Mosaic GPU] Add lowering for
tmem_alloc
andtmem_dealloc
.#30223 merged
Jul 18, 2025 -
[Pallas/Mosaic GPU] Add a deviceless export test for Pallas/Mosaic GPU.
#30298 merged
Jul 18, 2025 -
[Mosaic] Add store canonicalization for an expand reshape->store fusion
#30254 merged
Jul 18, 2025 -
[Mosaic TPU] Allow producing predicate phi node in control follow.
#30292 merged
Jul 18, 2025 -
Fix
JAX_ENABLE_X64
env var name for CI scripts.#30295 merged
Jul 17, 2025 -
[XLA] Add a primitive for tagging individual XLA ops with frontend_attributes.
#30282 merged
Jul 17, 2025 -
Add type interface for
jax.nn
#30196 merged
Jul 17, 2025 -
layout: Implement out-of-line template method as a non-template.
#30285 merged
Jul 17, 2025 -
[rollforward]: Add make_transfer_server_interface_factory to the py_socket_transfer jaxlib
#30283 merged
Jul 17, 2025 -
fused_attention_stablehlo_test_gpu: test skip for 12.0
#30272 merged
Jul 17, 2025 -
[Pallas] Add option to allow out-of-bound reads in TPU interpret mode.
#30259 merged
Jul 17, 2025 -
[Mosaic GPU] Add the
Distinct
constraint to the equation system.#30273 merged
Jul 17, 2025 -
[rollforward]: Add transfer_server_interface to xla.cc as an extra optional factory argument
#30269 merged
Jul 17, 2025 -
[Pallas/TPU] Add empty_ref_like function + SMEM support for Refs
#30239 merged
Jul 17, 2025 -
Rollback #30206 due to downstream test failures
#30266 merged
Jul 17, 2025 -
Use registers for state by default in emit_pipeline
#30265 merged
Jul 16, 2025 -
[XLA] Refactor
xla_metadata.py
to resolve dependency cycle#30240 merged
Jul 16, 2025 -
Add nn.logmeanexp.
#30206 merged
Jul 16, 2025 -
[Mosaic] Allow sublane rotation for non-sublane dim aligned shape.
#29419 merged
Jul 16, 2025 -
Reverts aee89744efbd3aa93ffea7a3bab803483add9289
#30260 merged
Jul 16, 2025 -
Disable
too_slow
in data.draw() for test_cast_from_32bit#30261 merged
Jul 16, 2025 -
[jax:benchmark] Add tracing benchmarks for some common operations.
#29413 merged
Jul 16, 2025 -
Set
check_leaks=False
indef direct_linearize
to fix some tests#30256 merged
Jul 16, 2025 -
[Mosaic] Support multiple non-contracting dims if they are collapsable.
#30076 merged
Jul 16, 2025 -
fix jax2tf scatter impl
#30255 merged
Jul 16, 2025 -
[pallas]
pl.core_map
now supportfunctools.partial
ed functions#30220 merged
Jul 16, 2025 -
[Mosaic TPU] Fix the assumption that beforeBody's inputs size is same as afterBody's in scf.while.
#30244 merged
Jul 16, 2025 -
support cudnn sdpa on gb300 (compute_cap=10.3) with cudnn > 9.11
#30242 merged
Jul 16, 2025 -
[Mosaic GPU] Add
tmem_alloc
andtmem_dealloc
to the Mosaic GPU dialect.#30183 merged
Jul 16, 2025 -
Includes
sharding_constraint_p
primitive inroofline
.#30232 merged
Jul 16, 2025 -
[Mosaic GPU] Add
tcgen05.mma
to the Mosaic GPU dialect.#30189 merged
Jul 16, 2025 -
[Mosaic GPU] Add equational layout inference rule for
scf.WhileOp
.#30230 merged
Jul 16, 2025 -
[Mosaic GPU] Add equational layout inference rule for
scf.ForOp
.#30229 merged
Jul 16, 2025 -
[Pallas] Physicalize fusion-dtype input avals before evaluating expressions.
#30241 merged
Jul 16, 2025 -
- Add more items in the ragged paged attention auto tuning table
#30151 merged
Jul 15, 2025 -
Remove version checks for ml_dtypes >= 0.5.0
#30235 merged
Jul 15, 2025 -
[Pallas] Support all integer casting cases.
#30127 merged
Jul 15, 2025 -
[Pallas][Mosaic GPU] Thread the python kernel function name to the MLIR kernel name.
#30200 merged
Jul 15, 2025 -
[doc] add docs for jax.lax.reduce_window
#30225 merged
Jul 15, 2025 -
Mark the jax.util submodule as deprecated
#30227 merged
Jul 15, 2025 -
[direct-linearize] add
instantiate
to linearize_jaxpr, fixpoint test#30231 merged
Jul 15, 2025 -
[doc] replace tensorflow.org/xla with openxla.org/xla
#30224 merged
Jul 15, 2025 -
[pallas:mosaic] Log an error when an
AsyncCopyDescriptor
was unused#30219 merged
Jul 15, 2025 -
Add default window_strides and padding to lax.reduce_window.
#30179 merged
Jul 15, 2025 -
Bump
jax_export_calling_convention_version
to 10#30186 merged
Jul 15, 2025 -
[Pallas] Adds context manager for forcing TPU interpret mode.
#29909 merged
Jul 15, 2025 -
Raise exception instead of returning NotImplementedError.
#30222 merged
Jul 15, 2025 -
Fix to allow slicing into replicated dims
#30216 merged
Jul 15, 2025 -
Add missing not-implemented sharding rule in
third_party/py/jax/_src/cudnn/fused_attention_stablehlo.py
.#30215 merged
Jul 15, 2025 -
[Mosaic GPU] Remove unused
MosaicGPU_WGMMALayout
from the dialect.#30217 merged
Jul 15, 2025 -
[Mosaic GPU] Improve error messages in mosaic_gpu_get_function
#30199 merged
Jul 15, 2025 -
[shape_poly] Add support for triu_indices and tril_indices
#30211 merged
Jul 15, 2025 -
[Pallas/Mosaic GPU] Introduce the
print_layout
primitive.#30201 merged
Jul 15, 2025 -
remove caveat about random bit generation from auto-parallelization doc
#30212 merged
Jul 15, 2025 -
[JAX] Use JAX test utils to assert array equality/closeness
#30210 merged
Jul 15, 2025 -
If operand and indices are fully replicated, allow that case to work in gather.
#30208 merged
Jul 14, 2025 -
Minor cleanup in stages.py
#30192 merged
Jul 14, 2025 -
Eliminate skiptest for C API in
memories_test
and add TPU C API config#30207 merged
Jul 14, 2025 -
[pallas:mosaic] Added
"core_parallel"
to the alloweddimension_semantics=
#30152 merged
Jul 14, 2025 -
[direct-linearize] fix hijax tests with JAX_USE_DIRECT_LINEARIZE=1
#30161 merged
Jul 14, 2025 -
[pallas] Removed some unused internal functions
#30182 merged
Jul 14, 2025 -
[Pallas] Fix some sus tests, make them faster
#30181 merged
Jul 14, 2025 -
Add clarification to softmax and log_softmax regarding masked-out elements.
#30167 merged
Jul 14, 2025 -
Gather sharding rule in explicit mode
#30184 merged
Jul 14, 2025 -
[pallas:mosaic] Ensured that a nested jit in tpu_fusible_matmul_test.py has no compiler_options=
#30193 merged
Jul 14, 2025 -
[Mosaic GPU] Add constraints to
EquationSystem
.#30149 merged
Jul 14, 2025 -
[Mosaic GPU] Fix a race in test_remote_async_copy
#30159 merged
Jul 14, 2025 -
[Mosaic] Fold load->reshape and reshape->store into strided load/store respectively.
#29640 merged
Jul 12, 2025 -
Summarize large constants that may appear in Jaxpr literals
#30173 merged
Jul 12, 2025 -
Pallas/Mosaic support host input.
#30102 merged
Jul 12, 2025 -
Allow dropping into explicit_axes with a reshard in an outer vmap with spmd_axis_name set.
#30165 merged
Jul 11, 2025 -
[Profiler] Correct xprof trace call and update tests
#30155 merged
Jul 11, 2025 -
Shard flash_attention_test
#30162 merged
Jul 11, 2025 -
Remove deprecated jax.util.HashableFunction
#30128 merged
Jul 11, 2025 -
[NFC][Mosaic:TPU] Clean up elementwise rules
#30139 merged
Jul 11, 2025 -
Fix pjit failures with direct linearize
#30160 merged
Jul 11, 2025 -
[direct-linearize] fix tests to pass under JAX_USE_DIRECT_LINEARIZE=1
#30040 merged
Jul 11, 2025 -
The repr of Literal should mention the class name "Literal" in it
#30154 merged
Jul 11, 2025 -
[pallas:mosaic] Enabled a few more primitives for all core types
#30153 merged
Jul 11, 2025 -
[Pallas] Allow 8bit gathers
#30144 merged
Jul 11, 2025 -
Compute non-zero flops for
gather
withFILL_OR_DROP
.#29914 merged
Jul 11, 2025 -
Simplify the handling of closed-over constants in Jaxprs
#29882 merged
Jul 11, 2025 -
[Pallas:MGPU] Fix a B200 test that was still using the recently removed direct TMEM load syntax
#30148 merged
Jul 11, 2025 -
[Mosaic GPU][NFC] Clean up pattern matching logic.
#30146 merged
Jul 11, 2025 -
[Mosaic GPU][NFC] Remove duplicated code.
#30145 merged
Jul 11, 2025 -
[Pallas:MGPU] Raise a helpful error message in case someone tries to load/store to TMEM refs directly
#30142 merged
Jul 11, 2025 -
[Pallas:MGPU] Disallow direct reads from TMEM refs
#30119 merged
Jul 11, 2025 -
[Mosaic:TPU][Relayouts] Materialize replicated offsets
#29760 merged
Jul 11, 2025 -
[Pallas:MGPU] Expose block-scaled MMA to Pallas on Blackwell
#30116 merged
Jul 11, 2025 -
[Mosaic GPU][NFC] Add utils to create and check the SMEM memory space attribute.
#30114 merged
Jul 11, 2025 -
[Pallas:MGPU] Align TMEM allocations to 16 bytes
#30115 merged
Jul 11, 2025 -
[Mosaic GPU] Fix B200 CI failure - missed one call site while changing function signature
#30136 merged
Jul 11, 2025 -
[Mosaic:TPU][Relayouts] Fallback for implicit dim change that goes through 32-bit native tiling
#29799 merged
Jul 11, 2025 -
roll-forward with fixes
#30135 merged
Jul 11, 2025 -
[Pallas] Make platforms param configurable for pl.lower_as_mlir.
#30121 merged
Jul 11, 2025 -
fix cudnn sdpa invalid seqlen for unused segments
#30023 merged
Jul 11, 2025 -
Reverts d7badcae5a2ab3009641c1a06e9490c5ba2013ac
#30134 merged
Jul 11, 2025 -
On-the-fly trace-time DCE using Python ref counting
#30062 merged
Jul 11, 2025 -
Add
scatter*
primitives toroofline
.#29824 merged
Jul 10, 2025 -
Increase mutable_array_test sharding
#30131 merged
Jul 10, 2025 -
remove
pjit_p
, leaving onlyjax.extend.core.primitives.jit_p
#30080 merged
Jul 10, 2025 -
Add
is_ref
toShapeDtypeStruct
to allow doing AOT with duck types.#30129 merged
Jul 10, 2025 -
[Pallas] Update changelog for new emit_pipeline features.
#30124 merged
Jul 10, 2025 -
Add sharding to fused_attention_stablehlo_test_gpu
#30126 merged
Jul 10, 2025 -
[Pallas TPU] Refactor casting logic to use bits instead of bytes and allow uint upcasts.
#30103 merged
Jul 10, 2025 -
[Mosaic GPU][NFC] Enable the gmem argument to
async_copy
to be anir.BlockArgument
#30112 merged
Jul 10, 2025 -
[Mosaic GPU] Add
CustomReturnOp
to serve as the terminator ofCustomPrimitiveOp
.#30111 merged
Jul 10, 2025 -
Remove a hanging test from the TPU continuous and nightly wheel test.
#30099 merged
Jul 10, 2025 -
[Mosaic GPU] Get rid of
func.FuncOp
s ingpu_layout_inference_test
.#30087 merged
Jul 10, 2025 -
[Pallas:MGPU] Allow specifying layout for TMEM refs
#30110 merged
Jul 10, 2025 -
Internal CI Change
#30108 merged
Jul 10, 2025 -
[Mosaic TPU] Explicitly specify the padding value in vector.transfer_read
#30117 merged
Jul 10, 2025 -
[pallas:mosaic] Enabled pytype for Mosaic GPU lowering
#30100 merged
Jul 10, 2025 -
[Mosaic GPU][NFC] Rename
equations
toeqns
ingpu_layout_inference_test.py
.#30085 merged
Jul 10, 2025 -
[Mosaic GPU] Add a rule for
mosaic_gpu.WGMMAOp
in the equational layout inference system.#30020 merged
Jul 10, 2025 -
[Mosaic GPU] Add support for reductions with data partially replicated across warps
#30053 merged
Jul 10, 2025 -
[Mosaic GPU] Add layout inference rule for
vector.splat
in the equational layout inference.#30018 merged
Jul 10, 2025 -
[Mosaic GPU] Increase fuel parameter for peer id recomputation.
#30098 merged
Jul 10, 2025 -
[Mosaic] Add support for dynamic filter values
#29938 merged
Jul 10, 2025 -
Allow async remote copy to take device id as a dict of axis names and indices.
#29911 merged
Jul 10, 2025 -
[shard-map] fix partial manual eager mode
#30078 merged
Jul 9, 2025 -
Allow
P(('a', 'b'))
in aval.sharding where mesh is('a', 2, Explicit), ('b', 2, Auto)
.#30083 merged
Jul 9, 2025 -
Move xla_workspace4() and xla_workspace3() calls above Python init.
#30088 merged
Jul 9, 2025 -
Deprecate jax.interpreters.xla.canonicalize_dtype.
#30082 merged
Jul 9, 2025 -
Add support for logging vectors on SC.
#29955 merged
Jul 9, 2025 -
Deprecate jax.lib.xla_bridge.get_compile_options
#30079 merged
Jul 9, 2025 -
[pallas:mosaic] Enabled pytype for Mosaic TPU lowering
#30061 merged
Jul 9, 2025 -
[Pallas:MGPU] Add support for reductions that might require cross-warp communication
#30054 merged
Jul 9, 2025 -
[Mosaic GPU] Insert nvvm.minctasm = 1 to make sure ptxas recognizes setmaxnreg
#30090 merged
Jul 9, 2025 -
Finalize deprecation of jax.core.typecheck
#30069 merged
Jul 9, 2025 -
Add
use_shardy_partitioner
intf.tpu.XlaOptions
,TPUReplicateMetadata
. The default value is false.#29646 merged
Jul 9, 2025 -
[Mosaic GPU] Improve the error message when FragmentedArray slice is not tile aligned
#30086 merged
Jul 9, 2025 -
[Pallas:MGPU] Create a programatic way to derive reduced layouts from tiled layouts in Pallas
#30051 merged
Jul 9, 2025 -
[mosaic_gpu] Enable pytype for Mosaic GPU .py files
#30065 merged
Jul 9, 2025 -
[Mosaic GPU] Shorten the name of
equations.Expression
variants.#29980 merged
Jul 9, 2025 -
remove unused internal function
#30077 merged
Jul 9, 2025 -
Simplify
as_manual_mesh
now that shard_map only tracks manual axes and notauto
as it did previously.#30081 merged
Jul 9, 2025 -
[Pallas TPU] Enable constraining memory spaces of Refs and closing over them in core_map.
#30035 merged
Jul 9, 2025 -
jax/BUILD: switch to pytype_strict_library everywhere
#30028 merged
Jul 9, 2025 -
Finalize deprecation of jax.lib.xla_client.DeviceAssignment
#30074 merged
Jul 8, 2025 -
[Pallas TPU] Add lookahead for input buffers to emit_pipeline.
#30070 merged
Jul 8, 2025 -
[Mosaic] Add tpu.wait_indirect_dma
#29935 merged
Jul 8, 2025 -
Fix broken JAX AI Stack links and tweak surrounding wording + styling
#30037 merged
Jul 8, 2025 -
[Pallas] Make CommsEffect lowerable
#30075 merged
Jul 8, 2025 -
Add a few tests to check lowering cache hits
#30072 merged
Jul 8, 2025 -
In print_environment_info, print environment variables that start with XLA_
#28839 merged
Jul 8, 2025 -
Add jax.tree.reduce_associative.
#29997 merged
Jul 8, 2025 -
[Pallas] Properly handle consts in while_loop fusible_dtype rule.
#30067 merged
Jul 8, 2025 -
Remove print statement from tree_util_test.TreeTest.testStringRepresentation.
#30071 merged
Jul 8, 2025 -
Update changelog
#30060 merged
Jul 8, 2025 -
Reverts e562cddc3421f3ee10a031ae21c48ed85f3edbb2
#30064 merged
Jul 8, 2025 -
Add heartbeat argument to distributed.initialize.
#30034 merged
Jul 8, 2025 -
[Mosaic GPU] Enable
mypy
and fix type errors fordialect_lowering.py
.#30046 merged
Jul 8, 2025 -
[Mosaic GPU] Remove a few stale
jaxlib
guards.#30045 merged
Jul 8, 2025 -
Fix checking presence of
in_devices
in_get_global_axis_size
.#30030 merged
Jul 8, 2025 -
[Mosaic GPU] Add
optimization_barrier
to the new equational layout inference.#29974 merged
Jul 8, 2025 -
[Mosaic GPU] Add support for reductions across a subset of warps
#30049 merged
Jul 8, 2025 -
[Mosaic GPU] Make sure that TMEMLayout.canonicalize returns a TMEMLayout
#30059 merged
Jul 8, 2025 -
[Mosaic GPU] Add support for elementwise ops in the new equational layout inference.
#29973 merged
Jul 8, 2025 -
[pallas:mosaic] Semaphore lowering rules are now registered for all core types
#30031 merged
Jul 8, 2025 -
Enable lowering ragged dot through GPU
#30006 merged
Jul 8, 2025 -
[Mosaic GPU] Replace repeated calls to
reduce_hint
withreduce_hints
.#29962 merged
Jul 8, 2025 -
[jax:pallas] Remove
weakref_lru_cache
from_trace_kernel_to_jaxpr
.#29934 merged
Jul 8, 2025 -
Mark ApiTest.test_pmap_global_cache as thread-unsafe.
#30057 merged
Jul 8, 2025 -
scaled_matmul_stablehlo_test: add multiaccelerator tag
#30047 merged
Jul 8, 2025 -
[Mosaic GPU] Remove trivial warp and lane_dims while canonicalizing TiledLayouts
#30052 merged
Jul 8, 2025 -
[Mosaic GPU] Allow the loose extraction of assignments from hints.
#29971 merged
Jul 8, 2025 -
[pallas:mosaic] Added return types to async copy APIs
#30032 merged
Jul 8, 2025 -
Fix ragged contracting mode to handle batch dimensions correctly
#30005 merged
Jul 8, 2025 -
[Mosaic GPU] Change the introduction of assignments for
Hint
s to account for unsatisfiable systems.#29970 merged
Jul 8, 2025 -
[Mosaic GPU] Implement
meet
andjoin
for replicated layouts.#29952 merged
Jul 8, 2025 -
[Pallas:MGPU] Fix a type annotation for layouts in MGPU primitives
#30014 merged
Jul 8, 2025 -
rename
pjit_p
tojit_p
internally, export both as aliases#29998 merged
Jul 8, 2025 -
[Pallas TPU] Enable pltpu.HBM as BlockSpec memory space
#29873 merged
Jul 8, 2025 -
[pallas] Support reshape with trailing
1
s in fuser.#29920 merged
Jul 8, 2025 -
[pallas] Generalize support for merging dimensions in fuser reshape.
#29800 merged
Jul 8, 2025 -
[caches] Fix pe.close_jaxpr cache leak
#29808 merged
Jul 8, 2025 -
[caches] Register weakref_lru_caches with util._caches
#30039 merged
Jul 8, 2025 -
[mutable-arrays] fix scan ad bug
#30026 merged
Jul 8, 2025 -
Reverts 8d304fe7635fc450817b6abae5ef703a397e9b68
#30038 merged
Jul 8, 2025 -
Finalize several deprecations in jax.core and jax.lib.xla_client
#30036 merged
Jul 7, 2025 -
[caches] Simplifications in the management of caches
#30022 merged
Jul 7, 2025 -
Fix a bug where vmapping over an explicit sharded dim on 1 mesh axis lead to a KeyError in shard_map.
#30029 merged
Jul 7, 2025 -
Remove deprecated APIs jax.lib.xla_extension.ArrayImpl and XlaRuntimeError
#30025 merged
Jul 7, 2025 -
Finalize deprecation of jax.extend.ffi
#30013 merged
Jul 7, 2025 -
Simplify the jaxpr emitted for jnp.repeat in the case that the input dim size is 1.
#30024 merged
Jul 7, 2025 -
Finalize deprecations of jax.interpreters.xla.abstractify and pytype_aval_mappings
#30021 merged
Jul 7, 2025 -
[JAX] Use qualified names in tracebacks.
#29932 merged
Jul 7, 2025 -
Updated links in NamedSharding and Mesh docstrings
#29977 merged
Jul 7, 2025 -
Remove a number of deprecated symbols in jax.lib.xla_client
#30017 merged
Jul 7, 2025 -
[Mosaic GPU] Fix a broken B200 test
#30016 merged
Jul 7, 2025 -
[Mosaic GPU] Add support for block-scaled f4e2m1fn MMAs
#29985 merged
Jul 7, 2025 -
[debug_info] Fix debug info in presence of
move_binders_to_front
.#30012 merged
Jul 7, 2025 -
Fix breakage indroduced in lax.parallel build refactor
#30011 merged
Jul 7, 2025 -
Deprecate jax.scipy.special.sph_harm
#30010 merged
Jul 7, 2025 -
Dropped a few unused functions
#30007 merged
Jul 7, 2025 -
[Pallas:MGPU] Allow scalar division in warp-level code
#30004 merged
Jul 7, 2025 -
Improve rendering of function names in jaxpr equation profile.
#29918 merged
Jul 7, 2025 -
Readd a name to the MLIR lowering of core.closed_call.
#29913 merged
Jul 7, 2025 -
[Pallas:MGPU] Add support for n == 256 to block-scaled MMA
#29978 merged
Jul 7, 2025 -
[Pallas:MGPU] Flip the incorrectly ordered arguments of infer_tmem_cols_layout
#29979 merged
Jul 7, 2025 -
[Mosaic GPU] Change
simplify_*
toreduce_*
everywhere to make the terminology consistent and accurate.#29961 merged
Jul 7, 2025 -
Fix typo in references for
numpy.column_stack()
#29988 merged
Jul 7, 2025 -
tweak pe._drop_unused_vars so we only call make_jaxpr_effects once
#29995 merged
Jul 6, 2025 -
split out some cleanup from #29967
#29994 merged
Jul 6, 2025 -
This change fixes 3 issues:
#29992 merged
Jul 6, 2025 -
implements non-factorially scaled jet computation
#29676 merged
Jul 5, 2025 -
increase shard count on test file
#29991 merged
Jul 5, 2025 -
Reorder GSPMDSharding constructor methods to ensure fast path is taken.
#29983 merged
Jul 4, 2025 -
Add some asserts to make sure device_list is an instance of xc.DeviceList
#29981 merged
Jul 4, 2025 -
[Pallas:MGPU] Expose TMEM_NATIVE_ROW_LAYOUT
#29975 merged
Jul 4, 2025 -
[Mosaic GPU] Use the newly introduced scale copy API to simplify and strengthen the block scaled test
#29958 merged
Jul 4, 2025 -
[Mosaic GPU] Add a new API to help with copying matmul scales to TMEM
#29957 merged
Jul 4, 2025 -
[jax] Test int4 host compute
#29959 merged
Jul 4, 2025 -
[Pallas:MGPU] Allow explicit arrivals on Barriers with orders_tensor_core=True
#29892 merged
Jul 4, 2025 -
[Mosaic GPU][NFC] Fix the typing anotation for
TiledLayout.base_tile_shape
.#29972 merged
Jul 4, 2025 -
[Mosaic GPU][NFC] Fix type of
hints
in test.#29969 merged
Jul 4, 2025 -
[mutable-arrays] several fixes for direct-linearize + mutable arrays
#29966 merged
Jul 3, 2025 -
Avoid strong refs to tracers in DynamicJaxprTrace.
#29937 merged
Jul 3, 2025 -
[JAX][DOC] Add optimizer state offloading doc
#28988 merged
Jul 3, 2025 -
[Mosaic GPU] Add logic to derive default layouts in an undetermined system.
#29923 merged
Jul 3, 2025 -
[Mosaic GPU] Unify
evaluate_equation
andsimplify_equation
intoreduce_equation
.#29928 merged
Jul 3, 2025 -
[Mosaic GPU][NFC] Import
equations
aseqns
inlayout_inference2.py
.#29925 merged
Jul 3, 2025 -
[mosaic] better error for in_layout size mismatch
#29899 merged
Jul 3, 2025 -
[Mosaic GPU] Propagate layouts from consumers and producers in the new layout inference.
#29922 merged
Jul 3, 2025 -
[Pallas] Improve the error message when writing a ref to a ref
#29950 merged
Jul 3, 2025 -
[Pallas:MGPU] Add support for TCGEN05_TRANSPOSED layout
#29895 merged
Jul 3, 2025 -
[Pallas:MGPU] Add documentation for tcgen05 functions to the MGPU reference
#29894 merged
Jul 3, 2025 -
[Mosaic] Add tpu.enqueue_indirect_dma verification test on TC
#29936 merged
Jul 3, 2025 -
Move
jax._src.image
to its own build rule#29889 merged
Jul 2, 2025 -
[dep] remove some finalized deprecations in jax.core
#29943 merged
Jul 2, 2025 -
Replace old heartbeat options with simplified heartbeat options.
#29813 merged
Jul 2, 2025 -
Document out_sharding in jax.numpy functions
#29930 merged
Jul 2, 2025 -
Move
jax._src.scipy
to its own build rule#29915 merged
Jul 2, 2025 -
Update xla_flags.md
#29901 merged
Jul 2, 2025 -
[Mosaic GPU] Introduce
{Least,Most}ReplicatedExpression
in the equation system.#29890 merged
Jul 2, 2025 -
Add out_sharding to
jnp.arange
#29910 merged
Jul 2, 2025 -
Skip test when there aren't enough GPUs.
#29903 merged
Jul 2, 2025 -
[JAX] Add a cache around HLO lowering rules.
#29789 merged
Jul 1, 2025 -
Disallow aliased mutable array arguments to vmap.
#29433 merged
Jul 1, 2025 -
[Pallas TPU] Add support for per-input multiple buffering to emit_pipeline.
#29821 merged
Jul 1, 2025 -
Fix JAX
scatter
docs.#29825 merged
Jul 1, 2025 -
Remove the implementation of infeed and outfeed.
#29854 merged
Jul 1, 2025 -
Move
jax._src.tpu
to its own build rule#29902 merged
Jul 1, 2025 -
Reverts 822c4f8534b7bd4ba9eea1d333fb9257af9285e7
#29905 merged
Jul 1, 2025 -
Make
axis_data.explicit_mesh_axis
dynamic whose value depends on the current mesh context.#29877 merged
Jul 1, 2025 -
Fix failing test because of stale StableHLO version.
#29900 merged
Jul 1, 2025 -
Use more numerically stable computation for
jax.random.logistic
.#29857 merged
Jul 1, 2025 -
Increase error tolerance for cudnn sdpa fp8 inference test
#29501 merged
Jul 1, 2025 -
simplify our jaxpr builder not to hold strong refs to tracers
#29898 merged
Jul 1, 2025 -
Move
jax._src.nn
to its own build rule#29879 merged
Jul 1, 2025 -
[Mosaic GPU] Implement a skeleton for the new equation-system-driven layout inference.
#29887 merged
Jul 1, 2025 -
[Mosaic GPU] Factor parametrization of WGMMA tests
#29891 merged
Jul 1, 2025 -
[pallas:triton] Use
metadata
instead ofTritonCompilerParams.serialized_metadata
#29885 merged
Jul 1, 2025 -
[Pallas][Mosaic GPU] Rename for_tensor_core to orders_tensor_core
#29888 merged
Jul 1, 2025 -
[Pallas:MGPU] Make TMEM reads and writes explicitly asynchronous
#29886 merged
Jul 1, 2025 -
[Pallas:MGPU] Add support for single-level slicing of the WGMMA accumulator
#29884 merged
Jul 1, 2025 -
[Pallas:TPU] Canonicalize axis in pltpu.repeat
#29880 merged
Jul 1, 2025 -
[Mosaic GPU] Add an
__and__
implementation forEquationSystem
.#29777 merged
Jul 1, 2025 -
Automated Code Change
#29840 merged
Jul 1, 2025 -
Silence buildifier string list warnings
#29878 merged
Jul 1, 2025 -
print
pjit_p
as"jit"
(in jaxpr, etc.)#29876 merged
Jul 1, 2025 -
[Pallas] Pass a
metadata
dict to mosaic#29531 merged
Jul 1, 2025 -
Correct documentation in
gpu_memory_allocation.rst
#29864 merged
Jul 1, 2025 -
[doc] Clarify Profiling Docs for XProf and Tensorboard integration
#29509 merged
Jun 30, 2025 -
Add memory_space to state.AbstractRef and use it instead of AbstractMemoryRef in Pallas
#29761 merged
Jun 30, 2025 -
Move
jax._src.debugger
to its own build rule#29870 merged
Jun 30, 2025 -
Move
jax._src.blocked_sampler
to its own build rule#29869 merged
Jun 30, 2025 -
[remat] fix 'rematted_computation' annotation
#29868 merged
Jun 30, 2025 -
[Mosaic] Add DMA source and target shape verification
#29755 merged
Jun 30, 2025 -
Show stdout and stderr on test failure.
#29858 merged
Jun 30, 2025 -
Move jax._src.random to its own build target
#29855 merged
Jun 30, 2025 -
Change ValueError to NotImplementedError when setting partitioned=True on callbacks for non-CPU/GPU devices.
#29850 merged
Jun 30, 2025 -
[CI] Relax the timeout constraint on pytest_tpu tests
#29861 merged
Jun 30, 2025 -
[Mosaic GPU] Fix another small bug in FragmentedArray.transfer_tiled
#29851 merged
Jun 30, 2025 -
[JAX] Remove lax.infeed and lax.outfeed from JAX's public APIs.
#29815 merged
Jun 30, 2025 -
[Pallas:TPU] Bump the too low minimum libtpu version for a test
#29852 merged
Jun 30, 2025 -
[Mosaic GPU] Add an initial skeleton for an equation system for layout/transform inference.
#29771 merged
Jun 30, 2025 -
[Mosaic GPU] Fix a minor bug in FragmentedArray.transfer_tiled
#29847 merged
Jun 30, 2025 -
[debug_info] Improve debug info for loop constructs.
#29539 merged
Jun 30, 2025 -
[Mosaic GPU] Fix failures in B200 CI
#29845 merged
Jun 30, 2025 -
[Mosaic GPU] Fix bug in
Tiling.remove_dimension
.#29844 merged
Jun 30, 2025 -
[Mosaic:TPU][Relayouts] 2nd minor implicit -> minor implicit
#29795 merged
Jun 30, 2025 -
[Mosaic:TPU][Relayouts] (1, 128 * packing) <-> (n * packing, 128) retilings for n > 1
#29732 merged
Jun 30, 2025 -
Move jax._src.dlpack to its own build rule
#29831 merged
Jun 29, 2025 -
Do not propagate name_stacks into lower_jaxpr_to_fun.
#29783 merged
Jun 28, 2025 -
Populate the HLO op_type field from JAX StableHLO.
#29811 merged
Jun 28, 2025 -
Move jax._src.debugging to its own build rule
#29830 merged
Jun 28, 2025 -
Fix
dynamic_update_slice
transpose rule to createzeros
on the same sharding asupdate
#29828 merged
Jun 28, 2025 -
mutable array scan ad fix
#29817 merged
Jun 28, 2025 -
Move jax._src.checkify to its own build rule.
#29826 merged
Jun 28, 2025 -
Remove JAX_IFRT_VERSION_NUMBER check for types in XLA/JAX
#29784 merged
Jun 28, 2025 -
Add a test for bf16 1D memref load -> broadcast issue.
#29661 merged
Jun 27, 2025 -
[Mosaic] Support faster packing/unpacking for bf16 <-> f8e5m2 and f8e4m3fn.
#29700 merged
Jun 27, 2025 -
[Mosaic:TPU][Relayouts] (packing, 128) to (1, 128 * packing)
#29729 merged
Jun 27, 2025 -
Include
dispatch
primitives inroofline
.#29793 merged
Jun 27, 2025 -
Add
out_sharding
tojnp.ones
,jnp.zeros
andjnp.emtpy
#29814 merged
Jun 27, 2025 -
Allow
prng.KeyTy
as a validRooflineShape.dtype
.#29790 merged
Jun 27, 2025 -
[Pallas MGPU] Remove unnecessary synchronization if we’re not copying out refs.
#29713 merged
Jun 27, 2025 -
Reverts ebf7d6498526f601da3ffeb3c6083702ceee3940
#29796 merged
Jun 27, 2025 -
Add an environment variable to enable GPU collective cancelling.
#29505 merged
Jun 27, 2025 -
[JAX] Validate efficient resharding via
jax.device_put
with a complex mesh change#29798 merged
Jun 27, 2025 -
[Pallas:MGPU] Bring back plgpu.load_p to control the optimized flag
#29804 merged
Jun 27, 2025 -
Relax
rtol
intest_shmap_unreduced_custom_vjp_bwd
for GPU backend#29812 merged
Jun 27, 2025 -
Flip the order of arguments in
util.wrap_name
to match the usage in implementation#29785 merged
Jun 27, 2025 -
Add B200 and H100 testing to nightlies
#29778 merged
Jun 27, 2025 -
[jaxpr consts] Propagate closed-over constants to the top-level jit (take 2)
#29768 merged
Jun 27, 2025 -
[Pallas:MGPU] Enable 2CTA tcgen05.mma with M=128 (64 per block) in Pallas
#29805 merged
Jun 27, 2025 -
[pallas] Fix broadcast in
None
block dim in fuser.#29765 merged
Jun 27, 2025 -
[Mosaic GPU] Generalize transfer_tiled to allow tiled dims split inside and across memory tiles
#29770 merged
Jun 27, 2025 -
[Mosaic GPU][NFC] Refactor transfer_tiled in preparation of new features
#29769 merged
Jun 27, 2025 -
[Pallas:MGPU] Add support for the TMEM_NATIVE layout
#29747 merged
Jun 27, 2025 -
[Pallas:MGPU] Provide output layout hint if a primitive is followed by a layout cast
#29741 merged
Jun 27, 2025 -
[attrs] remove attrs and boxes v1 (attrs_tracked etc)
#29794 merged
Jun 27, 2025 -
[Mosaic GPU] Make the Pallas Blackwell matmul kernel persistent.
#29312 merged
Jun 27, 2025 -
[pallas] Relax check for block spec equality in fuser.
#29766 merged
Jun 27, 2025 -
[mutable-arrays] fix mutable array AD bug with HOPs
#29797 merged
Jun 27, 2025 -
[ra2a] improve ragged_all_to_all batching rule to perform one call
#29763 merged
Jun 26, 2025 -
Add a comment to make the current McJAX-E code a bit clearer
#29792 merged
Jun 26, 2025 -
[Mosaic] Move gather/scatter differentiation to a separate method
#29781 merged
Jun 26, 2025 -
add psend/precv test for 2 GPUs and re-enable tests for 8 GPUs
#29762 merged
Jun 26, 2025 -
[Pallas][Mosaic GPU] Expose tcgen05.commit_arrive as a standalone primitive.
#29673 merged
Jun 26, 2025 -
Reverts 4d142d89e8e260266258ebf55e15620a4eeea8d4
#29788 merged
Jun 26, 2025 -
disable gb300 sdpa test until cudnn supports it properly
#29786 merged
Jun 26, 2025 -
Add _XlaShardingV2 to tf.XlaShardOp and use it for tf2xla lowering.
#29172 merged
Jun 26, 2025 -
Update the next release version to 0.7.0
#29782 merged
Jun 26, 2025 -
Prepare TPU interpret mode for AbstractRefs having a memory_space attr.
#29780 merged
Jun 26, 2025 -
Make the
device_assignment
argument of.compile
kwarg only#29776 merged
Jun 26, 2025 -
[Mosaic] Support bf16 <-> f8 casting.
#29557 merged
Jun 26, 2025 -
Remove donated_invars from PjitParams as a field since it already exists in PjitParams
#29779 merged
Jun 26, 2025 -
Remove name_stack argument from lower_jaxpr_to_module.
#29775 merged
Jun 26, 2025 -
[jaxpr consts] Propagate closed-over constants in lax.composite
#29738 merged
Jun 26, 2025 -
Add modes for
config.jax_dump_ir_to
via comma-delimitedconfig.jax_dump_ir_modes
.#29750 merged
Jun 26, 2025 -
[Mosaic GPU] Add a general broadcast_in_dim method that works for any TiledLayout
#29737 merged
Jun 26, 2025 -
fix ra2a vmap (batching) rule
#29757 merged
Jun 26, 2025 -
Add a field to plugin_attributes to indicate whether the PjRt plugin supports cross-host device transfers.
#29659 merged
Jun 26, 2025 -
[NFC][Mosaic:TPU] Minor refactor of changeTiling to make chaining packs/unpacks and relayouts easier
#29730 merged
Jun 25, 2025 -
Include
ad_checkpoint
primitives inroofline
.#29753 merged
Jun 25, 2025 -
Expose CompiledMemoryStats::peak_memory_in_bytes to Python.
#29668 merged
Jun 25, 2025 -
[JAX] Catch a Python exception from
nb::hash
call#29756 merged
Jun 25, 2025 -
[Mosaic:TPU][Relayouts] Express some retilings as unpack+pack and remove offset restrictions
#29534 merged
Jun 25, 2025 -
[Pallas][Mosaic GPU] Add support for TMEM Ref aliasing.
#29471 merged
Jun 25, 2025 -
[Mosaic] Add tpu.enqueue_indirect_dma and verification tests
#29665 merged
Jun 25, 2025 -
Add an inline= argument to mlir.register_lowering.
#29711 merged
Jun 25, 2025 -
[Mosaic] Add vmem_shared
#29720 merged
Jun 25, 2025 -
[jaxlib] Guard new transfer library API calls on
JAX_IFRT_VERSION_NUMBER
.#29746 merged
Jun 25, 2025 -
[JAX] Remove jax.interpreters.mlir.flatten_lowering_ir_args.
#29706 merged
Jun 25, 2025 -
[Mosaic GPU] Add preliminary support for scaled tcgen05.mma
#29656 merged
Jun 25, 2025 -
Update
third_party/xla/workspace.bzl
references tothird_party/xla/revision.bzl
.#29745 merged
Jun 25, 2025 -
Migrate to
jax.sharding
endpoint forreshard
andauto_axes
instead of the experimental endpoint.#29735 merged
Jun 25, 2025 -
cuFFT: remove deprecated enum values
#29739 merged
Jun 25, 2025 -
[pallas:mosaic] Added
device_id
andcore_id
totpu.dma_wait2
op#29682 merged
Jun 25, 2025 -
[Mosaic:TPU] 16-bit tpu.iota
#28987 merged
Jun 25, 2025 -
Removed unused matplotlib dependency from //tests/mosaic:gpu_test
#29734 merged
Jun 25, 2025 -
Allow all shardings if
exported.nr_devices
is 1 in_export.py
.#29690 merged
Jun 25, 2025 -
Update socket transfers to call new transfer library APIs.
#29716 merged
Jun 25, 2025 -
Update ml_dtypes to 0.5.1 to align with JAX and TensorFlow
#29718 merged
Jun 25, 2025 -
Export
jax.P
as an API which is an alias forjax.sharding.PartitionSpec
.#29725 merged
Jun 25, 2025 -
[JAX] Fallback to deprecated field output_memory_colors for now, due to a bug in libtpu.
#29712 merged
Jun 25, 2025 -
[mosaic] Removed
kernel
andbackend
parameters fromtpu_custom_call
APIs#29687 merged
Jun 25, 2025 -
remove use of nanobind for _export pass with shardy
#29555 merged
Jun 25, 2025 -
[pallas] Fixed the type of
Mesh.backend
#29723 merged
Jun 24, 2025 -
Add JAX documentation for psend and precv API
#29677 merged
Jun 24, 2025 -
Use more concise tracer reprs by default.
#29507 merged
Jun 24, 2025 -
Have tpu7x support in Jax.
#29719 merged
Jun 24, 2025 -
Add a hash/equality to JaxprEqnContext.
#29707 merged
Jun 24, 2025 -
Move
jax._src.cudnn.*
sources to their own build rule#29708 merged
Jun 24, 2025 -
Rollback "Custom batching rule for Cholesky that ensures symmetrize."
#29710 merged
Jun 24, 2025 -
Add very basic support for shard_map + unreduced.
#29508 merged
Jun 24, 2025 -
[JAX] Disable psend precv shard_map tests temporarily.
#29704 merged
Jun 24, 2025 -
Deprecate jax.dtypes.SUPPORTED_DTYPES and add is_supported_dtype()
#29616 merged
Jun 24, 2025 -
Remove obsolete submodule jax._src.scipy.interpolate
#29696 merged
Jun 24, 2025 -
Removed a few old jaxlib version guards
#29693 merged
Jun 24, 2025 -
Remove type declarations for deprecated APIs
#29701 merged
Jun 24, 2025 -
Fix corner case check in 2x2x2 VF mesh creation.
#29699 merged
Jun 24, 2025 -
[jax:pjit] Add jaxlib version guards to
ArrayPjitTest.jit_mul_sum_sharding_preserved
.#29692 merged
Jun 24, 2025 -
Custom batching rule for Cholesky that ensures symmetrize.
#29689 merged
Jun 24, 2025 -
[Pallas][NFC] Add documentation for some CompilerParams.
#29688 merged
Jun 24, 2025 -
[Mosaic GPU] Use mgpu.Barrier instead of mgpu.TMABarrier for the MMA barrier
#29650 merged
Jun 24, 2025 -
[Mosaic GPU] Use the 64-bit mov instruction while moving 64-bit timer values
#29648 merged
Jun 24, 2025 -
[Mosaic GPU] Simplify layout inference in TMEMRef.load
#29612 merged
Jun 24, 2025 -
Move jax._src.numpy to its own build rule
#29685 merged
Jun 24, 2025 -
[Mosaic GPU] Add support for 2CTA MMA with M=128 (64 per block)
#29613 merged
Jun 24, 2025 -
[Pallas:MGPU] Enable tcgen05.mma with m=64 in Pallas
#29680 merged
Jun 24, 2025 -
[Mosaic GPU][NFC] Make FragmentedArray.warp_dim into a sequence
#29606 merged
Jun 24, 2025 -
[Mosaic GPU] Add support for 1CTA MMA with M=64
#29602 merged
Jun 24, 2025 -
Automated Code Change
#29599 merged
Jun 24, 2025
123 Pull requests opened by 20 people
-
jnp.diagonal: avoid negative index overhead
#29709 opened
Jun 24, 2025 -
Add metadata for CUDA and libtpu versions
#29715 opened
Jun 24, 2025 -
Add [query,key_value]_seq_offsets arguments to dot_product_attention
#29731 opened
Jun 25, 2025 -
[pallas] Moved `pl.broadcast_to` into a Triton-specific submodule
#29736 opened
Jun 25, 2025 -
[Mosaic] Extend indirect DMA verification to dynamic shapes
#29759 opened
Jun 25, 2025 -
Add nd_loop and Enable block_n tiling for all_gather_lhs_matmul
#29822 opened
Jun 27, 2025 -
[mutable-arrays] systematic scan ad test
#29834 opened
Jun 28, 2025 -
repro multi axes of mgpu collective matmul
#29849 opened
Jun 30, 2025 -
[CI] Fix shell lint warnings in jax GH actions workflows
#29853 opened
Jun 30, 2025 -
Construct XLA GPU client with coordination service client.
#29872 opened
Jun 30, 2025 -
Experimental CuTe DSL Integration into Jax
#29897 opened
Jul 1, 2025 -
Include non-IFRT-PjRT arrays in JAX heap profiles.
#29917 opened
Jul 2, 2025 -
[CI] Update CI builds to Ubuntu 22.04
#29926 opened
Jul 2, 2025 -
Remove pretty_printer Python code.
#29927 opened
Jul 2, 2025 -
Precompute has_changed and will_change during pallas pipelines.
#29931 opened
Jul 2, 2025 -
Read corresponding environment variables for setting config flags.
#29945 opened
Jul 2, 2025 -
[Pallas:MGPU] Missing `div` lowering for LoweringSemantics.Lane and PrimitiveSemantics.Warp.
#29953 opened
Jul 3, 2025 -
Add array-api copy semantics to DLPackManagedTensorToBuffer
#29963 opened
Jul 3, 2025 -
Automated Code Change
#29968 opened
Jul 4, 2025 -
Cleaned up the includes in //jaxlib
#29976 opened
Jul 4, 2025 -
Reordered the GSPMDSharding constructor overloads to prioritize the PyDeviceList version.
#29984 opened
Jul 4, 2025 -
Add gemma-3-12b to auto tune and update tuned block for v6e.
#29999 opened
Jul 7, 2025 -
Clarify `wait_recv` semantics
#30000 opened
Jul 7, 2025 -
[Mosaic] Emulate shrui
#30008 opened
Jul 7, 2025 -
Fix cache leaks for pe._cached_abstract_eval; add util.multi_weakref_lru_cache
#30009 opened
Jul 7, 2025 -
Support complex numbers in pallas `uninitialised_value`
#30043 opened
Jul 8, 2025 -
[Mosaic GPU] Add a layout attribute to `mosaic_gpu.BroadcastInDimOp`.
#30044 opened
Jul 8, 2025 -
Always lower ragged dot for cpu, gpu, and tpu
#30058 opened
Jul 8, 2025 -
[Mosaic GPU] Fix `extract_constant_from_{least,most}_replicated_expression_for_hint` type annotation.
#30063 opened
Jul 8, 2025 -
Reverts e562cddc3421f3ee10a031ae21c48ed85f3edbb2
#30066 opened
Jul 8, 2025 -
Update Protobuf to 6.31.1
#30089 opened
Jul 9, 2025 -
#sdy Remove MHLO shardings from round-trip export
#30091 opened
Jul 9, 2025 -
Don't export Shardy in MPMD before going to IFRT
#30093 opened
Jul 9, 2025 -
[XLA:MGPU][Experimental] HLO -> Pallas.
#30122 opened
Jul 10, 2025 -
[Pallas] Remove deprecated symbols for TPUMemorySpace, TPUCompilerParams, and TritonCompilerParams.
#30123 opened
Jul 10, 2025 -
Refactor to avoid linking DCN transfer server in the Windows build.
#30130 opened
Jul 10, 2025 -
Simplify tree_reduce signature using Unspecified bare class.
#30133 opened
Jul 11, 2025 -
[Mosaic:TPU][Relayouts] Fallback for (1, x) column shifts that retiles
#30137 opened
Jul 11, 2025 -
[Mosaic:TPU] tileArrayShape is 1 for replicated dims
#30138 opened
Jul 11, 2025 -
[Mosaic:TPU][infer-vector-layout] Don't force sublane broadcasts to native tiling
#30140 opened
Jul 11, 2025 -
Experimenting with replicate vs maximal sharding issue
#30143 opened
Jul 11, 2025 -
Debug some fun flakes
#30168 opened
Jul 12, 2025 -
Add DETAILS.md for improved documentation
#30170 opened
Jul 12, 2025 -
Add tip to tree reduce and reduce_associative about how to exclude leaves from the reduction.
#30172 opened
Jul 12, 2025 -
Automated Code Change
#30175 opened
Jul 12, 2025 -
Document all checkpoint policies in one place, on the JAX public API page.
#30177 opened
Jul 12, 2025 -
[jaxprs] Hoist large constants as arguments during lowering
#30180 opened
Jul 14, 2025 -
Debug some fun flakes
#30198 opened
Jul 14, 2025 -
Bump medyagh/setup-minikube from 0.0.19 to 0.0.20
#30203 opened
Jul 14, 2025 -
gather/scatter: push negative index handling into primitives
#30205 opened
Jul 14, 2025 -
Clean up handling of transfer server factory.
#30209 opened
Jul 15, 2025 -
Implementing LİSHT
#30218 opened
Jul 15, 2025 -
Testing orbax changes with shardy enabled by default
#30226 opened
Jul 15, 2025 -
FIX: Handle WGSplatFragLayout in cond lowering.
#30238 opened
Jul 15, 2025 -
Support core axis index in the `device_id` dict for async copy and semaphore.
#30243 opened
Jul 16, 2025 -
[XLA:GPU] Add JAX-based precision tests for Triton and cuBLAS
#30246 opened
Jul 16, 2025 -
Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO
#30258 opened
Jul 16, 2025 -
[Mosaic:TPU] Explicitly instantiate VectorLayout::print
#30263 opened
Jul 16, 2025 -
Remove `local_config_nvshmem` repository and corresponding macros.
#30267 opened
Jul 17, 2025 -
Implement performance optimized w8a8 pallas kernel
#30268 opened
Jul 17, 2025 -
Add Windows Bazel CPU tests with py_import dependency to continuous tests.
#30286 opened
Jul 17, 2025 -
[Pallas/TPU] Add option to allow skipping the device barrier
#30289 opened
Jul 17, 2025 -
[jax:custom_partitioning] Allow factors for non-batching dimensions to
#30291 opened
Jul 17, 2025 -
[Mosaic][SC] Add custom assembly format to tpu.enqueue_indirect_dma
#30293 opened
Jul 17, 2025 -
[JAX] Disable Shardy in `JaxExportTest` and `CompatTest` if jaxlib version is before 0.7.0.
#30304 opened
Jul 18, 2025 -
[JAX] Disable Shardy in JAX export if jaxlib version is before 0.7.0.
#30306 opened
Jul 18, 2025 -
Test PR for runner
#30315 opened
Jul 18, 2025 -
Update ragged_dot kernels to use new GroupInfo for persistence
#30317 opened
Jul 18, 2025 -
Reverts dd59b47c07caa777f57637107379374e7906de12
#30319 opened
Jul 18, 2025 -
Reverts c9700e637550b6404e85aeae1ff4eb207e1f2d76
#30320 opened
Jul 18, 2025 -
flip flag to check OSS tests
#30321 opened
Jul 18, 2025 -
[Pallas:MGPU] Expose TCGEN05_TMEM_NATIVE_COL
#30322 opened
Jul 18, 2025 -
Add aval out to pull_block_spec signature in fusible dtype
#30332 opened
Jul 19, 2025 -
Automated Code Change
#30334 opened
Jul 19, 2025 -
fix core.Tracer inheritance (just to see what breaks)
#30341 opened
Jul 19, 2025 -
[mutable-arrays] flip JAX_MUTABLE_ARRAY_CHECKS=True by default
#30342 opened
Jul 20, 2025 -
Fix numerical bugs in the gradients of sigmoid/logistic, tanh, expm1.
#30346 opened
Jul 21, 2025 -
Accelerate deprecation for jax.lib.xla_bridge.get_backend.
#30349 opened
Jul 21, 2025 -
Add more Bazel tests to Nightly/Release job.
#30361 opened
Jul 21, 2025 -
Bump fonttools from 4.51.0 to 4.59.0
#30362 opened
Jul 21, 2025 -
Bump fsspec from 2024.5.0 to 2025.7.0
#30363 opened
Jul 21, 2025 -
Bump hypothesis from 6.102.4 to 6.136.1
#30364 opened
Jul 21, 2025 -
[Mosaic] Allow matrix-vector dot.
#30369 opened
Jul 21, 2025 -
[Mosaic] Allow vector::Extract for non-32 bits vector result.
#30371 opened
Jul 21, 2025 -
Bump libtpu version before release.
#30380 opened
Jul 22, 2025 -
#sdy Delete JAX test configs enabling Shardy.
#30381 opened
Jul 22, 2025 -
[pallas] Forked `load` and `store` into `triton` and `tpu`
#30384 opened
Jul 22, 2025 -
Finalize deprecation of xla_bridge.get_compile_options.
#30388 opened
Jul 22, 2025 -
Finalize deprecation of xla_extension.
#30389 opened
Jul 22, 2025 -
Refactor Bazel CPU RBE and Bazel GPU Non-RBE and add more Bazel tests to Nightly/Release job.
#30393 opened
Jul 22, 2025 -
#sdy Delete JAX test configs enabling Shardy.
#30395 opened
Jul 22, 2025 -
[Mosaic] Lower exp2 as math.exp2.
#30396 opened
Jul 22, 2025 -
First prototype of explicit async scheduling feature
#30399 opened
Jul 22, 2025 -
Remove meaningless sdpa util test
#30405 opened
Jul 22, 2025 -
[Mosaic][SC] Declare `tpu.scatter_store`. Stores values of a vector to arbitrary locations in memory.
#30409 opened
Jul 22, 2025 -
Reverts 8061b583e29fb62a39c876048e9a0109ca69beec
#30410 opened
Jul 22, 2025 -
Calculate roofline peak HBM bytes incrementally.
#30412 opened
Jul 22, 2025 -
Roofline: stop printing entire jaxpr when erroring.
#30413 opened
Jul 22, 2025 -
[Pallas][Mosaic GPU] Support pytree in/out specs in the warp specialized pipeline.
#30414 opened
Jul 22, 2025 -
[pallas] Moved atomic APIs to `jax.experimental.pallas.triton`
#30427 opened
Jul 23, 2025 -
Improve jax.nn.standardize numerical stability
#30428 opened
Jul 23, 2025 -
`lax.zeros_like_array` is no longer public
#30429 opened
Jul 23, 2025 -
[pallas:mgpu] Add batching rule for `plgpu.kernel`.
#30433 opened
Jul 23, 2025 -
Fix breakage caused by cl/785610645.
#30435 opened
Jul 23, 2025 -
Expose GPU tracing knobs for 3P
#30436 opened
Jul 23, 2025 -
[Mosaic GPU] Allow tcgen05.mma with different swizzles for A and B
#30438 opened
Jul 23, 2025 -
[RPA] Add random seed and disable unaligned num_heads for RPA test
#30444 opened
Jul 23, 2025 -
Implement PjRtClient::MakeCrossHostReceiveBuffers and PjRtBuffer:CopyToRemoteDevice in the PjRt C API.
#30449 opened
Jul 23, 2025 -
Adding tests for the LİSHT activation function
#30451 opened
Jul 23, 2025 -
[bazel] Create jax/_src/BUILD and move _src build declarations
#30452 opened
Jul 23, 2025 -
[Mosaic GPU] Explicitly set kernel_name on both Pallas and plain Mosaic GPU codepaths.
#30453 opened
Jul 23, 2025 -
[JAX][Mosaic] Add `llvm` namespace to `cast` op in mosaic_gpu.cc.
#30454 opened
Jul 23, 2025 -
Add JAX tests for deadlock verifier
#30455 opened
Jul 23, 2025 -
[Mosaic GPU] Add missing namespace specified
#30456 opened
Jul 23, 2025 -
Add option to enable coordination service client to recover.
#30457 opened
Jul 23, 2025 -
Add option to enable coordination service client to recover.
#30458 opened
Jul 23, 2025 -
Add option to enable TFRT GPU client in C API plugin.
#30459 opened
Jul 23, 2025 -
Add environment variable to enable TFRT GPU client in GPU plugin.
#30460 opened
Jul 23, 2025 -
[JAX] Remove enable_empty_arrays config flag. Behavior is now always True.
#30463 opened
Jul 24, 2025
57 Issues closed by 20 people
-
Segfault while Performing All-to-All Collective Operation on 8xH100 SXM5 (Shard + Swapaxes + Shard)
#30335 closed
Jul 24, 2025 -
jax.grad precision: float32 gradients of bfloat16 weights
#30337 closed
Jul 23, 2025 -
Fusing the optimizer step into the backward pass
#30338 closed
Jul 23, 2025 -
[jax._src.stages.ArgInfo] does not declare `__len__`, breaking `jit(f).lower()`
#30441 closed
Jul 23, 2025 -
jax.smap is not in the online documentation
#30407 closed
Jul 22, 2025 -
[HPC] Missing libdevice.so.10
#22590 closed
Jul 22, 2025 -
indexing and sharding
#13632 closed
Jul 22, 2025 -
CuSolver: Switch to 64 bit api to allow for eigh on matrices > than 26732x26732
#23413 closed
Jul 22, 2025 -
[How?] Debugging crash in jax.jit function when running under pytest-xdist
#10242 closed
Jul 22, 2025 -
[RFE] Add support for distributed CPU-backend mode
#11182 closed
Jul 22, 2025 -
Failed build: CI - with Numpy/Scipy nightly wheels (nightly)
#30055 closed
Jul 22, 2025 -
Add Gumbel distribution to scipy.stats
#29319 closed
Jul 21, 2025 -
jax.numpy.fft.fftfreq no longer supports complex dtype
#30287 closed
Jul 21, 2025 -
`jax.tree_util.tree_map` fails when a registered pydantic object which has been copied using `deep=True`
#30299 closed
Jul 19, 2025 -
Could you add the support of the new optimizer: Muon
#30309 closed
Jul 19, 2025 -
[sharding-in-types] setting global mesh+complex tensors+linear solve = problems
#30327 closed
Jul 19, 2025 -
Add nn.logmeanexp
#30178 closed
Jul 16, 2025 -
`fori_loop` gets slower after `jit`
#30245 closed
Jul 16, 2025 -
[Pallas, jax 0.6.0] Interpret mode seems to invoke GPU compilation process instead of being CPU only
#30214 closed
Jul 16, 2025 -
No error message but failing zero-copy?
#30228 closed
Jul 16, 2025 -
[sharding-in-types] `jnp.linalg.{slogdet/solve}` do not work with explicit sharding
#29883 closed
Jul 15, 2025 -
[sharding-in-types] jnp.linalg.inv does not work inside of shard_map
#30157 closed
Jul 11, 2025 -
Add include argument to tree_reduce and tree_reduce_associative
#30156 closed
Jul 11, 2025 -
Failing Pallas tests on TPU
#30150 closed
Jul 11, 2025 -
[sharding-in-types] `AxisType.Auto` and `random.normal(..., out_sharding)`
#30095 closed
Jul 11, 2025 -
Cannot extract graph node from different trace level
#30092 closed
Jul 10, 2025 -
RuntimeError: Unable to load cuSPARSE. Is it installed?
#29843 closed
Jul 9, 2025 -
Add associative / parallel version of jax.tree.reduce
#29774 closed
Jul 8, 2025 -
partial_eval.close_jaxpr cache is leaking
#29803 closed
Jul 8, 2025 -
Failed build: CI - with Numpy/Scipy nightly wheels (nightly)
#29807 closed
Jul 7, 2025 -
Numerical Discrepancy Between Equivalent jnp.einsum Formulations
#29990 closed
Jul 6, 2025 -
[sharding-in-types] vma not propagated by full_like and zeros_like
#29965 closed
Jul 6, 2025 -
[sharding-in-types] vma not set correctly by eval_shape
#29987 closed
Jul 6, 2025 -
[sharding-in-types] Add `out_sharding` argument to `jnp.tensordot`
#29986 closed
Jul 6, 2025 -
cusolver internal error on orin
#29802 closed
Jul 4, 2025 -
Extremely slow jax.lax.scan performance on GPU (GTX 1650) with simple sum
#29946 closed
Jul 3, 2025 -
[sharding-in-types] Hard to understand error when composing with vmap
#29839 closed
Jul 1, 2025 -
[sharding-in-types] `jax.random.choice` missing some `pvary` annotations, so fails under `shard_map`
#29881 closed
Jul 1, 2025 -
[sharding-in-types] sharding rule for scatter is not implemented.
#29252 closed
Jul 1, 2025 -
Incorrect documentation in `gpu_memory_allocation.rst`
#29865 closed
Jul 1, 2025 -
Redacted per user request
#29874 closed
Jul 1, 2025 -
[sharding-in-types] `jax.lax.map` fails when `batch_size>len(input)`
#29867 closed
Jun 30, 2025 -
[sharding-in-types] `jnp.reshape` gives surprising assertion error
#29859 closed
Jun 30, 2025 -
[sharding-in-types] scatter is not jit invariant
#29837 closed
Jun 30, 2025 -
jax.numpy.ndarray.at OOB indexing with negative indices and mode="fill" not coorect
#28998 closed
Jun 29, 2025 -
[sharding-in-types] `jax.jacrev` can fail inside shard_map
#29832 closed
Jun 28, 2025 -
jax vamp when used twice returns unexpected behaviour
#29829 closed
Jun 28, 2025 -
Cannot lower jaxpr with verifier errors
#29818 closed
Jun 28, 2025 -
[sharding-in-types] `jnp.nonzero` fails on replicated arrays outside jit
#29654 closed
Jun 27, 2025 -
Default value of JAX_DUMP_IR_MODES does nothing
#29819 closed
Jun 27, 2025 -
[sharding-in-types] Missing `out_sharding` on `jnp.matmul`
#29754 closed
Jun 27, 2025 -
[sharding-in-types] vmap+ jax.random.uniform: 'no mesh found' error
#29694 closed
Jun 25, 2025 -
The question about jax on Windows
#29686 closed
Jun 24, 2025 -
Vmap of dynamic_update_slice very slow on TPU
#21367 closed
Jun 24, 2025 -
[sharding-in-types] cannot multiply replicated jax.Array with np.Array
#29683 closed
Jun 24, 2025
55 Issues opened by 47 people
-
`visualize_array_sharding` unable to show device name on some shapes
#30462 opened
Jul 23, 2025 -
Build error: jaxlib 0.7.0, GCC 13.2.0
#30437 opened
Jul 23, 2025 -
`jax.nn.standardize` returns `nan` when variance is lower than `-epsilon`
#30426 opened
Jul 23, 2025 -
Not Implemented functions in Pallas
#30423 opened
Jul 23, 2025 -
Revisiting `custom_jvp` with `pure_callback`
#30401 opened
Jul 22, 2025 -
Pallas fails to write to output ref when using bfloat16 on TPU v3-8
#30391 opened
Jul 22, 2025 -
Profiling tool shows no GPU processes despite GPU Util being as expected.
#30378 opened
Jul 22, 2025 -
jet and equinox.nn.MLP
#30352 opened
Jul 21, 2025 -
Inconsistent Reduction precision in backwards computation
#30310 opened
Jul 18, 2025 -
jax.numpy.sort performance regression
#30296 opened
Jul 18, 2025 -
Json parse exceptions (and others) in perfetto traces
#30290 opened
Jul 17, 2025 -
[Feature Request] Add Sparse Attention kernel for GPUs in Pallas
#30281 opened
Jul 17, 2025 -
"ComputeCallSignature failed" in verbose logs investigating performance loss.
#30270 opened
Jul 17, 2025 -
Add custom derivative for scipy.special.hyp2f1
#30195 opened
Jul 14, 2025 -
jetson orin nx jetpack 6.2.1 cuSolver internal error
#30188 opened
Jul 14, 2025 -
Spikes in compilation time with jax.jit()
#30185 opened
Jul 14, 2025 -
bfloat16, float16 not supported by lax.linalg.qr, scipy.linalg.expm
#30176 opened
Jul 12, 2025 -
jax_explain_cache_misses is not thread safe
#30163 opened
Jul 11, 2025 -
Mosaic failed to compile TPU kernel: Target does not support this comparison
#30104 opened
Jul 10, 2025 -
`jax.numpy.polysub` brings different results with `numpy.polysub`
#30097 opened
Jul 9, 2025 -
`jax.numpy.var` brings different results with `numpy.var`
#30096 opened
Jul 9, 2025 -
[pallas] Does not support initialising complex values
#30084 opened
Jul 9, 2025 -
SlurmCluster isn't robust to only partially set SLURM_ env variables
#30073 opened
Jul 8, 2025 -
Build and upload cp313t jaxlib wheels to PyPI on all supported platforms
#30068 opened
Jul 8, 2025 -
jaxlib 0.6.2 (CUDA 12) fails to start: “Unable to load cuSPARSE” → falls back to CPU
#30050 opened
Jul 8, 2025 -
Gradient of gather uses 2x to 4x memory
#30015 opened
Jul 7, 2025 -
Error with `jax.custom_batching.sequential_vmap` in the `jax.ensure_compile_time_eval` context
#29996 opened
Jul 6, 2025 -
Wrong kernel selection in tranposed conv
#29993 opened
Jul 6, 2025 -
Cannot Initialize TPU on Google Colab
#29989 opened
Jul 5, 2025 -
Creating a differentiable sharded FFT with custom_partition
#29954 opened
Jul 3, 2025 -
custom jvp + transpose raises UnexpectedTracerError
#29948 opened
Jul 3, 2025 -
Foreign function interface (FFI) tutorial bugs
#29924 opened
Jul 2, 2025 -
Inefficient 1D convolution compared to PyTorch
#29875 opened
Jul 1, 2025 -
lax.associative_scan causes kernel crash in JAX 0.6.2 on GPU (PJRT)
#29866 opened
Jun 30, 2025 -
JAX >= 0.5. fails on rocm with lax.dot_general
#29846 opened
Jun 30, 2025 -
Erroneous pmap outputs on multi-device machine
#29841 opened
Jun 30, 2025 -
Decorator versions of pure_callback and io_callback
#29838 opened
Jun 29, 2025 -
jax.numpy.from_dlpack warns for unaligned data
#29810 opened
Jun 27, 2025 -
Jax compilation Profile
#29809 opened
Jun 27, 2025 -
build fails with unknown type name errors
#29801 opened
Jun 27, 2025 -
Add critical path length to cost_analysis
#29773 opened
Jun 26, 2025 -
Performance regression on CPU after 0.4.32
#29772 opened
Jun 26, 2025 -
How can cross-architecture operator libraries be applied in JAX, such as cuBLAS?
#29764 opened
Jun 26, 2025 -
`jax.block_until_ready()` should error on non-arrays by default
#29744 opened
Jun 25, 2025 -
Warning - All configs were filtered out because none of them sufficiently match the hints.
#29740 opened
Jun 25, 2025 -
Jax0.6.2 and 0.6.1 do not find GPU while 0.6.0 works
#29728 opened
Jun 25, 2025 -
Unexpected 2x memory for scan-like operations over trailing axes of an array
#29714 opened
Jun 24, 2025 -
Pallas tpu batched dot_general tries to broadcast lhs and output
#29698 opened
Jun 24, 2025 -
Large closed-over constants are inlined in the HLO code
#29684 opened
Jun 24, 2025 -
Simplify handling of closed-over constants in Jaxpr
#29679 opened
Jun 24, 2025
46 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
added solve_sylvester and accompanying tests
#28810 commented on
Jul 23, 2025 • 13 new comments -
OOM error on GPU despite more than enough available
#28316 commented on
Jul 15, 2025 • 0 new comments -
Implement JVP for SVD when full_matrices=True
#508 commented on
Jul 15, 2025 • 0 new comments -
Gemma 3 + `jax-metal`: 'mhlo.convolution' op Not supported
#27288 commented on
Jul 17, 2025 • 0 new comments -
Saturating arithmetic
#26566 commented on
Jul 17, 2025 • 0 new comments -
Segmentation fault when calling exported choleksy on CPU
#29610 commented on
Jul 18, 2025 • 0 new comments -
feature request: sparse jacobian and sparse hessians
#1032 commented on
Jul 19, 2025 • 0 new comments -
Support autodiff of Eigendecomposition with repeated eigenvalues
#669 commented on
Jul 21, 2025 • 0 new comments -
[sharding-in-types] `jnp.reshape` not figuring out a reshape output sharding when it should
#29235 commented on
Jul 22, 2025 • 0 new comments -
vmap of cond's predicate results in select, leading to unexpected compute/memory use
#8409 commented on
Jul 23, 2025 • 0 new comments -
[Doc] Rename gpu_performance_tips.md to performance_tips.md with new CPU performance tips session
#24961 commented on
Jul 16, 2025 • 0 new comments -
Consolidate material on debugging NaNs.
#24989 commented on
Jul 7, 2025 • 0 new comments -
Move the section on jitting methods from FAQ to Sharp Bits.
#25273 commented on
Jul 7, 2025 • 0 new comments -
[Draft] cudnn sdpa flex attention
#28048 commented on
Jul 7, 2025 • 0 new comments -
Initial commit for attaching XLA metadata to individual HLO operations via 'jax.attach_metadata(...)'
#28953 commented on
Jul 15, 2025 • 0 new comments -
Bump fsspec from 2024.5.0 to 2025.5.1
#29011 commented on
Jul 21, 2025 • 0 new comments -
[ROCm] ROCm7 Plugin Updates
#29281 commented on
Jul 1, 2025 • 0 new comments -
Refactor `custom_call` to use common `FindCudaExecutable` method from XLA repository to find CUDA binaries.
#29412 commented on
Jul 8, 2025 • 0 new comments -
Add support for deserializing xplanes to Jaxlib
#29431 commented on
Jul 8, 2025 • 0 new comments -
Add jax.nn.min_max_normalize.
#29569 commented on
Jul 9, 2025 • 0 new comments -
[Pallas:MGPU] Add more API links in the reference guide
#29579 commented on
Jun 25, 2025 • 0 new comments -
Move Mosaic CC sources into XLA
#29643 commented on
Jul 1, 2025 • 0 new comments -
[CI] Add bazel TPU presubmit testing
#29660 commented on
Jul 18, 2025 • 0 new comments -
CPU slowdown with new runtime (v0.4.32 and newer)
#26145 commented on
Jun 24, 2025 • 0 new comments -
CPU Slowdown introduced in 0.4.32 and in following versions
#26021 commented on
Jun 26, 2025 • 0 new comments -
XlaRuntimeError: failed to legalize operation 'mhlo.popcnt' during NUTS sampling on Sequoia 15.3/M1 Pro
#28460 commented on
Jun 27, 2025 • 0 new comments -
cuSolverMG support for distributed arrays
#16597 commented on
Jun 30, 2025 • 0 new comments -
Transposed preconditioned GMRES
#29449 commented on
Jun 30, 2025 • 0 new comments -
Slow transpose convolutions (both cpu and cuda backends)
#23783 commented on
Jun 30, 2025 • 0 new comments -
CPU Over-utilization and taskset
#29499 commented on
Jul 1, 2025 • 0 new comments -
Reorganize the tutorials
#24632 commented on
Jul 1, 2025 • 0 new comments -
Results do not match the reference. This is likely a bug/unexpected loss of precision.
#27188 commented on
Jul 2, 2025 • 0 new comments -
[jet] use proper Taylor coefficients during calculation
#29624 commented on
Jul 3, 2025 • 0 new comments -
Gradient of SVD with degenerate singular values becomes NaN
#2311 commented on
Jul 3, 2025 • 0 new comments -
Spoof multiple hosts
#5155 commented on
Jul 3, 2025 • 0 new comments -
jaxlib build fails on FreeBSD
#6076 commented on
Jul 7, 2025 • 0 new comments -
ComplexWarning in VJP when using complex matrix multiplication
#21188 commented on
Jul 7, 2025 • 0 new comments -
Add `jax.nn.normalize`
#20556 commented on
Jul 8, 2025 • 0 new comments -
Global mutable arrays not supported in `custom_jvp`
#27470 commented on
Jul 9, 2025 • 0 new comments -
[GPU] FlashAttention performance lags behind PyTorch
#24934 commented on
Jul 9, 2025 • 0 new comments -
Generalized directional derivative for abs
#12142 commented on
Jul 9, 2025 • 0 new comments -
`jax.numpy.percentile` brings different results with `numpy.percentile`
#29572 commented on
Jul 9, 2025 • 0 new comments -
`jax.numpy.corrcoef` brings different results with `numpy.corrcoef`
#29571 commented on
Jul 9, 2025 • 0 new comments -
ValueError: Mosaic failed to compile TPU kernel: Not Implemented: The last dim size is not 128 in original base memref
#29520 commented on
Jul 10, 2025 • 0 new comments -
NotImplementedError: MLIR translation rule for primitive 'schur' not found for platform cuda
#28927 commented on
Jul 11, 2025 • 0 new comments -
Hard to find documentation of predefined jax checkpoint policies
#12417 commented on
Jul 12, 2025 • 0 new comments