Commit Graph

131 Commits

Author SHA1 Message Date
87dc559817 Lots of updates including some stack of command buffers. 2023-12-12 17:41:56 +01:00
481c45d78d Add a basic implementation for slice-assign. (#1377) 2023-11-26 17:31:22 +00:00
2813fb5dbc Cleanup fixed a few ops removed debugging scaffolding. 2023-11-20 14:12:57 +01:00
d46670f7c0 Tmp state. 2023-11-20 14:12:57 +01:00
df6814f34e Refactor to simplify our lives for settings the params in the encoder. 2023-11-20 14:12:57 +01:00
c6763e3b41 Add a simple implementation of cumsum. (#1334)
* Add a simple implementation of cumsum.

* Add another test.
2023-11-15 21:11:15 +00:00
347e31c9ff Add the tril/triu/eye ops. (#1333)
* Add tril/triu/eye.

* Revert the metal crate tweak.
2023-11-15 20:34:37 +00:00
9e666d4229 Add the var method. (#1315)
* Add the var method.

* Add a test.
2023-11-10 22:47:57 +01:00
26c4e5bf1d Metal part 1 - Scaffolding for metal. (#1308)
* Metal part 1 - Scaffolding for metal.

* Remove tracing.
2023-11-10 08:35:48 +01:00
a773a4b22b [ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.
2023-11-06 22:44:58 +01:00
fbd69f952c Lazy detach. (#1242) 2023-11-02 07:33:48 +00:00
5fc66bd4ba Support negative steps in arange. (#1218) 2023-10-30 07:40:54 +00:00
55bc3382cf Allow for different behavior between training and eval (#1213)
* Forward with training.

* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
9b1158b315 Add some missing backtraces. (#1193) 2023-10-27 06:09:11 +01:00
c698e17619 Enable the test for meshgrid + fix the implementation. (#1175) 2023-10-25 13:47:54 +01:00
e4c9adfdbe Implemented meshgrid (#1174)
* Implemented meshgrid

* Resolved feedback from LaurentMazare

* Rustfmt

* Updated docstring

* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
62fc965617 Expose the track-op method. (#1148) 2023-10-22 06:57:03 +01:00
e8f760ee44 Add get_on_dim. (#1142) 2023-10-21 15:01:38 +01:00
87eb1658e1 Add pad_with_same. (#1127)
* More model cloning.

* More cloning on quantized models.

* Add pad-with-same.

* Add some tests.
2023-10-18 23:13:37 +01:00
662c186fd5 Better error message when overflowing in narrow. (#1119) 2023-10-18 08:40:14 +01:00
37dbbff261 Use full tensors for zeros and ones (#1071)
* Only optimize float tensors.

* Use full tensors for zeros and ones.
2023-10-11 08:16:04 +01:00
7f7d95e2c3 Add the round-to function. (#1039) 2023-10-05 20:28:09 +01:00
c18a856e76 Add the rounding operators. (#1030)
* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
2023-10-04 17:58:44 +01:00
01b92cd959 fixes slice_scatter dim type (#988) 2023-09-29 07:54:45 +01:00
8601537e31 Add slice-scatter. (#927)
* Add slice-scatter.

* Add the op.

* Make transpose be a no-op when the dimensions are identical.

* Add the backprop.

* And add some gradient test.
2023-09-22 12:18:16 +01:00
7b26e513f1 Add the erf function. (#917) 2023-09-21 06:19:10 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
4f91c8e109 Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
2023-09-19 15:09:47 +01:00
635012d770 Do not backprop through argmin/argmax. (#865) 2023-09-15 22:15:40 +01:00
9a465e1b26 Add 1d upsampling. (#839)
* Add 1d upsampling.

* Add the interpolate functions.
2023-09-13 16:50:39 +01:00
18d3c803a8 Scalar support in minimum/maximum. (#832)
* Scalar support in minimum/maximum.

* Add a clamp method to tensors.
2023-09-13 08:24:58 +01:00
acf8f10ae1 Get the comparison operation to work on scalar values. (#780)
* Get the comparison operation to work on scalar values.

* Add some time measurement.
2023-09-08 20:13:29 +01:00
0e250aee4f Shape with holes (#770)
* Shape with holes.

* rustfmt.
2023-09-08 08:38:13 +01:00
a8410bf35e Add some documentation. (#743) 2023-09-05 09:51:12 +01:00
74a82c358a Add the mse loss. (#723) 2023-09-03 10:51:40 +01:00
30a4b593d7 More ops again. (#697) 2023-08-31 22:28:48 +01:00
949f1eae6f Implement a couple more binary ops. (#693) 2023-08-31 21:30:15 +01:00
ad8a62dbf5 Add tanh. (#675)
* Add tanh.

* Use tanh in the lstm block.

* Add a test for tanh forward and backward passes.
2023-08-30 13:54:50 +01:00
618f4e4c78 Add some documentation. (#673)
* Add some documentation.

* Bump the crate version.
2023-08-30 11:54:00 +01:00
59b731de99 Add the powf op. (#664)
* Add the powf op.

* Cuda kernels and backprop.

* Add a test.
2023-08-29 20:48:18 +01:00
2d3fcad267 Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions.

* Small tweak.

* Attempt at using apply to simplify the convnet definition.
2023-08-29 19:12:16 +01:00
e21c686cdc Fixes for clippy 1.72. (#587) 2023-08-24 17:46:17 +01:00
aba1e90797 Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions.

* Avoid some unnecessary groups checks.

* Move the tensor convolution bits.

* Properh handling of groups.

* Bump the crate version.

* And add a changelog.
2023-08-23 12:58:55 +01:00
11c7e7bd67 Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3.

* Use the running stats for inference in the batch-norm layer.

* Get some proper predictions for yolo.

* Avoid the quadratic insertion.
2023-08-20 23:19:15 +01:00
a1812f934f Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
2023-08-20 18:19:37 +01:00
e3d2786ffb Add a couple functions required for yolo. (#527) 2023-08-20 17:02:05 +01:00
2fcb386f17 Add a broadcast variant to matmul. (#523)
* Add a broadcast variant to matmul.

* Get the test to pass.
2023-08-20 13:20:42 +01:00
cb069d6063 Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch).

* Add the backprop for dimension permutation.
2023-08-18 16:30:53 +01:00
95462c6a2e Add a vision transformer example (dino-v2). (#502)
* Add a vision transformer example (dino-v2).

* Add some documentation + test.

* CI fix.

* Another fix (still unable to replicate the errors locally :( )
2023-08-18 11:58:06 +01:00
03be33eea4 Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp.

* Simplify the custom-ops when no backward is required.
2023-08-17 11:12:05 +01:00