c12db594e3
fix typo ( #2606 )
2024-11-23 08:40:00 +01:00
0ed24b9852
Add max-all/min-all. ( #2616 )
2024-11-14 21:08:04 +01:00
dcd83336b6
Testcases ( #2567 )
2024-10-17 13:00:45 +02:00
382c6b51af
Improve error message ( #2485 )
2024-09-20 07:11:41 -06:00
d3fe989d08
Add documentation examples for Tensor::i
and Tensor::narrow
methods ( #2308 )
...
* Add documentation examples for `Tensor` methods
* Apply fmt.
* Cosmetic tweaks.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2024-08-10 08:11:09 +02:00
bd80078acf
Fix log_sum_exp to handle large positive/negative inputs ( #2367 )
2024-08-01 10:37:02 +02:00
0f5cbb08b3
Add support for Llama 3.1 ( #2359 )
...
* Add Llama 3.1 rope
* Clippy
* Format
* Clippy
* Add support for multiple eos tokens:
* Untagged either
* Remove either dep and fix settings.json
* Make the max positional embeddings configurable
2024-07-26 21:32:26 +02:00
8a05743a21
Add StorageRef. ( #2113 )
...
* Add the storage-ref bits.
* Add the metal implementation.
2024-04-23 13:23:27 +02:00
e198bb0816
Handle zero dims in some simple operations. ( #2064 )
...
* Handle zero dims in some simple operations.
* Handle zero-dims in matmul.
* More testing.
2024-04-15 09:18:54 +02:00
c5626b8271
Add support for "sign" on tensors ( #2012 )
...
* add the sign unary operator
* remove uneeded import
* remove uneeded import
* undo formatting
* undo formatting
* remove unnecessary redefintion
* allow gradient to flow through for sign and round
* fix cpu ops to ensure that negzero and positive zero are handled properly
* clippy fixes
* Properly avoid gradient tracking.
* Use a branchless version.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com >
2024-04-04 22:32:47 +02:00
a9abde5f93
More flexible matmul contiguity checks. ( #1949 )
...
* More flexible matmul contiguity checks.
* Also relax the checks on the metal side.
2024-03-27 10:59:05 +01:00
fdfe8fd129
Preliminary support for inplace ops. ( #1921 )
...
* Preliminary support for inplace ops.
* Add a test.
2024-03-23 14:16:19 +01:00
6708870e63
Add the alloc_uninit function. ( #1901 )
...
* Add the alloc_uninit function.
* Dummy metal fix.
* Lazy initialization.
2024-03-22 07:25:23 +01:00
74b7f59261
Prepare for the custom-op extension. ( #1892 )
2024-03-21 07:02:20 +01:00
455c42aa72
Avoid copying the data on squeeze and unsqueeze. ( #1884 )
...
* Avoid copying the data on squeeze and unsqueeze.
* Fix the quantized llama example.
* Unrelated fix for the quantized stable-lm example on cuda.
* Fix for mamba on cuda (unrelated to the PR).
2024-03-20 13:04:36 +01:00
ce9fbc3682
Optimize the cat operation on contiguous tensors ( #1855 )
...
* Add a specialized kernel for copy2d.
* Move the cat operations.
* Avoid transpositions in cat.
* Bugfix.
* Bugfix for the cuda kernel.
* Add a benchmark.
* Add more testing.
* Test fix.
* Faster kernel.
* Add the missing kernel.
* Tweak the test.
* Add a metal kernel.
* Fix for the metal kernel.
* Get the tests to pass on metal.
* Also use this opportunity to fix the metal kernel for ELU.
* Add some bf16 kernels.
* Clippy fixes.
2024-03-17 10:49:13 +01:00
8013b50829
Add grads for interpolate1d ( #1742 )
...
* add backprop for interpolate1d
* fix clippy lint
* correct fix clippy lint
2024-02-22 08:44:01 +01:00
b60064780d
feat: add silu activation function ( #1706 )
...
* feat: add silu activation function
* use silu/arg in grad
* update candle-nn
* use node
2024-02-14 10:27:22 +01:00
ad73e93da2
Detach the tensors on batch-norm eval. ( #1702 )
...
* Detach the tensors on batch-norm eval.
* Fix pyo3 bindings.
* Black tweak.
* Formatting.
* Also update the pyo3-onnx formatting.
* Apply black.
2024-02-13 14:26:32 +01:00
b545f54a19
Fix clippy lints. ( #1667 )
2024-02-06 09:03:36 +01:00
982722019b
add roll function to tensor ( #1666 )
2024-02-06 08:49:45 +01:00
e6d86b0819
Add the pow operator. ( #1583 )
...
* Add the pow operator.
* Support the pow operation in onnx.
2024-01-13 20:24:06 +01:00
ba1fae590e
Validate the kernel size in pooling ops. ( #1473 )
...
* Validate the kernel size in pooling ops.
* Revert the changes to basics.
2023-12-23 11:19:22 +01:00
9fc210fae8
Merge pull request #1318 from huggingface/metal4
...
Starting to fix some tests.
2023-12-20 15:37:31 +01:00
96f1a28e39
Add a simple full method. ( #1455 )
...
* Add a simple implementation of the full method.
* Add the docstring.
2023-12-17 20:15:57 -05:00
1e86717bf2
Fix a couple typos ( #1451 )
...
* Mixtral quantized instruct.
* Fix a couple typos.
2023-12-17 05:20:05 -06:00
77197379cc
More cleanup.
2023-12-15 11:17:05 +01:00
87dc559817
Lots of updates including some stack of command buffers.
2023-12-12 17:41:56 +01:00
77252ffb82
Add logsumexp function ( #1424 )
2023-12-12 10:32:17 -06:00
18eb87f25f
Upsample grad ( #1420 )
...
* encode size of upsample in enum
* working convolution method for limited 2d kernels
* add test for sf 3 interpolation
* add higher dimensional tests, fix to work with multichannel input
* Remove commented out line.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2023-12-10 08:43:24 +01:00
481c45d78d
Add a basic implementation for slice-assign. ( #1377 )
2023-11-26 17:31:22 +00:00
2813fb5dbc
Cleanup fixed a few ops removed debugging scaffolding.
2023-11-20 14:12:57 +01:00
d46670f7c0
Tmp state.
2023-11-20 14:12:57 +01:00
df6814f34e
Refactor to simplify our lives for settings the params in the encoder.
2023-11-20 14:12:57 +01:00
c6763e3b41
Add a simple implementation of cumsum. ( #1334 )
...
* Add a simple implementation of cumsum.
* Add another test.
2023-11-15 21:11:15 +00:00
347e31c9ff
Add the tril/triu/eye ops. ( #1333 )
...
* Add tril/triu/eye.
* Revert the metal crate tweak.
2023-11-15 20:34:37 +00:00
9e666d4229
Add the var method. ( #1315 )
...
* Add the var method.
* Add a test.
2023-11-10 22:47:57 +01:00
26c4e5bf1d
Metal part 1 - Scaffolding for metal. ( #1308 )
...
* Metal part 1 - Scaffolding for metal.
* Remove tracing.
2023-11-10 08:35:48 +01:00
a773a4b22b
[ONNX] Support a couple more ops. ( #1284 )
...
* Support the shape op in ONNX.
* Share the axis normalization bits.
* Add some limited support for gather.
* Unsqueeze.
* Comparison with broadcasting.
* Add Not + handle i32.
2023-11-06 22:44:58 +01:00
fbd69f952c
Lazy detach. ( #1242 )
2023-11-02 07:33:48 +00:00
5fc66bd4ba
Support negative steps in arange. ( #1218 )
2023-10-30 07:40:54 +00:00
55bc3382cf
Allow for different behavior between training and eval ( #1213 )
...
* Forward with training.
* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
9b1158b315
Add some missing backtraces. ( #1193 )
2023-10-27 06:09:11 +01:00
c698e17619
Enable the test for meshgrid + fix the implementation. ( #1175 )
2023-10-25 13:47:54 +01:00
e4c9adfdbe
Implemented meshgrid ( #1174 )
...
* Implemented meshgrid
* Resolved feedback from LaurentMazare
* Rustfmt
* Updated docstring
* Removed outdated error mode from docstring
2023-10-25 12:49:11 +01:00
62fc965617
Expose the track-op method. ( #1148 )
2023-10-22 06:57:03 +01:00
e8f760ee44
Add get_on_dim. ( #1142 )
2023-10-21 15:01:38 +01:00
87eb1658e1
Add pad_with_same. ( #1127 )
...
* More model cloning.
* More cloning on quantized models.
* Add pad-with-same.
* Add some tests.
2023-10-18 23:13:37 +01:00
662c186fd5
Better error message when overflowing in narrow. ( #1119 )
2023-10-18 08:40:14 +01:00
37dbbff261
Use full tensors for zeros and ones ( #1071 )
...
* Only optimize float tensors.
* Use full tensors for zeros and ones.
2023-10-11 08:16:04 +01:00