Show HN: Zephyr: New [WIP] NN Jax Framework; Short, Simple, Declarative

github.com

1 points by imaginepandas 11 hours ago

Hello HN! I have an early work-in-progress Neural Network Framework written on top of JAX.

Simple | Declarative | No need to learn duplicated JAX transforms or specialized manipulation functions

How to build a NN? 1. Make `f(params, x, hyperparameters)` 2. Initialize the params `params = trace(f, key, x, *hyperparameters)` 3. Use the function `f` you made: `f(params, x, hyperparameters)` 4. Optionally: use partial to bake-in the hyperparams: `model = f(_,_, hypermeters)` ; `_` will be explained in the README. You can also use `partial`

Its key difference with other JAX frameworks is its simplicity and straightforwardness. With other frameworks, the network is transformed to an (init, apply) which are pure functions and you basically don't use the actual code you've written and instead use this. With zephyr, neural networks look like neural networks: a function of parameters, input-data, and hyperparameters. It's also patterned for FP use, so partial application will a useful alternative to states in OO.

Lastly, it's meant to be declarative and simple. Neural networks are just functions, not objects that need instantiation or anything. This means code are usually shorter as declaration and usage/computation happens in the same place since those things are highly coupled and so placing them together results in less cognitive load. (No more wondering what the MLP(x)'s output-dim-per-layer is, since it's all bundled together).

It's early stage and so the core nets are few and unpolished, but I want to focus on the core first before moving on to implementing all core nets. Feedback on any part of it is very welcome!