Back to Search Start Over

DrJAX: Scalable and Differentiable MapReduce Primitives in JAX

Authors :
Rush, Keith
Charles, Zachary
Garrett, Zachary
Augenstein, Sean
Mitchell, Nicole
Publication Year :
2024

Abstract

We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.

Details

Database :
arXiv
Publication Type :
Report
Accession number :
edsarx.2403.07128
Document Type :
Working Paper