1. JAX based parallel inference for reactive probabilistic programming
- Author
-
Baudart, Guillaume, Mandel, Louis, Tekin, Reyyan, Parallélisme de Kahn Synchrone ( Parkas), Département d'informatique - ENS Paris (DI-ENS), École normale supérieure - Paris (ENS-PSL), Université Paris sciences et lettres (PSL)-Université Paris sciences et lettres (PSL)-Institut National de Recherche en Informatique et en Automatique (Inria)-Centre National de la Recherche Scientifique (CNRS)-École normale supérieure - Paris (ENS-PSL), Université Paris sciences et lettres (PSL)-Université Paris sciences et lettres (PSL)-Institut National de Recherche en Informatique et en Automatique (Inria)-Centre National de la Recherche Scientifique (CNRS)-Centre National de la Recherche Scientifique (CNRS)-Inria de Paris, Institut National de Recherche en Informatique et en Automatique (Inria), and IBM T. J. Watson Research Centre
- Subjects
Streaming Inference ,Parallel Computing ,[INFO.INFO-PL]Computer Science [cs]/Programming Languages [cs.PL] ,Probabilistic Programming ,Reactive Programming ,[INFO.INFO-ES]Computer Science [cs]/Embedded Systems ,Probabilistic Programming Reactive Programming Streaming Inference Parallel Computing Compilation ,[INFO.INFO-SE]Computer Science [cs]/Software Engineering [cs.SE] ,Compilation - Abstract
International audience; ProbZelus is a synchronous probabilistic language for the design of reactive probabilistic models in interaction with an environment. Reactive inference methods continuously learn distributions over the unobserved parameters of the model from statistical observations. Unfortunately, this inference problem is in general intractable. Monte Carlo inference techniques thus rely on many independent executions to compute accurate approximations. These methods are expensive but can be parallelized. We propose to use JAX to parallelize ProbZelus reactive inference engine. JAX is a recent library to compile Python code which can then be executed on massively parallel architectures such as GPUs or TPUs. In this paper, we describe a new reactive inference engine implemented in JAX and the new associated JAX backend for ProbZelus. We show on existing benchmarks that our new parallel implementation outperforms the original sequential implementation for a high number of particles.
- Published
- 2022