21e1c73892
Add a LSTM test. ( #681 )
...
* Add a LSTM test.
* Clippy.
2023-08-30 20:05:42 +02:00
2047d34b7c
More robust tests (so that they pass on accelerate). ( #679 )
2023-08-30 18:10:10 +01:00
3159982a89
Add a Dropout layer ( #676 )
...
* Add a dropout layer.
* Add an actual layer.
2023-08-30 16:19:28 +01:00
ad8a62dbf5
Add tanh. ( #675 )
...
* Add tanh.
* Use tanh in the lstm block.
* Add a test for tanh forward and backward passes.
2023-08-30 13:54:50 +01:00
f35b9f6baa
Add some recurrent neural networks ( #674 )
...
* Add the rnn module.
* More LSTM.
* Implement the RNN forward pass.
* More forward pass for LSTM.
2023-08-30 13:27:09 +01:00
2d3fcad267
Simplify usage of the pool functions. ( #662 )
...
* Simplify usage of the pool functions.
* Small tweak.
* Attempt at using apply to simplify the convnet definition.
2023-08-29 19:12:16 +01:00
a044907ffc
Dilated convolutions ( #657 )
...
* Add the dilation parameter.
* Restore the basic optimizer example.
* Dilation support in cudnn.
* Use the dilation parameter in the cpu backend.
* More dilation support.
* No support for dilation in transposed convolutions.
* Add dilation to a test.
* Remove a print.
* Helper function.
2023-08-29 16:12:11 +01:00
33c23c19b6
Preliminary support for SDXL. ( #647 )
...
* Preliminary support for SDXL.
* More SDXL support.
* More SDXL.
* Use the proper clip config.
* Querying for existing tensors.
* More robust test.
2023-08-29 09:00:04 +01:00
4c338b0cd9
VarBuilder cleanup ( #627 )
...
* VarBuilder cleanup.
* Implement the basic varbuilders.
* Add the sharded code.
* Proper support for tensor sharding.
2023-08-27 18:03:26 +01:00
431051cc32
Add Efficientnet ( #572 )
...
* EfficientNet.
* Complete the efficientnet implementation.
* Improve group handling.
* Get the efficientnet to work.
2023-08-23 18:02:58 +01:00
aba1e90797
Add some group parameter to convolutions. ( #566 )
...
* Add some group parameter to convolutions.
* Avoid some unnecessary groups checks.
* Move the tensor convolution bits.
* Properh handling of groups.
* Bump the crate version.
* And add a changelog.
2023-08-23 12:58:55 +01:00
11c7e7bd67
Some fixes for yolo-v3. ( #529 )
...
* Some fixes for yolo-v3.
* Use the running stats for inference in the batch-norm layer.
* Get some proper predictions for yolo.
* Avoid the quadratic insertion.
2023-08-20 23:19:15 +01:00
e3d2786ffb
Add a couple functions required for yolo. ( #527 )
2023-08-20 17:02:05 +01:00
d2622a8160
Move the VarMap to a separate file ( #525 )
...
* Move the var-map struct in a separate file.
* Fix some typos.
2023-08-20 14:25:07 +01:00
42e1cc8062
Add a batch normalization layer ( #508 )
...
* Add BatchNormalization.
* More batch-norm.
* Add some validation of the inputs.
* More validation.
2023-08-18 20:05:56 +01:00
c78ce76501
Add a simple Module trait and implement it for the various nn layers ( #500 )
...
* Start adding the module trait.
* Use the module trait.
* Implement module for qmatmul.
2023-08-18 09:38:22 +01:00
13401df4d1
Add an abstract type for RmsNorm. ( #499 )
2023-08-18 08:52:14 +01:00
d32e8199cd
Layer norm tweaks ( #482 )
...
* Add some options to make layer-norm more configurable.
* Add the rms-norm variant.
* Replace the RmsNorm with the shared bits.
2023-08-17 10:07:13 +01:00
55e428c8ae
Expose the varmap inner data. ( #411 )
2023-08-11 16:58:56 +01:00
89d3926c9b
Fixes for the stable diffusion example. ( #342 )
...
* Fixes for the stable diffusion example.
* Bugfix.
* Another fix.
* Fix for group-norm.
* More fixes to get SD to work.
2023-08-08 14:57:09 +01:00
2345b8ce3f
Skeleton for the avg-pool2d and upsample-nearest2d ops. ( #337 )
...
* Skeleton for the avg-pool2d and upsample-nearest2d ops.
* Preliminary conv2d support.
2023-08-07 16:15:38 +01:00
5bb2fce998
Implement group-norm. ( #334 )
...
* Implement group-norm.
* Add some testing for group-norm.
2023-08-07 06:53:05 +01:00
d34039e352
Add a stable diffusion example ( #328 )
...
* Start adding a stable-diffusion example.
* Proper computation of the causal mask.
* Add the chunk operation.
* Work in progress: port the attention module.
* Add some dummy modules for conv2d and group-norm, get the attention module to compile.
* Re-enable the 2d convolution.
* Add the embeddings module.
* Add the resnet module.
* Add the unet blocks.
* Add the unet.
* And add the variational auto-encoder.
* Use the pad function from utils.
2023-08-06 17:49:43 +01:00
620f83cf66
Add the candle-datasets crate ( #322 )
...
* Move the vision datasets to a separate crate.
* Move the batcher bits.
* Update the readme.
* Move the tiny-stories bits.
---------
Co-authored-by: Jane Doe <jane.doe@example.org >
2023-08-05 08:56:50 +01:00
0902846f25
Add the AdamW optimizer. ( #307 )
...
* Add the AdamW optimizer.
* Add some AdamW test validated against PyTorch.
2023-08-02 14:03:49 +01:00
cc76c63202
Use index-select for the embeddings as it supports backprop. ( #298 )
2023-08-01 20:44:43 +01:00
ff876c2103
Llama more training ( #297 )
...
* Rework the var-builder to handle initializations.
* Add some helper functions for layer creation.
* Improve the layer initializations.
* Get initialized variables.
* Precompute the rot embeddings when training lamas.
2023-08-01 19:53:41 +01:00
614f911e9e
Add some batcher variants that handle errors. ( #294 )
2023-08-01 09:40:34 +01:00
e1e8127f15
Add the batcher. ( #293 )
2023-08-01 09:16:10 +01:00
1064b9b031
Add the cross-entropy loss. ( #287 )
2023-07-31 14:26:36 +01:00
ffeafbfc43
Make the nll op closer to the pytorch version + add a test. ( #286 )
2023-07-31 14:14:01 +01:00
16c33383eb
Improve the mnist training example. ( #276 )
...
* Improve the mnist training example.
* Add some initialization routine that can be used for nn.
* Proper initialization in the mnist example.
2023-07-29 16:28:22 +01:00
07eb899729
More mnist training. ( #275 )
2023-07-29 13:29:31 +01:00
3eb2bc6d07
Softmax numerical stability. ( #267 )
...
* Softmax numerical stability.
* Fix the flash-attn test.
2023-07-28 13:13:01 +01:00
8435a99edd
Added comment about offsets.
2023-07-27 20:11:57 +02:00
952eca6b54
Fixing slice errors + comments.
2023-07-27 16:59:32 +02:00
7c7e6ba201
Removing inner dependency on safetensors.
2023-07-27 09:58:47 +02:00
1735e4831e
TP sharding v2
2023-07-27 09:58:14 +02:00
1f26042693
Move some shared functions to the nn module. ( #221 )
2023-07-22 13:25:11 +01:00
43c7223292
Rename the .r functions to .dims so as to be a bit more explicit. ( #220 )
2023-07-22 10:39:27 +01:00
dfd624dbd3
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
2023-07-19 16:25:44 +02:00
2a74019ec6
Vision dataset ( #179 )
...
* Add some readers for the mnist dataset.
* Import the cifar and mnist dataset.
2023-07-16 23:43:55 +01:00
d88b6cdca9
Add backtrace information to errors where relevant. ( #166 )
...
* Add backtrace information to errors where relevant.
* More backtrace information.
* Add to the FAQ.
2023-07-14 09:31:25 +01:00
a2f72edc0d
Simplify the parameters used by sum and sum_keepdim. ( #165 )
2023-07-14 08:22:08 +01:00
2bfa791336
Use the same default as pytorch for sum. ( #164 )
2023-07-13 21:32:32 +01:00
23e105cd94
Add the gradient for reduce-sum. ( #162 )
...
* Add the gradient for reduce-sum.
* And add the gradient for the broadcast ops.
* Add some backprop tests.
* Add some linear regression example.
2023-07-13 20:14:10 +01:00
ded93a1169
Add the SGD optimizer ( #160 )
...
* Add the nn::optim and some conversion traits.
* Add the backward_step function for SGD.
* Get the SGD optimizer to work and add a test.
* Make the test slighly simpler.
2023-07-13 19:05:44 +01:00
465fc8c0c5
Add some documentation and test to the linear layer. ( #151 )
...
* Add some documentation and test to the linear layer.
* Layer norm doc.
* Minor tweaks.
2023-07-12 20:24:23 +01:00
a76ec797da
Cleanup the main crate error and add a couple dedicated ones ( #142 )
...
* Cosmetic cleanups to the error enum.
* More error cleanup.
* Proper error handling rather than panicing.
* Add some conv1d dedicated error.
2023-07-12 09:17:08 +01:00
fa760759e5
Allow for lazy loading of npz files, use it in llama to reduce memory usage in the cpu version. ( #141 )
2023-07-11 20:22:34 +01:00