Back to News
quantum-computing

Jaxmg Enables Scalable Multi-GPU Linear Solves Beyond Single-GPU Memory Limits

Quantum Zeitgeist
Loading...
5 min read
0 likes
⚡ Quantum Brief
Flatiron Institute researchers led by Roeland Wiersema introduced JAXMg, a multi-GPU linear solver integrated with JAX, overcoming single-GPU memory limits for large-scale dense linear systems. It bridges optimized but rigid libraries with JAX’s flexible JIT-compiled workflows. JAXMg leverages cuSOLVERMg via XLA, enabling JIT-compatible multi-GPU operations like matrix inversion (potri) and eigenvalue decomposition (syevd) without leaving JAX’s execution model. This preserves composability for scientific applications requiring repeated linear solves. The system uses a 1D block-cyclic data distribution, mapping matrix columns to GPUs in configurable tiles via peer-to-peer transfers, minimizing overhead. It supports SPMD and MPMD execution modes, handling complex memory management transparently. Benchmarks on 8 NVIDIA H200 GPUs show JAXMg scales efficiently, solving matrices up to 524,288 in size (1TB+ memory) with 30% performance gains over single-GPU methods. It supports float32/64 and complex data types. This breakthrough enables previously infeasible simulations in quantum physics and optimization, maintaining JAX’s ease of use while unlocking multi-GPU performance for memory-bound problems. Future work may expand supported operations.
Jaxmg Enables Scalable Multi-GPU Linear Solves Beyond Single-GPU Memory Limits

Summarize this article with:

Scientists are tackling a persistent bottleneck in modern scientific computing: efficiently solving large, dense linear systems. Roeland Wiersema from the Center for Computational Quantum Physics, Flatiron Institute, and colleagues demonstrate a novel solution with JAXMg, a multi-GPU linear solver built within the JAX framework.

This research is significant because it bridges the gap between highly optimised, yet often inflexible, multi-GPU libraries and the increasingly popular, composable JIT-compiled workflows of JAX. By seamlessly integrating with cuSOLVERMg via XLA, JAXMg allows scalable linear algebra to be embedded directly into JAX programs, unlocking multi-GPU performance for end-to-end scientific applications. Existing multi-GPU solver libraries often prove difficult to integrate into these composable Python workflows, requiring users to exit the JAX execution model and manually manage memory. JAXMg circumvents these challenges by providing a unified, JIT-compatible interface, allowing researchers to leverage the power of multiple GPUs without sacrificing the benefits of JAX’s streamlined programming environment. This approach is particularly valuable for applications demanding repeated linear system solves or eigenvalue decompositions within larger simulation loops or differentiable optimization processes. Experiments demonstrate that JAXMg supports CUDA 12 and CUDA 13 compatible devices and offers JIT-able interfaces to core cuSOLVERMg routines, including potrs for solving symmetric positive-definite systems, potri for computing matrix inverses, and syevd for eigenvalue decomposition. The implementation utilizes a 1D block-cyclic data distribution scheme, efficiently mapping columns to GPUs in fixed-size tiles of user-configurable size, TA, to balance computational load. This distribution is achieved through deterministic, in-place rotations using peer-to-peer GPU copies and small staging buffers, minimizing data movement and maximizing performance. Furthermore, the study unveils sophisticated memory management techniques supporting both Single Program Multiple Devices (SPMD) and Multi Program Multiple Devices (MPMD) execution modes. In SPMD mode, shared virtual address spaces facilitate straightforward pointer sharing, while MPMD mode leverages the CUDA IPC API to enable inter-process communication and GPU allocation sharing. Benchmarks conducted on a system equipped with 8 NVIDIA H200 GPUs (143 GB VRAM each) reveal that JAXMg consistently outperforms native single-GPU linear algebra routines, particularly for larger problem sizes, and scales effectively with increasing numbers of GPUs.

The team reports performance gains across various data types, including float32, float64, complex64, and complex128, demonstrating the versatility of this new approach. JAXMg multi-GPU linear algebra via cuSOLVERMg offers significant Researchers engineered a system that integrates cuSOLVERMg routines—potrs, potri, and syevd—to solve symmetric positive-definite systems, perform matrix inversion, and compute eigenvalues and eigenvectors within the JAX ecosystem. The implementation supports JAX data types (float32, float64, complex64, complex128) and is compatible with CUDA 12 and CUDA 13 devices. Experiments employ a 1D block-cyclic data distribution, constructed in a C++ backend to ensure balanced workload distribution across GPUs by assigning matrix columns in fixed-size tiles distributed in a round-robin fashion. Efficient in-place redistribution is achieved by decomposing column-index mappings into disjoint permutation cycles, enabling peer-to-peer GPU transfers via cudaMemcpyPeerAsync with minimal staging overhead. Using jax.shard, the system exposes per-device shards and passes corresponding GPU pointers to the backend, supporting both SPMD and MPMD execution while maintaining a single controlling process capable of accessing all device memory. This approach enables scalable multi-GPU linear algebra for problems exceeding single-GPU memory limits, preserves composability within JAX pipelines, and overcomes key memory management challenges in existing multi-GPU solutions. JAXMg delivers scalable multi-GPU linear algebra for machine Experiments revealed that JAXMg surpasses native single-GPU linear algebra routines in performance, particularly for larger matrices. For the potrs benchmark, with b = (1, ., 1)T, the team varied the tile size (TA) and recorded wall-clock timings, demonstrating improved scaling with JAXMg.

Results demonstrate that the largest solvable problem reached N=524288, utilising over 1 TB of memory, a feat previously infeasible, and delivering a 0.3 performance increase. Larger tile sizes improved performance once the problem size became sufficiently large, consistent with increased GPU utilisation, while tile size had minimal impact on syevd. Tests prove that both syevd and potri require significantly more workspace memory than potrs, influencing the maximum achievable matrix sizes. Measurements confirm that JAXMg enables dense linear solves and eigendecompositions bottlenecked by single-GPU memory capacity, all while maintaining JAX’s composability and JIT-compiled programming model. The comparison of jaxmg. potri with jax. numpy. linalg. inv for a complex128 matrix and jaxmg. syevd with jax. numpy. linalg. eigh for a float64 matrix further validates the library’s effectiveness. Data shows a strong dependence of potri on the tile size (TA), whereas syevd exhibited negligible impact from tile size variations. The breakthrough delivers the ability to tackle matrix sizes previously impossible on single GPUs and to enhance throughput by leveraging aggregate device memory and compute, opening new avenues for complex scientific simulations and analyses. This work highlights the potential for JAXMg to accelerate research across diverse fields reliant on large-scale linear algebra. JAXMg enables scalable, composable GPU linear algebra This innovative tool enables Cholesky-based linear solves and symmetric eigendecompositions for matrices exceeding the memory capacity of a single GPU. The key achievement lies in maintaining composability with JAX transformations, allowing multi-GPU execution within complete scientific workflows, a feature often lacking in existing highly optimised multi-GPU solver libraries. Benchmarks demonstrate JAXMg’s ability to solve problems with matrices up to 524,288 in size, utilising over 1TB of memory, and show competitive performance against established JAX-based linear algebra routines. The authors acknowledge a dependence on the TA algorithm and note that further optimisation may be possible. Future work could explore extending JAXMg to encompass a wider range of linear algebra operations and solvers, broadening its applicability across diverse scientific domains. 👉 More information 🗞 JAXMg: A multi-GPU linear solver in JAX 🧠 ArXiv: https://arxiv.org/abs/2601.14466 Tags:

Read Original

Source Information

Source: Quantum Zeitgeist