9fc210fae8
Merge pull request #1318 from huggingface/metal4
...
Starting to fix some tests.
2023-12-20 15:37:31 +01:00
03641293ee
Clippy pass.
2023-12-18 15:22:43 +01:00
1e86717bf2
Fix a couple typos ( #1451 )
...
* Mixtral quantized instruct.
* Fix a couple typos.
2023-12-17 05:20:05 -06:00
c630622a07
Expose AdamW parameters ( #1449 )
...
* Expose AdamW parameters
* Use reference
2023-12-16 18:41:56 -06:00
6bc92e63cb
Addressing a lot of comments.
2023-12-15 13:06:04 +01:00
aa04015098
Remove unwrap()
.
2023-12-15 12:23:28 +01:00
26540641c1
Renamed all kernel names.
2023-12-15 11:24:47 +01:00
ece4c69a68
Fixing softmax.
2023-12-15 01:35:08 +01:00
361f2ad2af
Working with merging encoders and using fences.
2023-12-14 16:05:33 +01:00
e60f9b5dfc
Speedup ShardedSafeTensors to load Tensors with default hints ( #1384 )
...
* Speedup ShardedSafeTensors to load Tensors with default hints
* Tweaks.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2023-12-14 08:08:56 -06:00
87dc559817
Lots of updates including some stack of command buffers.
2023-12-12 17:41:56 +01:00
236b820e28
Another prelu bugfix. ( #1407 )
2023-12-06 09:54:41 +01:00
2648e797c2
Use the proper broadcasting for prelu. ( #1406 )
2023-12-05 07:09:31 +01:00
b5c283e86f
Add the prelu layer. ( #1402 )
2023-12-03 16:06:09 +00:00
4349ff1fc2
Starting to fix some tests.
...
Few fixes.
Going back on remote metal-rs.
Reusing a single buffer (for now) to speed things up.
Adding some half kernels.
All tests are panicking instead of random failure.
Putting back f16 index select.
Add erf.
Working version for llama2-c.
Fixes + cache compute_pipeline_state.
BF16 metal fix.
Remove some prints.
new_owned -> new()..to_owned().
Better batched matmul.
Metal operational.
Reuse buffers on our own reference counts.
Tmp gemm.
Revert "Tmp gemm."
This reverts commit c65f68e988
.
Interleave committing.
Speeding up copies using blit.
Fmt.
Fmt.
Remove the assert!
Fmt all.
Fixes after big rebase.
Add softmax for half and bfloat + tests
Fixing Llama example + accumulate softmax in float.
2023-11-30 11:30:31 +01:00
18d30005c5
Add support to UL2 model family ( #1300 )
...
* Add support to UL2 model family
* Update docs with UL2
* Create ActivationWithOptionalGating to avoid polluting activations
* Also refactor quantized t5
* Remove useless conversion
* Revert Activation::NewGelu name change
* Remove useless return
* Apply rustfmt and clippy recommendations
* Reuse t5::ActivationWithOptionalGating in quantized version
* (cosmetic change) use a match rather than ifs + avoid early returns.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2023-11-09 18:55:09 +01:00
e6697471bb
Add weight and bias functions to LayerNorm ( #1306 )
2023-11-09 16:09:01 +01:00
3b0d1e7d03
Transposed conv1d in candle-nn. ( #1252 )
2023-11-03 11:18:25 +01:00
a2a20aeecc
Add the swiglu activation from the chatglm PR. ( #1246 )
2023-11-02 20:01:34 +01:00
d39d0c40fd
Add hard-sigmoid and hard-swish activations ( #1244 )
...
* Add hard-sigmoid and hard-swish activations
* Update ops.rs
* Use / rather than div.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2023-11-02 18:20:27 +01:00
392a00a147
Add support for the marian base model. ( #1221 )
2023-10-30 19:20:36 +00:00
55bc3382cf
Allow for different behavior between training and eval ( #1213 )
...
* Forward with training.
* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
c8face3f95
Add the relu2 and relu6 activations. ( #1201 )
2023-10-27 20:51:16 +01:00
b3181455d5
Add fuse-conv-bn method for Conv2d ( #1196 )
...
* Add fuse-conv-bn method for Conv2d
* no unwrap
* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
0acd16751d
Expose the fields from batch-norm. ( #1176 )
2023-10-25 15:35:32 +01:00
86e1803191
Add Binary Cross Entropy With Logit Loss to nn crate ( #1157 )
...
* add bce with logit loss
* add bce with logit loss
* remove imports
* fix tiny bug
* add test documentation and refactor function
* fix test cases and formatting
2023-10-23 17:12:44 +01:00
7366aeac21
Make func cloneable. ( #1137 )
2023-10-20 16:28:50 +01:00
99cf13e8e2
Add the sequential layer. ( #1136 )
2023-10-20 16:08:50 +01:00
8e773cc0c6
Experiment with resnet ( #1128 )
...
* Add some preliminary support for resnet.
* Add an actual resnet example.
2023-10-19 09:25:03 +01:00
122da87580
feat: add pth varbuilder ( #1108 )
2023-10-16 16:20:36 +01:00
9fea56d28e
Only optimize float tensors. ( #1069 )
2023-10-10 09:05:41 +01:00
a4967600d0
More general seq forward functions for RNNs. ( #1050 )
2023-10-07 15:08:01 +01:00
f0c619a4af
Use AsRef<str> for set_one. ( #1033 )
2023-10-05 06:05:44 +01:00
096dee7073
Bump the version to 0.3.0. ( #1014 )
...
* Bump the version to 0.3.0.
* Changelog update.
2023-10-01 13:51:57 +01:00
53510ce427
Use a silu activation in mistral. ( #991 )
2023-09-29 07:06:54 +01:00
ce0a4e3a85
Use the gelu-erf activation. ( #969 )
2023-09-26 22:30:21 +01:00
c798184c2b
Configurable layer idx for the lstm layer. ( #962 )
2023-09-25 21:31:14 +01:00
4aeb449017
Depreate the VarBuilder::from_safetensors function. ( #951 )
2023-09-24 11:18:17 +01:00
bcb0ed8f1c
Self-contained safetensors for the multiprocess llama example. ( #950 )
2023-09-24 06:54:49 +01:00
e32c89d90c
Add the buffered safetensor wrapper. ( #948 )
2023-09-23 22:57:42 +01:00
890d069092
Self-contained safetensor wrappers ( #946 )
...
* Self-contained safetensor wrappers.
* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
ccf352f3d1
Use yoke to provide a self-referential container for mmaped safetenso… ( #939 )
...
* Use yoke to provide a self-referential container for mmaped safetensor files.
* Add the new self-owned type for safetensor files without removing the previous version.
* Add routing.
* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
402d207f0f
VarMap setter functions ( #938 )
...
* Add some setter helper functions for varmap.
* Add more comments.
2023-09-23 10:27:51 +01:00
7b1ddcff47
Add clone to various nn layers. ( #910 )
2023-09-20 11:33:51 +01:00
34f2ecbc3b
Fix the leaky relu. ( #898 )
2023-09-19 18:17:17 +01:00
06cc329e71
Remove the parameters for the Wuerstchen layer-norm. ( #879 )
...
* Remove the parameters for the Wuerstchen layer-norm.
* Fixes.
* More fixes (including conv-transpose2d.
* More fixes.
* Again more fixes.
2023-09-17 15:59:27 +01:00
30be5b6660
Replication pad ( #861 )
...
* Add the embed mapper convolutions.
* Add the replication pad layer.
* Use the replication-pad op.
* Tweak a todo.
2023-09-15 14:06:21 +01:00
2746f2c4be
DiffNeXt/unet ( #859 )
...
* DiffNeXt/unet
* Start adding the vae.
* VAE residual block.
* VAE forward pass.
* Add pixel shuffling.
* Actually use pixel shuffling.
2023-09-15 10:14:02 +01:00
0633c85514
Add leaky-relu in the activation enum. ( #858 )
2023-09-15 07:05:38 +01:00
130fe5a087
Add the upblocks. ( #853 )
2023-09-14 22:24:56 +01:00