Commit Graph

128 Commits

Author SHA1 Message Date
8e773cc0c6 Experiment with resnet (#1128)
* Add some preliminary support for resnet.

* Add an actual resnet example.
2023-10-19 09:25:03 +01:00
122da87580 feat: add pth varbuilder (#1108) 2023-10-16 16:20:36 +01:00
c096f02411 Add a matvec cpu benchmark. (#1076) 2023-10-12 09:29:18 +01:00
89b525b5e7 Convmixer (#1073)
* Only optimize float tensors.

* Use full tensors for zeros and ones.

* Add a benchmark for the matmul slowness.

* Add the convmixer model.

* Proper adaptive pooling.
2023-10-11 18:24:32 +01:00
9fea56d28e Only optimize float tensors. (#1069) 2023-10-10 09:05:41 +01:00
a4967600d0 More general seq forward functions for RNNs. (#1050) 2023-10-07 15:08:01 +01:00
f0c619a4af Use AsRef<str> for set_one. (#1033) 2023-10-05 06:05:44 +01:00
089fc3b584 Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
2023-10-02 17:17:46 +01:00
096dee7073 Bump the version to 0.3.0. (#1014)
* Bump the version to 0.3.0.

* Changelog update.
2023-10-01 13:51:57 +01:00
53510ce427 Use a silu activation in mistral. (#991) 2023-09-29 07:06:54 +01:00
ce0a4e3a85 Use the gelu-erf activation. (#969) 2023-09-26 22:30:21 +01:00
c798184c2b Configurable layer idx for the lstm layer. (#962) 2023-09-25 21:31:14 +01:00
4aeb449017 Depreate the VarBuilder::from_safetensors function. (#951) 2023-09-24 11:18:17 +01:00
bcb0ed8f1c Self-contained safetensors for the multiprocess llama example. (#950) 2023-09-24 06:54:49 +01:00
e32c89d90c Add the buffered safetensor wrapper. (#948) 2023-09-23 22:57:42 +01:00
890d069092 Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
ccf352f3d1 Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
402d207f0f VarMap setter functions (#938)
* Add some setter helper functions for varmap.

* Add more comments.
2023-09-23 10:27:51 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00
34f2ecbc3b Fix the leaky relu. (#898) 2023-09-19 18:17:17 +01:00
7dd8e12472 Bump the crate versions to v0.2.3. (#886)
* Bump the crate version.

* Also update the python bindings.
2023-09-18 12:14:03 +01:00
06cc329e71 Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
2023-09-17 15:59:27 +01:00
30be5b6660 Replication pad (#861)
* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
2023-09-15 14:06:21 +01:00
2746f2c4be DiffNeXt/unet (#859)
* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
2023-09-15 10:14:02 +01:00
0633c85514 Add leaky-relu in the activation enum. (#858) 2023-09-15 07:05:38 +01:00
130fe5a087 Add the upblocks. (#853) 2023-09-14 22:24:56 +01:00
49d3f7f708 Add support to flan-t5 (#840) 2023-09-13 19:27:20 +02:00
9daa6dbe87 Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen

* Add main for t5 module
2023-09-13 07:14:05 +01:00
2257f4d475 Bump the crate version + update the changelog. (#822) 2023-09-12 06:39:24 +01:00
871efc0307 Bugfix for the conv2d cpu kernel. (#820) 2023-09-11 23:11:27 +01:00
59e63d690c Add weight, bias, and hidden_size methods (#816)
* Add weight, bias methods to Conv(1|2)

* Add hidden_size method to Embedding

* Expose hidden_size
2023-09-11 16:01:11 +01:00
98d1242b8f im2col based conv2d (#802)
* im2col implementation for conv2d.

* Fix for the im2col implementation to match the current conv2d.

* Small optimization.

* Add a cuda kernel.

* Handle arbitrary layouts.

* Im2Col cuda code.
2023-09-10 21:02:42 +01:00
4f18180fc7 Bugfix so that im2col produce the same results as conv2d. (#801) 2023-09-10 16:59:46 +01:00
559944146f Add an im2col based benchmark. (#800)
* Add an im2col based benchmark.

* Reshape the final result.
2023-09-10 16:56:28 +01:00
b7cd58473b TinyViT backbone for segment-anything. (#787)
* TinyViT.

* More TinyViT.

* Add more to the tinyvit backbone.

* Proper padding.

* Plus ViT.

* Add the tiniest vit spec.
2023-09-09 15:10:06 +01:00
7396b8ed1a Segment Anything - process images (#766)
* Start processing images.

* Add LayerNorm2d.

* Properly use LayerNorm2d.

* Tweak eps.

* Use LayerNorm on inputs with a rank different from 3.

* Window partitioning.

* Fix a couple todos.

* More todos.

* Hard-code the einsums.

* More padding support.

* Some sizes tweaks.

* Use the hub to get the weights.

* Use a batch matmul.

* Tweaks.

* More fixes.

* Get some predictions to be generated.
2023-09-07 19:22:45 +01:00
8c991df394 More segment-anything. (#763)
* More segment-anything.

* Split the model in multiple files.

* Start adding the transformer.

* Add the attention block.

* Move the MLP Block.
2023-09-07 07:28:30 +01:00
000fa00e31 Expose the conv2d-transpose layers. (#761) 2023-09-07 06:04:52 +01:00
a17a7c42c1 Add a nn layer for conv-transpose2d. (#760) 2023-09-07 05:47:28 +01:00
bdc9d46fe3 Use an arc in the varbuilder rather than rc. (#757)
* Use an arc in the varbuilder rather than rc.

* Require the backends to be send.

* Request send and sync.
2023-09-06 15:29:09 +01:00
a0d65585db Softmax implementation for cuda. (#747) 2023-09-05 18:38:03 +01:00
6615daf242 Tweaks to softmax. (#745) 2023-09-05 15:22:27 +01:00
1c9e5394a5 Add a custom softmax implementation. (#744)
* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
2023-09-05 14:20:23 +01:00
4698eb5cb6 Fix typo in the nll function document (#742) 2023-09-05 09:25:11 +01:00
e2f9f60ac2 Avoid some redundant clone. (#731) 2023-09-04 09:18:32 +02:00
26cd266e65 Musicgen text embeddings. (#726)
* Musicgen text embeddings.

* Bugfix for layer norm.

* Proper position bias.

* Expose the weights.
2023-09-03 18:27:48 +01:00
74a82c358a Add the mse loss. (#723) 2023-09-03 10:51:40 +01:00
af552a5274 Fix the rnn tests for accelerate. (#704) 2023-09-01 13:21:38 +01:00
7529531056 Add the optimizer trait. (#702) 2023-09-01 12:55:39 +01:00
f9f482d4e5 Add some doc to the varbuilder. (#700) 2023-09-01 08:28:35 +01:00