Commit Graph

32 Commits

Author SHA1 Message Date
8cf39d27ce Metal part 1 - Scaffolding for metal. 2023-11-09 19:30:59 +01:00
e4c9adfdbe Implemented meshgrid (#1174)
* Implemented meshgrid

* Resolved feedback from LaurentMazare

* Rustfmt

* Updated docstring

* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
7396b8ed1a Segment Anything - process images (#766)
* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
2023-09-07 19:22:45 +01:00
7299a68353 img2img pipeline for stable diffusion. (#752)
* img2img pipeline for stable diffusion.

* Rename the arguments + fix.

* Fix for zero strength.

* Another fix.

* Another fix.

* Revert.

* Include the backtrace.

* Noise scaling.

* Fix the height/width.
2023-09-06 07:06:49 +01:00
3071134788 Get the ggml based llama to generate some text. (#464)
* Add more stats to the ggml example.

* Build a quantized model from the file content.

* Move the tensor retrieval in the main crate.

* Start adding the forward pass.

* Add more to the forward pass of the quantized llama.

* Apply the attention layers.

* Add the sampling loop.

* Get the sampling loop to work.

* Minor tweak.

* Add a quantize/dequantize test.

* Bugfix.

* Add a comment + swap the order.

* Bugfixes.
2023-08-16 12:41:07 +01:00
f3fe730a30 Npy tweaks & error with path (#384)
* Simplify the npy writing.

* Wrap the file path so as to provide better errors.
2023-08-10 06:21:58 +01:00
a27239f3d9 Add training for the llama2.c example (#296)
* Rework the commands and run inference by default.

* Add the training module and load the training dataset.

* Random dataset iterator.

* Proper valid-loss computation.

* Compute the evaluation loss.

* Add more substance to the training loop.
2023-08-01 17:23:07 +01:00
952eca6b54 Fixing slice errors + comments. 2023-07-27 16:59:32 +02:00
1235aa2536 Use bail rather than wrapping a string where possible. (#249)
* Use bail rather than wrapping a string where possible.

* Revert the cuda default bit.
2023-07-26 15:42:46 +01:00
fa2b64d678 Proper flash-attn parameters. (#244)
* Proper flash-attn parameters.

* Set the flash attention parameters.

* Add more validations.

* Setup the o_ flash attn parameters.

* More flash-attn support.

* Set more flash attn parameters.
2023-07-26 10:13:40 +01:00
6eeea1b04e Polish the index-add op and use it in the index-select backprop (#218)
* Add the cpu version of index-add.

* More cpu support for index-add.

* Use index-add in the backprop.
2023-07-22 05:31:46 +01:00
410654525f Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax.

* Clippy fixes.

* Use the newly introduced argmax.

* Fix the strided case.

* Handle the non-contiguous case.
2023-07-21 11:41:08 +01:00
fa08fb3126 Add the index-select op. (#209)
* Add the index-select op.

* Cpu implementation of index-select.

* Add the cpu implementation for index-select.
2023-07-20 14:01:03 +01:00
ad12e20f6b Add cpu support for min and max. (#202)
* Add cpu support for min and max.

* Add min/max all.
2023-07-19 17:11:44 +01:00
d88b6cdca9 Add backtrace information to errors where relevant. (#166)
* Add backtrace information to errors where relevant.

* More backtrace information.

* Add to the FAQ.
2023-07-14 09:31:25 +01:00
a2f72edc0d Simplify the parameters used by sum and sum_keepdim. (#165) 2023-07-14 08:22:08 +01:00
5ee3c95582 Move the variable creation to the variable module. (#159)
* Move the variable creation to the variable module.

* Make it possible to set a variable.

* Add some basic gradient descent test.

* Get the gradient descent test to work.
2023-07-13 16:55:40 +01:00
21aa29ddce Use a rwlock for inner mutability. (#156)
* Use a rw-lock.

* Make clippy happier.
2023-07-13 11:25:24 +01:00
50b0946a2d Tensor mutability (#154)
* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
2023-07-13 11:04:40 +01:00
ba35d895e7 Sketch the candle-transformers crate. (#147)
* Sketch the candle-transformers crate.

* Format the empty files.
2023-07-12 13:49:31 +01:00
a76ec797da Cleanup the main crate error and add a couple dedicated ones (#142)
* Cosmetic cleanups to the error enum.

* More error cleanup.

* Proper error handling rather than panicing.

* Add some conv1d dedicated error.
2023-07-12 09:17:08 +01:00
37cad85869 Resurrect the llama npy support. (#140) 2023-07-11 19:32:10 +01:00
64264d97c1 Modular backends (#138)
* Add some trait to formalize backends.

* Use the generic backend trait.
2023-07-11 11:17:02 +01:00
270997a055 Add the elu op. (#113) 2023-07-09 21:56:31 +01:00
dd60bd84bb MKL adjustments. (#87) 2023-07-06 11:37:27 +01:00
a57b314780 Add a batch dimension on the bert example. 2023-07-04 06:10:52 +01:00
ad52b0377c Add the varbuilder + check shapes. 2023-07-03 15:32:20 +01:00
899c76de75 Handle more types in safetensors. 2023-07-03 10:09:46 +01:00
fe2c07e368 Add the ST error. 2023-07-03 08:44:00 +01:00
19cbbc5212 Improve how we check that the dims are in bounds. 2023-06-30 09:11:00 +01:00
b50bd880ce Only narrow when needed + deactivate the kv cache. 2023-06-29 19:07:52 +01:00
d7f729fb8f Refactor the hierarchy. 2023-06-27 11:57:27 +02:00