98d1242b8f
im2col based conv2d ( #802 )
...
* im2col implementation for conv2d.
* Fix for the im2col implementation to match the current conv2d.
* Small optimization.
* Add a cuda kernel.
* Handle arbitrary layouts.
* Im2Col cuda code.
2023-09-10 21:02:42 +01:00
4f18180fc7
Bugfix so that im2col produce the same results as conv2d. ( #801 )
2023-09-10 16:59:46 +01:00
559944146f
Add an im2col based benchmark. ( #800 )
...
* Add an im2col based benchmark.
* Reshape the final result.
2023-09-10 16:56:28 +01:00
b7cd58473b
TinyViT backbone for segment-anything. ( #787 )
...
* TinyViT.
* More TinyViT.
* Add more to the tinyvit backbone.
* Proper padding.
* Plus ViT.
* Add the tiniest vit spec.
2023-09-09 15:10:06 +01:00
7396b8ed1a
Segment Anything - process images ( #766 )
...
* Start processing images.
* Add LayerNorm2d.
* Properly use LayerNorm2d.
* Tweak eps.
* Use LayerNorm on inputs with a rank different from 3.
* Window partitioning.
* Fix a couple todos.
* More todos.
* Hard-code the einsums.
* More padding support.
* Some sizes tweaks.
* Use the hub to get the weights.
* Use a batch matmul.
* Tweaks.
* More fixes.
* Get some predictions to be generated.
2023-09-07 19:22:45 +01:00
8c991df394
More segment-anything. ( #763 )
...
* More segment-anything.
* Split the model in multiple files.
* Start adding the transformer.
* Add the attention block.
* Move the MLP Block.
2023-09-07 07:28:30 +01:00
000fa00e31
Expose the conv2d-transpose layers. ( #761 )
2023-09-07 06:04:52 +01:00
a17a7c42c1
Add a nn layer for conv-transpose2d. ( #760 )
2023-09-07 05:47:28 +01:00
bdc9d46fe3
Use an arc in the varbuilder rather than rc. ( #757 )
...
* Use an arc in the varbuilder rather than rc.
* Require the backends to be send.
* Request send and sync.
2023-09-06 15:29:09 +01:00
a0d65585db
Softmax implementation for cuda. ( #747 )
2023-09-05 18:38:03 +01:00
6615daf242
Tweaks to softmax. ( #745 )
2023-09-05 15:22:27 +01:00
1c9e5394a5
Add a custom softmax implementation. ( #744 )
...
* Add a custom softmax implementation.
* Add softmaxlastdim to the benchmarks.
* And add a test.
* Support more dtypes.
* Polish the code.
* Use the slow implementation on cuda.
* Add a todo for the cuda kernel.
2023-09-05 14:20:23 +01:00
4698eb5cb6
Fix typo in the nll function document ( #742 )
2023-09-05 09:25:11 +01:00
e2f9f60ac2
Avoid some redundant clone. ( #731 )
2023-09-04 09:18:32 +02:00
26cd266e65
Musicgen text embeddings. ( #726 )
...
* Musicgen text embeddings.
* Bugfix for layer norm.
* Proper position bias.
* Expose the weights.
2023-09-03 18:27:48 +01:00
74a82c358a
Add the mse loss. ( #723 )
2023-09-03 10:51:40 +01:00
af552a5274
Fix the rnn tests for accelerate. ( #704 )
2023-09-01 13:21:38 +01:00
7529531056
Add the optimizer trait. ( #702 )
2023-09-01 12:55:39 +01:00
f9f482d4e5
Add some doc to the varbuilder. ( #700 )
2023-09-01 08:28:35 +01:00
9736236175
Allow retrieving and setting prefix of VarBuilder ( #699 )
2023-09-01 08:08:41 +01:00
db59816087
Add a GRU layer. ( #688 )
...
* Add a GRU layer.
* Fix the n gate computation.
2023-08-31 08:43:10 +01:00
d210c71d77
Set the learning rate. ( #687 )
2023-08-31 08:03:40 +01:00
eaf760a751
Add a python variant for the lstm test. ( #682 )
2023-08-30 22:32:08 +01:00
21e1c73892
Add a LSTM test. ( #681 )
...
* Add a LSTM test.
* Clippy.
2023-08-30 20:05:42 +02:00
2047d34b7c
More robust tests (so that they pass on accelerate). ( #679 )
2023-08-30 18:10:10 +01:00
3159982a89
Add a Dropout layer ( #676 )
...
* Add a dropout layer.
* Add an actual layer.
2023-08-30 16:19:28 +01:00
ad8a62dbf5
Add tanh. ( #675 )
...
* Add tanh.
* Use tanh in the lstm block.
* Add a test for tanh forward and backward passes.
2023-08-30 13:54:50 +01:00
f35b9f6baa
Add some recurrent neural networks ( #674 )
...
* Add the rnn module.
* More LSTM.
* Implement the RNN forward pass.
* More forward pass for LSTM.
2023-08-30 13:27:09 +01:00
618f4e4c78
Add some documentation. ( #673 )
...
* Add some documentation.
* Bump the crate version.
2023-08-30 11:54:00 +01:00
2d3fcad267
Simplify usage of the pool functions. ( #662 )
...
* Simplify usage of the pool functions.
* Small tweak.
* Attempt at using apply to simplify the convnet definition.
2023-08-29 19:12:16 +01:00
a044907ffc
Dilated convolutions ( #657 )
...
* Add the dilation parameter.
* Restore the basic optimizer example.
* Dilation support in cudnn.
* Use the dilation parameter in the cpu backend.
* More dilation support.
* No support for dilation in transposed convolutions.
* Add dilation to a test.
* Remove a print.
* Helper function.
2023-08-29 16:12:11 +01:00
33c23c19b6
Preliminary support for SDXL. ( #647 )
...
* Preliminary support for SDXL.
* More SDXL support.
* More SDXL.
* Use the proper clip config.
* Querying for existing tensors.
* More robust test.
2023-08-29 09:00:04 +01:00
a3f97c143d
Bump the crate version + update CHANGELOG. ( #628 )
2023-08-27 18:17:11 +01:00
4c338b0cd9
VarBuilder cleanup ( #627 )
...
* VarBuilder cleanup.
* Implement the basic varbuilders.
* Add the sharded code.
* Proper support for tensor sharding.
2023-08-27 18:03:26 +01:00
5320aa6b7d
Move the test-utils bits to a shared place. ( #619 )
2023-08-27 09:42:22 +01:00
431051cc32
Add Efficientnet ( #572 )
...
* EfficientNet.
* Complete the efficientnet implementation.
* Improve group handling.
* Get the efficientnet to work.
2023-08-23 18:02:58 +01:00
aba1e90797
Add some group parameter to convolutions. ( #566 )
...
* Add some group parameter to convolutions.
* Avoid some unnecessary groups checks.
* Move the tensor convolution bits.
* Properh handling of groups.
* Bump the crate version.
* And add a changelog.
2023-08-23 12:58:55 +01:00
11c7e7bd67
Some fixes for yolo-v3. ( #529 )
...
* Some fixes for yolo-v3.
* Use the running stats for inference in the batch-norm layer.
* Get some proper predictions for yolo.
* Avoid the quadratic insertion.
2023-08-20 23:19:15 +01:00
a1812f934f
Add a yolo-v3 example. ( #528 )
...
* Add a couple functions required for yolo.
* Add the yolo-v3 example.
* Add minimum and maximum.
* Use the newly introduced maximum.
* Cuda support for min/max + add some testing.
* Allow for more tests to work with accelerate.
* Fix a typo.
2023-08-20 18:19:37 +01:00
e3d2786ffb
Add a couple functions required for yolo. ( #527 )
2023-08-20 17:02:05 +01:00
d2622a8160
Move the VarMap to a separate file ( #525 )
...
* Move the var-map struct in a separate file.
* Fix some typos.
2023-08-20 14:25:07 +01:00
a8f61e66cc
Bump the crates version to 0.1.2. ( #522 )
2023-08-20 08:07:07 +01:00
42e1cc8062
Add a batch normalization layer ( #508 )
...
* Add BatchNormalization.
* More batch-norm.
* Add some validation of the inputs.
* More validation.
2023-08-18 20:05:56 +01:00
c78ce76501
Add a simple Module trait and implement it for the various nn layers ( #500 )
...
* Start adding the module trait.
* Use the module trait.
* Implement module for qmatmul.
2023-08-18 09:38:22 +01:00
13401df4d1
Add an abstract type for RmsNorm. ( #499 )
2023-08-18 08:52:14 +01:00
d32e8199cd
Layer norm tweaks ( #482 )
...
* Add some options to make layer-norm more configurable.
* Add the rms-norm variant.
* Replace the RmsNorm with the shared bits.
2023-08-17 10:07:13 +01:00
8ad4a21ffc
Add a basic optimizer example. ( #454 )
2023-08-15 17:19:18 +01:00
531f23b4d0
Rename vec-dot to vec-ops. ( #449 )
...
* Rename vec-dot to vec-ops.
* Also bump the crate version.
* Add a currently empty readme.
2023-08-15 10:48:57 +01:00
eab54e4490
Fix the tests for mkl. ( #437 )
2023-08-14 08:09:27 +01:00
55e428c8ae
Expose the varmap inner data. ( #411 )
2023-08-11 16:58:56 +01:00