Commit Graph

1304 Commits

Author SHA1 Message Date
a0d65585db Softmax implementation for cuda. (#747) 2023-09-05 18:38:03 +01:00
94c6a8d3d3 Add a dedicated cuda kernel for softmax. (#746) 2023-09-05 17:53:20 +02:00
6615daf242 Tweaks to softmax. (#745) 2023-09-05 15:22:27 +01:00
1c9e5394a5 Add a custom softmax implementation. (#744)
* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
2023-09-05 14:20:23 +01:00
a8410bf35e Add some documentation. (#743) 2023-09-05 09:51:12 +01:00
cda45a7443 Let outside CustomOp2 implementations use binary_map/binary_map_vec (#741) 2023-09-05 09:27:32 +01:00
4698eb5cb6 Fix typo in the nll function document (#742) 2023-09-05 09:25:11 +01:00
000487c36f Add a python function to save as safetensors. (#740) 2023-09-04 20:32:14 +01:00
ab0d9fbdd1 Properly set the is_bf16 flag. (#738) 2023-09-04 16:45:26 +01:00
f80fd44201 BF16 support for flash-attn. (#737) 2023-09-04 16:35:43 +01:00
0d00c06a83 Fix clippy lint. (#736) 2023-09-04 16:09:19 +01:00
8395152d20 Llama2c WASM UI improvements (#732)
* pass seed, expose model seq_len

* wip new llama2.c ui

* final new UI example

* small coppy

* copy
2023-09-04 15:59:22 +01:00
e2f9f60ac2 Avoid some redundant clone. (#731) 2023-09-04 09:18:32 +02:00
d0cdea95a5 Add back the bf16 flash-attn kernels. (#730) 2023-09-04 07:50:52 +01:00
20512ba408 Return the metadata in the gguf pyo3 bindings. (#729)
* Return the metadata in the gguf pyo3 bindings.

* Read the metadata in the quantized llama example.

* Get inference to work on gguf files.
2023-09-04 07:07:00 +01:00
9c61b0fc9b Proper log buckets for t5. (#727)
* Proper log buckets for t5.

* Properly pass the position bias.
2023-09-03 20:33:50 +01:00
26cd266e65 Musicgen text embeddings. (#726)
* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
2023-09-03 18:27:48 +01:00
bbec527bb9 Fix the musicgen example. (#724)
* Fix the musicgen example.

* Retrieve the weights from the hub.
2023-09-03 14:50:39 +01:00
f7980e07e0 Add ggufv2 support (#725) 2023-09-03 14:41:57 +01:00
74a82c358a Add the mse loss. (#723) 2023-09-03 10:51:40 +01:00
84d003ff53 Handle arbitrary shapes in Tensor::new. (#718) 2023-09-02 19:59:21 +01:00
21109e1983 Recommend using maturin. (#717) 2023-09-02 16:19:35 +01:00
ad796eb4be More quantized llama in python. (#716)
* More quantized llama in python.

* Expose a couple more functions.

* Apply the last layer.

* Use the vocab from the ggml files.
2023-09-02 13:41:48 +01:00
e8e33752f4 Sketch a quantized llama using the pyo3 api. (#715)
* Sketch a quantized llama using the pyo3 api.

* Add more ops.

* Expose a few more functions to use in the quantized model.

* Rope embeddings.

* Get the forward pass to work.
2023-09-02 11:26:05 +01:00
dabaa479b9 Update README.md (#714) 2023-09-02 07:56:12 +01:00
2c1df6bba1 Add a repeat penality to the llama2-c command line example. (#713)
* Add a repeat penality to the llama2-c command line example.

* Another fix attempt.
2023-09-01 20:38:58 +01:00
4d56cef583 Handle the empty sequence case properly. (#712)
* Handle the empty sequence case properly.

* Proper fix.
2023-09-01 20:12:30 +01:00
19042962d5 Whisper fix (#711)
* Remove unnecessary file.

* Whisper fix.
2023-09-01 20:04:07 +01:00
731e3ffb03 Remove unnecessary file. (#710) 2023-09-01 19:42:23 +01:00
2fef14cb14 Add a repeat penalty to the llama2.c wasm example. (#709) 2023-09-01 19:32:28 +01:00
1e5b2cc1d5 Add some quantized functions to pyo3. (#708) 2023-09-01 19:45:36 +02:00
2ed78ab336 Support for quantized tensors in the python api. (#706)
* Add more pyo3 support.

* Add some support for quantized tensors in pyo3.

* Add an arc layer on qmatmul.

* Add the quantized matmul.

* Quantization support.

* More quantization support.

* Test the python quantization.
2023-09-01 15:53:42 +01:00
237323c2bc Cleanup the pyo3 setup. (#705) 2023-09-01 14:26:18 +01:00
af552a5274 Fix the rnn tests for accelerate. (#704) 2023-09-01 13:21:38 +01:00
7529531056 Add the optimizer trait. (#702) 2023-09-01 12:55:39 +01:00
f2d476ca65 Replace the discord link. (#701) 2023-09-01 09:43:55 +01:00
f9f482d4e5 Add some doc to the varbuilder. (#700) 2023-09-01 08:28:35 +01:00
9736236175 Allow retrieving and setting prefix of VarBuilder (#699) 2023-09-01 08:08:41 +01:00
30a4b593d7 More ops again. (#697) 2023-08-31 22:28:48 +01:00
949f1eae6f Implement a couple more binary ops. (#693) 2023-08-31 21:30:15 +01:00
7cef35c84d Tweak some quantized args (#692)
* Print the args + change the default temp/repeat penalty.

* Minor formatting tweak.
2023-08-31 17:25:21 +01:00
7509c98970 Interactive mode for the quantized model. (#690) 2023-08-31 10:52:42 +01:00
94aa234dfd Add the kv-cache to the whisper wasm version. (#689)
* Add the kv-cache to the whisper wasm version.

* Improve the handling of special tokens.
2023-08-31 09:37:44 +01:00
db59816087 Add a GRU layer. (#688)
* Add a GRU layer.

* Fix the n gate computation.
2023-08-31 08:43:10 +01:00
d210c71d77 Set the learning rate. (#687) 2023-08-31 08:03:40 +01:00
8e84d8a59b Llama2.c wasm module. (#686) 2023-08-31 07:44:32 +01:00
9bd486fb96 Add Yolo Pose to JS Example (#684)
* add support for yolo pose models

* fix copy
2023-08-31 06:32:57 +01:00
eaf760a751 Add a python variant for the lstm test. (#682) 2023-08-30 22:32:08 +01:00
1d0bb48fae Improve Whisper WASM UI example (#669)
* wip add module and js worker example

* params

* clean up, send error

* final UI with whisper webworker

* add simple instructions
2023-08-30 20:35:41 +02:00
21e1c73892 Add a LSTM test. (#681)
* Add a LSTM test.

* Clippy.
2023-08-30 20:05:42 +02:00