1c9e5394a5
Add a custom softmax implementation. ( #744 )
...
* Add a custom softmax implementation.
* Add softmaxlastdim to the benchmarks.
* And add a test.
* Support more dtypes.
* Polish the code.
* Use the slow implementation on cuda.
* Add a todo for the cuda kernel.
2023-09-05 14:20:23 +01:00
a8410bf35e
Add some documentation. ( #743 )
2023-09-05 09:51:12 +01:00
cda45a7443
Let outside CustomOp2 implementations use binary_map/binary_map_vec ( #741 )
2023-09-05 09:27:32 +01:00
4698eb5cb6
Fix typo in the nll function document ( #742 )
2023-09-05 09:25:11 +01:00
000487c36f
Add a python function to save as safetensors. ( #740 )
2023-09-04 20:32:14 +01:00
ab0d9fbdd1
Properly set the is_bf16 flag. ( #738 )
2023-09-04 16:45:26 +01:00
f80fd44201
BF16 support for flash-attn. ( #737 )
2023-09-04 16:35:43 +01:00
0d00c06a83
Fix clippy lint. ( #736 )
2023-09-04 16:09:19 +01:00
8395152d20
Llama2c WASM UI improvements ( #732 )
...
* pass seed, expose model seq_len
* wip new llama2.c ui
* final new UI example
* small coppy
* copy
2023-09-04 15:59:22 +01:00
e2f9f60ac2
Avoid some redundant clone. ( #731 )
2023-09-04 09:18:32 +02:00
d0cdea95a5
Add back the bf16 flash-attn kernels. ( #730 )
2023-09-04 07:50:52 +01:00
20512ba408
Return the metadata in the gguf pyo3 bindings. ( #729 )
...
* Return the metadata in the gguf pyo3 bindings.
* Read the metadata in the quantized llama example.
* Get inference to work on gguf files.
2023-09-04 07:07:00 +01:00
9c61b0fc9b
Proper log buckets for t5. ( #727 )
...
* Proper log buckets for t5.
* Properly pass the position bias.
2023-09-03 20:33:50 +01:00
26cd266e65
Musicgen text embeddings. ( #726 )
...
* Musicgen text embeddings.
* Bugfix for layer norm.
* Proper position bias.
* Expose the weights.
2023-09-03 18:27:48 +01:00
bbec527bb9
Fix the musicgen example. ( #724 )
...
* Fix the musicgen example.
* Retrieve the weights from the hub.
2023-09-03 14:50:39 +01:00
f7980e07e0
Add ggufv2
support ( #725 )
2023-09-03 14:41:57 +01:00
74a82c358a
Add the mse loss. ( #723 )
2023-09-03 10:51:40 +01:00
84d003ff53
Handle arbitrary shapes in Tensor::new. ( #718 )
2023-09-02 19:59:21 +01:00
21109e1983
Recommend using maturin. ( #717 )
2023-09-02 16:19:35 +01:00
ad796eb4be
More quantized llama in python. ( #716 )
...
* More quantized llama in python.
* Expose a couple more functions.
* Apply the last layer.
* Use the vocab from the ggml files.
2023-09-02 13:41:48 +01:00
e8e33752f4
Sketch a quantized llama using the pyo3 api. ( #715 )
...
* Sketch a quantized llama using the pyo3 api.
* Add more ops.
* Expose a few more functions to use in the quantized model.
* Rope embeddings.
* Get the forward pass to work.
2023-09-02 11:26:05 +01:00
dabaa479b9
Update README.md ( #714 )
2023-09-02 07:56:12 +01:00
2c1df6bba1
Add a repeat penality to the llama2-c command line example. ( #713 )
...
* Add a repeat penality to the llama2-c command line example.
* Another fix attempt.
2023-09-01 20:38:58 +01:00
4d56cef583
Handle the empty sequence case properly. ( #712 )
...
* Handle the empty sequence case properly.
* Proper fix.
2023-09-01 20:12:30 +01:00
19042962d5
Whisper fix ( #711 )
...
* Remove unnecessary file.
* Whisper fix.
2023-09-01 20:04:07 +01:00
731e3ffb03
Remove unnecessary file. ( #710 )
2023-09-01 19:42:23 +01:00
2fef14cb14
Add a repeat penalty to the llama2.c wasm example. ( #709 )
2023-09-01 19:32:28 +01:00
1e5b2cc1d5
Add some quantized functions to pyo3. ( #708 )
2023-09-01 19:45:36 +02:00
2ed78ab336
Support for quantized tensors in the python api. ( #706 )
...
* Add more pyo3 support.
* Add some support for quantized tensors in pyo3.
* Add an arc layer on qmatmul.
* Add the quantized matmul.
* Quantization support.
* More quantization support.
* Test the python quantization.
2023-09-01 15:53:42 +01:00
237323c2bc
Cleanup the pyo3 setup. ( #705 )
2023-09-01 14:26:18 +01:00
af552a5274
Fix the rnn tests for accelerate. ( #704 )
2023-09-01 13:21:38 +01:00
7529531056
Add the optimizer trait. ( #702 )
2023-09-01 12:55:39 +01:00
f2d476ca65
Replace the discord link. ( #701 )
2023-09-01 09:43:55 +01:00
f9f482d4e5
Add some doc to the varbuilder. ( #700 )
2023-09-01 08:28:35 +01:00
9736236175
Allow retrieving and setting prefix of VarBuilder ( #699 )
2023-09-01 08:08:41 +01:00
30a4b593d7
More ops again. ( #697 )
2023-08-31 22:28:48 +01:00
949f1eae6f
Implement a couple more binary ops. ( #693 )
2023-08-31 21:30:15 +01:00
7cef35c84d
Tweak some quantized args ( #692 )
...
* Print the args + change the default temp/repeat penalty.
* Minor formatting tweak.
2023-08-31 17:25:21 +01:00
7509c98970
Interactive mode for the quantized model. ( #690 )
2023-08-31 10:52:42 +01:00
94aa234dfd
Add the kv-cache to the whisper wasm version. ( #689 )
...
* Add the kv-cache to the whisper wasm version.
* Improve the handling of special tokens.
2023-08-31 09:37:44 +01:00
db59816087
Add a GRU layer. ( #688 )
...
* Add a GRU layer.
* Fix the n gate computation.
2023-08-31 08:43:10 +01:00
d210c71d77
Set the learning rate. ( #687 )
2023-08-31 08:03:40 +01:00
8e84d8a59b
Llama2.c wasm module. ( #686 )
2023-08-31 07:44:32 +01:00
9bd486fb96
Add Yolo Pose to JS Example ( #684 )
...
* add support for yolo pose models
* fix copy
2023-08-31 06:32:57 +01:00
eaf760a751
Add a python variant for the lstm test. ( #682 )
2023-08-30 22:32:08 +01:00
1d0bb48fae
Improve Whisper WASM UI example ( #669 )
...
* wip add module and js worker example
* params
* clean up, send error
* final UI with whisper webworker
* add simple instructions
2023-08-30 20:35:41 +02:00
21e1c73892
Add a LSTM test. ( #681 )
...
* Add a LSTM test.
* Clippy.
2023-08-30 20:05:42 +02:00
2047d34b7c
More robust tests (so that they pass on accelerate). ( #679 )
2023-08-30 18:10:10 +01:00
9874d843f1
Fix the accelerate build ( #678 )
...
* Cosmetic changes.
* Fix the accelerate build for tanh.
2023-08-30 18:31:14 +02:00
7d753d3acd
Mnist training dropout ( #677 )
...
* Use dropout in the mnist training.
* Fix.
2023-08-30 16:41:01 +01:00