Following the work completed in https://github.com/asmith26/jax_toolkit/pull/119, it could be useful to add the `vmap` functionality to additional losses and metrics.
Following the work completed in #119, it could be useful to add the
vmapfunctionality to additional losses and metrics.