Commit Graph

114 Commits

Author SHA1 Message Date
4349ff1fc2 Starting to fix some tests.
Few fixes.

Going back on remote metal-rs.

Reusing a single buffer (for now) to speed things up.

Adding some half kernels.

All tests are panicking instead of random failure.

Putting back f16 index select.

Add erf.

Working version for llama2-c.

Fixes + cache compute_pipeline_state.

BF16 metal fix.

Remove some prints.

new_owned -> new()..to_owned().

Better batched matmul.

Metal operational.

Reuse buffers on our own reference counts.

Tmp gemm.

Revert "Tmp gemm."

This reverts commit c65f68e988.

Interleave committing.

Speeding up copies using blit.

Fmt.

Fmt.

Remove the assert!

Fmt all.

Fixes after big rebase.

Add softmax for half and bfloat + tests

Fixing Llama example + accumulate softmax in float.
2023-11-30 11:30:31 +01:00
18d30005c5 Add support to UL2 model family (#1300)
* Add support to UL2 model family

* Update docs with UL2

* Create ActivationWithOptionalGating to avoid polluting activations

* Also refactor quantized t5

* Remove useless conversion

* Revert Activation::NewGelu name change

* Remove useless return

* Apply rustfmt and clippy recommendations

* Reuse t5::ActivationWithOptionalGating in quantized version

* (cosmetic change) use a match rather than ifs + avoid early returns.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-09 18:55:09 +01:00
e6697471bb Add weight and bias functions to LayerNorm (#1306) 2023-11-09 16:09:01 +01:00
3b0d1e7d03 Transposed conv1d in candle-nn. (#1252) 2023-11-03 11:18:25 +01:00
a2a20aeecc Add the swiglu activation from the chatglm PR. (#1246) 2023-11-02 20:01:34 +01:00
d39d0c40fd Add hard-sigmoid and hard-swish activations (#1244)
* Add hard-sigmoid and hard-swish activations

* Update ops.rs

* Use / rather than div.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-11-02 18:20:27 +01:00
392a00a147 Add support for the marian base model. (#1221) 2023-10-30 19:20:36 +00:00
55bc3382cf Allow for different behavior between training and eval (#1213)
* Forward with training.

* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
c8face3f95 Add the relu2 and relu6 activations. (#1201) 2023-10-27 20:51:16 +01:00
b3181455d5 Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d

* no unwrap

* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
0acd16751d Expose the fields from batch-norm. (#1176) 2023-10-25 15:35:32 +01:00
86e1803191 Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting
2023-10-23 17:12:44 +01:00
7366aeac21 Make func cloneable. (#1137) 2023-10-20 16:28:50 +01:00
99cf13e8e2 Add the sequential layer. (#1136) 2023-10-20 16:08:50 +01:00
8e773cc0c6 Experiment with resnet (#1128)
* Add some preliminary support for resnet.

* Add an actual resnet example.
2023-10-19 09:25:03 +01:00
122da87580 feat: add pth varbuilder (#1108) 2023-10-16 16:20:36 +01:00
9fea56d28e Only optimize float tensors. (#1069) 2023-10-10 09:05:41 +01:00
a4967600d0 More general seq forward functions for RNNs. (#1050) 2023-10-07 15:08:01 +01:00
f0c619a4af Use AsRef<str> for set_one. (#1033) 2023-10-05 06:05:44 +01:00
096dee7073 Bump the version to 0.3.0. (#1014)
* Bump the version to 0.3.0.

* Changelog update.
2023-10-01 13:51:57 +01:00
53510ce427 Use a silu activation in mistral. (#991) 2023-09-29 07:06:54 +01:00
ce0a4e3a85 Use the gelu-erf activation. (#969) 2023-09-26 22:30:21 +01:00
c798184c2b Configurable layer idx for the lstm layer. (#962) 2023-09-25 21:31:14 +01:00
4aeb449017 Depreate the VarBuilder::from_safetensors function. (#951) 2023-09-24 11:18:17 +01:00
bcb0ed8f1c Self-contained safetensors for the multiprocess llama example. (#950) 2023-09-24 06:54:49 +01:00
e32c89d90c Add the buffered safetensor wrapper. (#948) 2023-09-23 22:57:42 +01:00
890d069092 Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
ccf352f3d1 Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
402d207f0f VarMap setter functions (#938)
* Add some setter helper functions for varmap.

* Add more comments.
2023-09-23 10:27:51 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00
34f2ecbc3b Fix the leaky relu. (#898) 2023-09-19 18:17:17 +01:00
06cc329e71 Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
2023-09-17 15:59:27 +01:00
30be5b6660 Replication pad (#861)
* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
2023-09-15 14:06:21 +01:00
2746f2c4be DiffNeXt/unet (#859)
* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
2023-09-15 10:14:02 +01:00
0633c85514 Add leaky-relu in the activation enum. (#858) 2023-09-15 07:05:38 +01:00
130fe5a087 Add the upblocks. (#853) 2023-09-14 22:24:56 +01:00
49d3f7f708 Add support to flan-t5 (#840) 2023-09-13 19:27:20 +02:00
9daa6dbe87 Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen

* Add main for t5 module
2023-09-13 07:14:05 +01:00
59e63d690c Add weight, bias, and hidden_size methods (#816)
* Add weight, bias methods to Conv(1|2)

* Add hidden_size method to Embedding

* Expose hidden_size
2023-09-11 16:01:11 +01:00
b7cd58473b TinyViT backbone for segment-anything. (#787)
* TinyViT.

* More TinyViT.

* Add more to the tinyvit backbone.

* Proper padding.

* Plus ViT.

* Add the tiniest vit spec.
2023-09-09 15:10:06 +01:00
7396b8ed1a Segment Anything - process images (#766)
* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
2023-09-07 19:22:45 +01:00
8c991df394 More segment-anything. (#763)
* More segment-anything.

* Split the model in multiple files.

* Start adding the transformer.

* Add the attention block.

* Move the MLP Block.
2023-09-07 07:28:30 +01:00
000fa00e31 Expose the conv2d-transpose layers. (#761) 2023-09-07 06:04:52 +01:00
a17a7c42c1 Add a nn layer for conv-transpose2d. (#760) 2023-09-07 05:47:28 +01:00
bdc9d46fe3 Use an arc in the varbuilder rather than rc. (#757)
* Use an arc in the varbuilder rather than rc.

* Require the backends to be send.

* Request send and sync.
2023-09-06 15:29:09 +01:00
a0d65585db Softmax implementation for cuda. (#747) 2023-09-05 18:38:03 +01:00
6615daf242 Tweaks to softmax. (#745) 2023-09-05 15:22:27 +01:00
1c9e5394a5 Add a custom softmax implementation. (#744)
* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
2023-09-05 14:20:23 +01:00
4698eb5cb6 Fix typo in the nll function document (#742) 2023-09-05 09:25:11 +01:00
26cd266e65 Musicgen text embeddings. (#726)
* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
2023-09-03 18:27:48 +01:00