-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
98 Pull requests merged by 12 people
-
[mutable-arrays] make partial_eval_jaxpr forward input-residuals
#29311 merged
Jun 7, 2025 -
[Pallas][Mosaic GPU] Add collective (CTA-pair) MMAs to blackwell matmul kernel.
#29244 merged
Jun 7, 2025 -
Small speedups to pretty-printing.
#29306 merged
Jun 6, 2025 -
[Pallas TPU] Add custom_vjp_call lowering rule
#29302 merged
Jun 6, 2025 -
Reverts c1bb095c5ce5b0286dc5052abf3b597b6f23cea5
#29305 merged
Jun 6, 2025 -
[JAX] Allow registering callbacks to be called when backends are cleared
#29301 merged
Jun 6, 2025 -
Optimize jaxpr equation pretty-printing.
#29296 merged
Jun 6, 2025 -
[Pallas][Mosaic GPU] Expose partitioned collective loads to copy_gmem_to_smem.
#29243 merged
Jun 6, 2025 -
Fix segfault if None is passed to PartitionSpec.__eq__.
#29299 merged
Jun 6, 2025 -
Port PartitionSpec to C++.
#29221 merged
Jun 6, 2025 -
[Mosaic GPU] Extract the type-related logic out of
reinterpret_smem_ref
.#29270 merged
Jun 6, 2025 -
Reverts 5c33588b30edbae51d5b63b0bd7cc8d9058d7ccb
#29295 merged
Jun 6, 2025 -
[Mosaic GPU] Add support for lowering
2xbf16 -> 2xf8e4m3fn
converts.#29264 merged
Jun 6, 2025 -
Bring back tree concat optimization for np.array(...)
#29291 merged
Jun 6, 2025 -
lax.top_k: raise error if indices will overflow
#29254 merged
Jun 6, 2025 -
[Pallas][Mosaic GPU] Add support for load/broadcast using TCGEN05 ROW/COL layouts.
#29275 merged
Jun 5, 2025 -
[Pallas][Mosaic GPU] Skip tcgen05 reduce test on WG semantics.
#29279 merged
Jun 5, 2025 -
[Mosaic GPU] Fix
2xf32 -> 2xf8e4m3fn
conversion.#29263 merged
Jun 5, 2025 -
[Pallas][Mosaic GPU] Support column slicing on TMEM.
#29253 merged
Jun 5, 2025 -
Don't recompute source_info.current() in DynamicJaxprTracer.
#29266 merged
Jun 5, 2025 -
[cleanup] remove core.gensym, and Var.suffix
#29273 merged
Jun 5, 2025 -
lax_numpy: move array and asarray to their own submodule
#29246 merged
Jun 5, 2025 -
doc: clarified lack of gpu support for schur and sqrtm
#29087 merged
Jun 5, 2025 -
skip pytype on slow file
#29274 merged
Jun 5, 2025 -
jnp.array: avoid call to stack
#29224 merged
Jun 5, 2025 -
[Pallas][Mosaic GPU] Use separate allocations for collective TMEM.
#29240 merged
Jun 5, 2025 -
[Mosaic GPU] Move
should_have_transforms
toinference_utils
.#29268 merged
Jun 5, 2025 -
[Pallas/Mosaic GPU] Expose the new
TCGEN05_COL
layout.#29242 merged
Jun 5, 2025 -
fix sharding-in-types + from_edtype
#29258 merged
Jun 5, 2025 -
Add a general system for keeping track of quasi-dynamic data (QDD)
#29245 merged
Jun 5, 2025 -
[Rollback] Roll-forward with fix and test: prototype of cross-host device transfers for TFRT TPU.
#29250 merged
Jun 5, 2025 -
[Pallas Fuser] Add basic reshape push rule
#29251 merged
Jun 5, 2025 -
Link c-api raw buffer support into jaxlib.
#29255 merged
Jun 4, 2025 -
Fix documentation for the CLI
up
command in the debugger.#29247 merged
Jun 4, 2025 -
Fix typos discovered by codespell
#29146 merged
Jun 4, 2025 -
Add not-implemented sharding rule in
third_party/py/jax/_src/cudnn/fused_attention_stablehlo.py
.#29241 merged
Jun 4, 2025 -
Fix a rare numerical flake in svd_test seen on TPU v6e.
#29132 merged
Jun 4, 2025 -
Reverts 6cd196a5db22b8db0ed4000e4cf67ad748bf52f3
#29239 merged
Jun 4, 2025 -
[jaxlib] Bind 'compile' to
xla::PyClient::Compile
rather thanxla::PyClient::CompileAndLoad
.#29212 merged
Jun 4, 2025 -
Raise
NotImplementedError
instead ofValueError
when using Shardy without sharding rule.#29233 merged
Jun 4, 2025 -
[Mosaic GPU] Use the
mosaic_gpu.sliceSMEM
MLIR op when using WG semantics.#29237 merged
Jun 4, 2025 -
#sdy Have JAX export compat tests also run on Shardy.
#29161 merged
Jun 4, 2025 -
[Pallas/Mosaic GPU] Expose the new
TCGEN05_ROW
layout.#29210 merged
Jun 4, 2025 -
#sdy Fallback to GSPMD in JAX export if the loaded module was lowered for GSPMD.
#29033 merged
Jun 4, 2025 -
Make experimental pytree_serialization visible in OSS jax build
#29229 merged
Jun 4, 2025 -
[Pallas] Fix missing sub lowering rule for sparsecore.
#29226 merged
Jun 4, 2025 -
Update more uses of
backend.compile
tobackend.compile_and_load
.#29218 merged
Jun 3, 2025 -
[Pallas] Add forward-compatible i1 broadcast.
#29225 merged
Jun 3, 2025 -
[pallas] In TPU interpret mode, run kernels in parallel over Megacore cores.
#29187 merged
Jun 3, 2025 -
Resurrect _pjit_lower's cache because it's important for python dispatch performance.
#29192 merged
Jun 3, 2025 -
[Mosaic GPU] Fix
bitcast
logic inshfl_bfly
.#29211 merged
Jun 3, 2025 -
Prototype of cross-host device transfers in IFRT-PJRT.
#28867 merged
Jun 3, 2025 -
[jaxlib] Add
PyClient::Compile
method that returns an unloadedPyExecutable
.#29104 merged
Jun 3, 2025 -
[JAX] Remove the redundant pjit BUILD target.
#29126 merged
Jun 3, 2025 -
[Mosaic GPU] Add lowering for
2xf32 -> 2xf8e4m3fn
conversions.#29198 merged
Jun 3, 2025 -
[imports] avoid top-level imports in jax.numpy sources
#29186 merged
Jun 3, 2025 -
Propagate layouts correctly via mutable arrays
#29194 merged
Jun 3, 2025 -
[pallas:mosaic] Dropped the
TPU
prefix from the recently addedTPUInterpreterParams
#29197 merged
Jun 3, 2025 -
[pallas:mosaic] Removed the
TPU
prefix fromTPUCompilerParams
andTPUMemorySpace
#29115 merged
Jun 3, 2025 -
[Mosaic GPU] Add support for tiled loads and stores of
f8
data types.#29196 merged
Jun 3, 2025 -
move jax/_src/custom_partitioning_sharding_rule.py to its own build rule
#29188 merged
Jun 3, 2025 -
Move jax/_src/extend/* to its own build rule
#29189 merged
Jun 3, 2025 -
Simplify
jnp.isclose
#29153 merged
Jun 2, 2025 -
Clean up unused GPU RNN kernels.
#29178 merged
Jun 2, 2025 -
[Mosaic GPU] Add reduction support for TCGEN05 layout.
#29184 merged
Jun 2, 2025 -
[pallas:mosaic_gpu]
plgpu.nd_loop
is now a decorator similar topl.loop
#29123 merged
Jun 2, 2025 -
[pallas] Added a note on
pl.loop
to the changelog#29118 merged
Jun 2, 2025 -
Maintain the dtype of the input on the output in
broadcast_one_to_all
.#29181 merged
Jun 2, 2025 -
Clean up some unused GPU sparse kernels.
#29177 merged
Jun 2, 2025 -
Clean up some unused GPU linear algebra kernels.
#29175 merged
Jun 2, 2025 -
Enable profiler_test for TPU's
#29098 merged
Jun 2, 2025 -
Adding tests and improving jnp.ufunc support
#29144 merged
Jun 2, 2025 -
Update workflow files to use new ml-build containers.
#29171 merged
Jun 2, 2025 -
always compile Pallas calls, enabling
pallas_call
underdisable_jit
#29168 merged
Jun 2, 2025 -
[cleanup] inline uses of NumpyComplexWarning
#29170 merged
Jun 2, 2025 -
Fix native tiling logic in infer_vector_layout.
#28862 merged
Jun 2, 2025 -
Add a pretty printing rule for custom_lin_p.
#29169 merged
Jun 2, 2025 -
Bump the minimum NumPy and SciPy versions.
#29166 merged
Jun 2, 2025 -
[mutable-arrays] don't let scan AD hoist mutable operations
#29127 merged
Jun 2, 2025 -
[CI] Move k8s tests files out of .github/workflows
#29084 merged
Jun 2, 2025 -
Update partial eval to avoid DCEing a specific set of effects
#29165 merged
Jun 2, 2025 -
[jaxlib] Use SafeStaticInit in more places.
#29131 merged
Jun 2, 2025 -
Automated Code Change
#29156 merged
Jun 2, 2025 -
Introduce profiler_options in the documentation.
#28880 merged
Jun 2, 2025 -
Reverts 73c016a534af51614741d70d36c2c75ca59f2dcc
#29151 merged
Jun 1, 2025 -
Allow specifying non-differentiable arguments by name
#29149 merged
Jun 1, 2025
53 Pull requests opened by 16 people
-
Implementing PReLu
#29147 opened
May 31, 2025 -
[xla:cpu] Deprecate API_VERSION_STATUS_RETURNING custom calls
#29150 opened
Jun 1, 2025 -
Add initial devcontainer configuration for Alpine environment
#29154 opened
Jun 1, 2025 -
[jax] Always mask padded region if `query_seq_lengths` is passed to `dot_product_attention`.
#29158 opened
Jun 2, 2025 -
Fixed TSAN CI NumPy build step
#29160 opened
Jun 2, 2025 -
Add _XlaShardingV2 to tf.XlaShardOp and use it for tf2xla lowering.
#29172 opened
Jun 2, 2025 -
Bump hypothesis from 6.102.4 to 6.133.0
#29174 opened
Jun 2, 2025 -
Bump fonttools from 4.51.0 to 4.58.1
#29176 opened
Jun 2, 2025 -
Add pylint to disable warning unused imports in API files.
#29179 opened
Jun 2, 2025 -
Remove legacy CPU custom calls.
#29180 opened
Jun 2, 2025 -
Improve error handling in transmission of buffer metadata for experimental cross-host device transfers.
#29182 opened
Jun 2, 2025 -
[Mosaic] Support arbitrary 1D layout mask.
#29190 opened
Jun 2, 2025 -
[jaxlib] Change Traceback to be a raw CPython class rather than a nanobind class.
#29191 opened
Jun 3, 2025 -
Set an upper limit on McCabe code complexity
#29199 opened
Jun 3, 2025 -
Ruff rules for Perflint
#29200 opened
Jun 3, 2025 -
[CI] Debug tsan job speed
#29201 opened
Jun 3, 2025 -
[JAX] Wrap triton_call custom call using the FFI.
#29202 opened
Jun 3, 2025 -
Ruff rules C4 for comprehensions
#29203 opened
Jun 3, 2025 -
[Pallas TPU] Avoid using SMEM tensors in pipeline loop to track buffer slots
#29204 opened
Jun 3, 2025 -
fix missing input checks in jax.nn.functions
#29206 opened
Jun 3, 2025 -
[Mosaic] Use BF16 ops for math::PowF on TPUv6+.
#29214 opened
Jun 3, 2025 -
Grad of unreduced
#29219 opened
Jun 3, 2025 -
[jaxlib] Remove extraneous `compile_and_load` bindings for `xla::PyClient`.
#29222 opened
Jun 3, 2025 -
Removing Tensorflow references from the document.
#29231 opened
Jun 4, 2025 -
Add a pytype disable around zstandard.
#29238 opened
Jun 4, 2025 -
Fix handling of empty arrays in ufunc.reduce/accumulate
#29248 opened
Jun 4, 2025 -
[mosaic-gpu] add utility to get number of SMs
#29249 opened
Jun 4, 2025 -
fix type annotation for _IndexUpdateRef.get
#29257 opened
Jun 5, 2025 -
Fixed NaN propagation in optimal_step_size using nanmin/nanmax
#29261 opened
Jun 5, 2025 -
Change `use_shardy_partitioner` default to `None` instead of `False`.
#29262 opened
Jun 5, 2025 -
[Mosaic GPU] Handle `None` transforms in `swizzle_and_transforms_from_transforms_attr`
#29269 opened
Jun 5, 2025 -
[pallas:mgpu] Optionally skip allocating registers in warp specialized pipelining.
#29271 opened
Jun 5, 2025 -
[ROCm] ROCm7 Plugin Updates
#29281 opened
Jun 5, 2025 -
Remove forward_compat as it is past the support date.
#29282 opened
Jun 5, 2025 -
Expose local/global `ExchangeTopologies` timeouts for PJRT CPU client.
#29283 opened
Jun 5, 2025 -
Add unit tests for parameter and optimizer state offload
#29284 opened
Jun 5, 2025 -
Implemented cross-host memory transfer on GPU.
#29286 opened
Jun 5, 2025 -
Improve device assignment string format and add AbslStringify implementation.
#29287 opened
Jun 5, 2025 -
Add test and refactor Device Assignment.
#29288 opened
Jun 5, 2025 -
Delete extraneous `block_until_ready`s on JAX dispatch benchmarks.
#29289 opened
Jun 5, 2025 -
[Mosaic GPU] Simplify how we skip transforms for unrealized casts and gpu shared memory.
#29293 opened
Jun 6, 2025 -
//tests:scaled_matmul_stablehlo_test: fix for xla#27096
#29294 opened
Jun 6, 2025 -
Pallas documentation fixes.
#29297 opened
Jun 6, 2025 -
[CI] Run Mosaic H100 and B200 tests on all PRs that target mosaic subpaths
#29298 opened
Jun 6, 2025 -
Add alternative location of `CUDA_ROOT` for Bazel build/tests with hermetic CUDA.
#29300 opened
Jun 6, 2025 -
Adding `jax._src.util.cache_clearing_funs` to `jax.extend.backend`.
#29303 opened
Jun 6, 2025 -
Port pretty-printer to C++.
#29308 opened
Jun 6, 2025 -
Add numerics check for unreduced e2e.
#29310 opened
Jun 6, 2025 -
[Mosaic GPU] Make the Pallas Blackwell matmul kernel persistent.
#29312 opened
Jun 6, 2025 -
Clarify argument order for lax.associative_scan when reverse=True.
#29313 opened
Jun 7, 2025
6 Issues closed by 5 people
-
Suggestion add a random generator class that works with vmap, jit, etc.
#29277 closed
Jun 5, 2025 -
[sharding-in-types] `jax.lax.map(...batch_size=)` bug when using Explicit shapes
#29195 closed
Jun 3, 2025 -
FFI breaks if sm_90a is targeted
#29148 closed
Jun 3, 2025 -
`debug.print` inside `scan` gets lost during AD
#28738 closed
Jun 2, 2025 -
Unable to Reuse GPU After JAX Execution
#29159 closed
Jun 2, 2025 -
Gradients from `jacrev` and `jacfwd` not matching Finite Differences
#29111 closed
Jun 1, 2025
9 Issues opened by 7 people
-
Unsupported int8 in mosaic transpose
#29278 opened
Jun 5, 2025 -
jax.nn.dot_product_attention(...implementation='cudnn') fails due to incorrect loading of libnvrtc
#29260 opened
Jun 5, 2025 -
[sharding-in-types] sharding rule for scatter is not implemented.
#29252 opened
Jun 4, 2025 -
[sharding-in-types] `jnp.reshape` not figuring out a reshape output sharding when it should
#29235 opened
Jun 4, 2025 -
custom_partitioning compatibility with lax.composite
#29223 opened
Jun 3, 2025 -
Expose ormqr, allowing more efficient linear least squares solves using QR factorization
#29173 opened
Jun 2, 2025 -
[sharding-in-types+shard_map] Support closing over explicit meshes when using shard map
#29162 opened
Jun 2, 2025 -
Fix broken navigation breadcrumbs on website
#29155 opened
Jun 1, 2025 -
Jax jacobian returns unexpected nan
#29152 opened
Jun 1, 2025
37 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
add psend and precv to jax/lax/parallel
#29135 commented on
Jun 6, 2025 • 15 new comments -
Parametrize build system on CUDA major version
#28968 commented on
Jun 6, 2025 • 14 new comments -
[XProf] Change tensorboard-plugin-profile to new xprof package
#29129 commented on
Jun 3, 2025 • 3 new comments -
Add cudnn paged attention support in JAX cuDNN SDPA API
#28102 commented on
Jun 5, 2025 • 2 new comments -
added solve_sylvester and accompanying tests
#28810 commented on
Jun 4, 2025 • 2 new comments -
Tighten numpy.linalg.qr annotation
#28997 commented on
Jun 3, 2025 • 1 new comment -
jax reported a 'Segmentation fault' error for backend_compile in compiler.py
#29139 commented on
Jun 1, 2025 • 0 new comments -
* Add support for output and input memory space colors in tpu custom calls via CustomCallConfig.
#28290 commented on
Jun 6, 2025 • 0 new comments -
Add is_leaf_with_path predicate.
#28300 commented on
Jun 5, 2025 • 0 new comments -
Bump kiwisolver from 1.4.5 to 1.4.8
#28350 commented on
Jun 2, 2025 • 0 new comments -
Bump importlib-resources from 6.4.0 to 6.5.2
#28352 commented on
Jun 2, 2025 • 0 new comments -
Major deps update:
#28497 commented on
Jun 7, 2025 • 0 new comments -
add jax cudnn sdpa mla support
#28872 commented on
Jun 5, 2025 • 0 new comments -
Add hermetic `nvshmem` implementation.
#28892 commented on
Jun 5, 2025 • 0 new comments -
Explicitly make `out` a `jax.Array` before converting it in `lax_numpy.array`.
#28966 commented on
Jun 6, 2025 • 0 new comments -
[JAX][DOC] Add optimizer state offloading doc
#28988 commented on
Jun 4, 2025 • 0 new comments -
Enable more flexible handling of custom_vjp under remat by deferring partial eval
#29116 commented on
Jun 6, 2025 • 0 new comments -
[Mosaic GPU] Convert all memrefs with transforms to unrealized casts and check them.
#29122 commented on
Jun 2, 2025 • 0 new comments -
[Mosaic GPU] Error when causal masking is used on cuda versions known to result in a ptxas miscompilation (between 12.8.0 and 12.9.1).
#29141 commented on
Jun 4, 2025 • 0 new comments -
jax.nn.dot_product_attention with implementation='cudnn' fails when only `mask`, but no `bias` provided
#28974 commented on
Jun 2, 2025 • 0 new comments -
Missing `.device` attribute inside `@jax.jit`
#26000 commented on
Jun 3, 2025 • 0 new comments -
Unimplemented primitive in Pallas GPU lowering: gather
#29143 commented on
Jun 3, 2025 • 0 new comments -
Support for ragged arrays, like torch.nested
#17863 commented on
Jun 4, 2025 • 0 new comments -
JAX segment_sum is two times slower for FP16 inputs than FP32 inputs
#23136 commented on
Jun 4, 2025 • 0 new comments -
jax.ops.segment_sum is 100x slower when segment_ids are sorted, but only for float16
#26227 commented on
Jun 4, 2025 • 0 new comments -
`jax.debug.breakpoint` crashes in a hard-to-describe way.
#16732 commented on
Jun 4, 2025 • 0 new comments -
`XlaRuntimeError` on latest jax-metal
#27062 commented on
Jun 4, 2025 • 0 new comments -
cuSolver internal error. jax problems on gpu mode. everything works with cpu mode
#8916 commented on
Jun 4, 2025 • 0 new comments -
lax.cond crashes on Windows
#29049 commented on
Jun 5, 2025 • 0 new comments -
jax.scipy.special.hyp1f1 unstable where scipy.special.hyp1f1 is not
#21503 commented on
Jun 5, 2025 • 0 new comments -
xla_gpu_deterministic_ops=true breaks simple indexing in jax
#27796 commented on
Jun 5, 2025 • 0 new comments -
NotImplementedError: MLIR translation rule for primitive 'schur' not found for platform cuda
#28927 commented on
Jun 5, 2025 • 0 new comments -
Be able to consider subtree's key path when determining if is_leaf in tree operations
#27996 commented on
Jun 5, 2025 • 0 new comments -
Results do not match the reference. This is likely a bug/unexpected loss of precision.
#27188 commented on
Jun 6, 2025 • 0 new comments -
CPU slowdown with new runtime (v0.4.32 and newer)
#26145 commented on
Jun 6, 2025 • 0 new comments -
Tracking: MPMD/pipeline parallelism in eager-mode McJAX
#26645 commented on
Jun 6, 2025 • 0 new comments -
Add NSA Attention implementation
#26943 commented on
Jun 3, 2025 • 0 new comments