I find that Jax tends to result in messy code unless you build good abstractions. I personally don't like Flax and Haiku, I prefer stax and Equinox as they are more transparent on what is happening, feel a lot less like magic, and more pythonic (explicit is better than implicit etc).
PyTorch is far more friendly for deep learning stuff, but sometimes all you want is pure numerical computations that can be vmapped across tensors, and this is where jax shines imho.
Personal Example: I needed to sample a bunch of datapoints, make distributions out of them, sample, and then compute the density of each sample across distributions. Doing this with pytorch was rather slow, I was probably doing something wrong with vectorization and broadcasting, but I didn't have the time to figure it out.
With jax, I wrote a function that produces the samples, then I vmapped the evaluation of a sample across all distributions, then vmapped over all samples. Took a couple of minutes to implement and seconds to execute.
PyTorch also has the advantage of a far more mature ecosystem, libraries like Lightning, Accelerate, Transformers, Evaluate, and so on make building models a breeze.
> Personal Example: I needed to sample a bunch of datapoints, make distributions out of them, sample, and then compute the density of each sample across distributions. Doing this with pytorch was rather slow, I was probably doing something wrong with vectorization and broadcasting, but I didn't have the time to figure it out.
You probably were not doing anything wrong. I spent a lot of time trying to be clever in order to parallelize things like this and it just wasn't possible without doing CUDA extensions. But it is now! PyTorch now has vmap through functorch and it works.
Probably not your issue, but one kind of annoying bit is that the inputs need to be tensors. I ended up calling partial on the function I was messing around with and then vmapping the partial, which seemed to work.
Pretty similar actually, I needed to pass in some parameters as tuples, lists, etc. Partial would have probably worked, but tbh I wasn't in the mood to try harder at 2am.
PyTorch is far more friendly for deep learning stuff, but sometimes all you want is pure numerical computations that can be vmapped across tensors, and this is where jax shines imho.
Personal Example: I needed to sample a bunch of datapoints, make distributions out of them, sample, and then compute the density of each sample across distributions. Doing this with pytorch was rather slow, I was probably doing something wrong with vectorization and broadcasting, but I didn't have the time to figure it out.
With jax, I wrote a function that produces the samples, then I vmapped the evaluation of a sample across all distributions, then vmapped over all samples. Took a couple of minutes to implement and seconds to execute.
PyTorch also has the advantage of a far more mature ecosystem, libraries like Lightning, Accelerate, Transformers, Evaluate, and so on make building models a breeze.