Commit Graph

159 Commits

Author SHA1 Message Date
9fc210fae8 Merge pull request #1318 from huggingface/metal4
Starting to fix some tests.
2023-12-20 15:37:31 +01:00
03641293ee Clippy pass. 2023-12-18 15:22:43 +01:00
94817dac56 Bump the crate version to 0.3.2. (#1452) 2023-12-17 05:34:53 -06:00
1e86717bf2 Fix a couple typos (#1451)
* Mixtral quantized instruct.

* Fix a couple typos.
2023-12-17 05:20:05 -06:00
c630622a07 Expose AdamW parameters (#1449)
* Expose AdamW parameters

* Use reference
2023-12-16 18:41:56 -06:00
6bc92e63cb Addressing a lot of comments. 2023-12-15 13:06:04 +01:00
aa04015098 Remove unwrap(). 2023-12-15 12:23:28 +01:00
26540641c1 Renamed all kernel names. 2023-12-15 11:24:47 +01:00
ece4c69a68 Fixing softmax. 2023-12-15 01:35:08 +01:00
361f2ad2af Working with merging encoders and using fences. 2023-12-14 16:05:33 +01:00
e60f9b5dfc Speedup ShardedSafeTensors to load Tensors with default hints (#1384)
* Speedup ShardedSafeTensors to load Tensors with default hints

* Tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-14 08:08:56 -06:00
87dc559817 Lots of updates including some stack of command buffers. 2023-12-12 17:41:56 +01:00
236b820e28 Another prelu bugfix. (#1407) 2023-12-06 09:54:41 +01:00
2648e797c2 Use the proper broadcasting for prelu. (#1406) 2023-12-05 07:09:31 +01:00
b5c283e86f Add the prelu layer. (#1402) 2023-12-03 16:06:09 +00:00
4349ff1fc2 Starting to fix some tests.
Few fixes.

Going back on remote metal-rs.

Reusing a single buffer (for now) to speed things up.

Adding some half kernels.

All tests are panicking instead of random failure.

Putting back f16 index select.

Add erf.

Working version for llama2-c.

Fixes + cache compute_pipeline_state.

BF16 metal fix.

Remove some prints.

new_owned -> new()..to_owned().

Better batched matmul.

Metal operational.

Reuse buffers on our own reference counts.

Tmp gemm.

Revert "Tmp gemm."

This reverts commit c65f68e988.

Interleave committing.

Speeding up copies using blit.

Fmt.

Fmt.

Remove the assert!

Fmt all.

Fixes after big rebase.

Add softmax for half and bfloat + tests

Fixing Llama example + accumulate softmax in float.
2023-11-30 11:30:31 +01:00
bfa7c8fc01 Implement the module trait directly for QMatMul. (#1372) 2023-11-25 10:09:45 +00:00
a209ce8ceb Update for 0.3.1. (#1324) 2023-11-11 18:48:52 +00:00
18d30005c5 Add support to UL2 model family (#1300)
* Add support to UL2 model family

* Update docs with UL2

* Create ActivationWithOptionalGating to avoid polluting activations

* Also refactor quantized t5

* Remove useless conversion

* Revert Activation::NewGelu name change

* Remove useless return

* Apply rustfmt and clippy recommendations

* Reuse t5::ActivationWithOptionalGating in quantized version

* (cosmetic change) use a match rather than ifs + avoid early returns.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-09 18:55:09 +01:00
e6697471bb Add weight and bias functions to LayerNorm (#1306) 2023-11-09 16:09:01 +01:00
3b0d1e7d03 Transposed conv1d in candle-nn. (#1252) 2023-11-03 11:18:25 +01:00
a2a20aeecc Add the swiglu activation from the chatglm PR. (#1246) 2023-11-02 20:01:34 +01:00
d39d0c40fd Add hard-sigmoid and hard-swish activations (#1244)
* Add hard-sigmoid and hard-swish activations

* Update ops.rs

* Use / rather than div.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-02 18:20:27 +01:00
392a00a147 Add support for the marian base model. (#1221) 2023-10-30 19:20:36 +00:00
55bc3382cf Allow for different behavior between training and eval (#1213)
* Forward with training.

* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
c8face3f95 Add the relu2 and relu6 activations. (#1201) 2023-10-27 20:51:16 +01:00
b3181455d5 Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d

* no unwrap

* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
0acd16751d Expose the fields from batch-norm. (#1176) 2023-10-25 15:35:32 +01:00
86e1803191 Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting
2023-10-23 17:12:44 +01:00
7366aeac21 Make func cloneable. (#1137) 2023-10-20 16:28:50 +01:00
99cf13e8e2 Add the sequential layer. (#1136) 2023-10-20 16:08:50 +01:00
8e773cc0c6 Experiment with resnet (#1128)
* Add some preliminary support for resnet.

* Add an actual resnet example.
2023-10-19 09:25:03 +01:00
122da87580 feat: add pth varbuilder (#1108) 2023-10-16 16:20:36 +01:00
c096f02411 Add a matvec cpu benchmark. (#1076) 2023-10-12 09:29:18 +01:00
89b525b5e7 Convmixer (#1073)
* Only optimize float tensors.

* Use full tensors for zeros and ones.

* Add a benchmark for the matmul slowness.

* Add the convmixer model.

* Proper adaptive pooling.
2023-10-11 18:24:32 +01:00
9fea56d28e Only optimize float tensors. (#1069) 2023-10-10 09:05:41 +01:00
a4967600d0 More general seq forward functions for RNNs. (#1050) 2023-10-07 15:08:01 +01:00
f0c619a4af Use AsRef<str> for set_one. (#1033) 2023-10-05 06:05:44 +01:00
089fc3b584 Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
2023-10-02 17:17:46 +01:00
096dee7073 Bump the version to 0.3.0. (#1014)
* Bump the version to 0.3.0.

* Changelog update.
2023-10-01 13:51:57 +01:00
53510ce427 Use a silu activation in mistral. (#991) 2023-09-29 07:06:54 +01:00
ce0a4e3a85 Use the gelu-erf activation. (#969) 2023-09-26 22:30:21 +01:00
c798184c2b Configurable layer idx for the lstm layer. (#962) 2023-09-25 21:31:14 +01:00
4aeb449017 Depreate the VarBuilder::from_safetensors function. (#951) 2023-09-24 11:18:17 +01:00
bcb0ed8f1c Self-contained safetensors for the multiprocess llama example. (#950) 2023-09-24 06:54:49 +01:00
e32c89d90c Add the buffered safetensor wrapper. (#948) 2023-09-23 22:57:42 +01:00
890d069092 Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
ccf352f3d1 Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
402d207f0f VarMap setter functions (#938)
* Add some setter helper functions for varmap.

* Add more comments.
2023-09-23 10:27:51 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00