AI RESEARCH
Parax: Parametric Modeling in JAX + Equinox [P]
r/MachineLearning
•
Hi everyone! Just wanted to share my Python project Parax - an add-on on top of the Equinox library catering for parameter-first modeling in JAX. For our scientific applications, we found that we often needed to attach metadata to our parameter objects, such as marking them as fixed or attached a prior probability distribution. Further, we often needed to manipulate these parameters in very deep hierarchies, which sometimes can be unintuitive using eqx.tree_at. We therefore developed Parax, which provides parax. Parameter and parax. Module (that both inherit from eqx.