* added policy_gradient, modified main, ddpg and README
* fixed typo in README
* removed unnecessary imports
* small refactor
* Use clap for picking up the subcommand to run.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add the Mixtral model.
* Add more of the mixtral layers.
* Add the final layers for mixtral.
* Sketch the expert selection.
* Add some expert routing logic.
* Hopefully finish the routing logic for mixtral.
* Add the mixtral example.
* Fix the weight filenames.
* Bugfix.
* Another fix.
* Yet another fix + remove the unused pragma.
* Shape fix.
* Support for quantized mixtral.
* Support mixtral in the quantized example.
* Mlp or moe type.
* Fix the expert field namings.
* Refactor the mlp bit.
* More MoE logic.
* Add the MoE quantized logic.
* Fix the experts length.
* Add the Mixtral model.
* Add more of the mixtral layers.
* Add the final layers for mixtral.
* Sketch the expert selection.
* Add some expert routing logic.
* Hopefully finish the routing logic for mixtral.
* Add the mixtral example.
* Fix the weight filenames.
* Bugfix.
* Another fix.
* Yet another fix + remove the unused pragma.
* Shape fix.
* Add a readme.
* Add support for SD Turbo
* Set Leading as default in euler_ancestral discrete
* Use the appropriate default values for n_steps and guidance_scale.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Few fixes.
Going back on remote metal-rs.
Reusing a single buffer (for now) to speed things up.
Adding some half kernels.
All tests are panicking instead of random failure.
Putting back f16 index select.
Add erf.
Working version for llama2-c.
Fixes + cache compute_pipeline_state.
BF16 metal fix.
Remove some prints.
new_owned -> new()..to_owned().
Better batched matmul.
Metal operational.
Reuse buffers on our own reference counts.
Tmp gemm.
Revert "Tmp gemm."
This reverts commit c65f68e988.
Interleave committing.
Speeding up copies using blit.
Fmt.
Fmt.
Remove the assert!
Fmt all.
Fixes after big rebase.
Add softmax for half and bfloat + tests
Fixing Llama example + accumulate softmax in float.
* add bce with logit loss
* add bce with logit loss
* remove imports
* fix tiny bug
* add test documentation and refactor function
* fix test cases and formatting
* distilbet files
* Apply various cleanups.
* More cleanups.
* More polish.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* Fix linspace implementation
`steps` should be strictly greater than 1 to make it consistent with the context.
* Handle steps == 0 and steps == 1.
* Fix rustfmt.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* Add OpenChat to quantized examples
* Add chat prompt
* Make the openchat example more in line with the other models.
* Fix a typo.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Updating the readme to coincide with other examples. If you try to run it as previously written, you will get a "cannot find the path specified" error.