794 Commits

Author SHA1 Message Date
287013ef28 Add a forward_via_f16 method to the qmatmul op. (#2138) 2024-04-28 20:35:01 +02:00
eb26e2467e Add the cuda dequantize f16 kernels. (#2137)
* Add the cuda dequantize f16 kernels.

* Expose the cuda kernels.

* Add some testing + fix.

* Test the other cases too.

* A few more tests.

* Add an environment variable to enable the dequantize f16 + matmul behavior.
2024-04-28 20:05:05 +02:00
805f3be8e1 Add a sort function. (#2134) 2024-04-28 08:18:04 +02:00
96a48e5cc4 Add argsort. (#2132)
* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
2024-04-27 20:17:35 +02:00
8a05743a21 Add StorageRef. (#2113)
* Add the storage-ref bits.

* Add the metal implementation.
2024-04-23 13:23:27 +02:00
08a15cb79e Update zip requirement from 0.6.6 to 1.1.1 (#2103)
* Update zip requirement from 0.6.6 to 1.1.1

---
updated-dependencies:
- dependency-name: zip
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Fix for the zip crate update.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-22 16:23:27 +02:00
0067fe00a8 Metal Unary: Add benchmarks and process kernels in a tile based fashion (#2056)
* add basic unary bench for sqrt

* process unary commands in tiles of 4

* re-enable all benchmarks

* rename helper to unary

* modify approach to split up tiled and non-tiled operations

* undo bench ignore for other tests

* update tile size to 2

* only perform the optimization on the contiguous even numbered element case
2024-04-21 00:10:33 +02:00
587ee3bb6f Small cleanups to the llama multi-process example. (#2098) 2024-04-20 22:19:46 +02:00
dd78422701 Handle multiple dimensions in metal QMM + two fixes. (#2097) 2024-04-20 18:55:45 +02:00
1690ab45d2 Fix the silu gradient issue on 0. (#2083) 2024-04-18 14:31:41 +02:00
8de0ce6cba Add more QMMV cuda kernels. (#2077)
* Add more QMMV cuda kernels.

* Enable the new kernels.

* Adapt the testing.
2024-04-18 08:36:43 +02:00
2817643db9 Add the mmv kernels for small batch sizes. (#2075)
* Add the mmv kernels for smaller sizes.

* Support more mmv kernels.

* Use the new kernels.

* Fix the call.

* Silly fix.

* Improve the testing.

* Fix for dmmv.

* Add another dedicated test for the batching mmv.
2024-04-16 21:30:51 +02:00
f135b7963d Fix for the batch dim in the quantized matmul example. (#2073)
* Fix for the batch dim in the quantized matmul example.

* Enable more tests on cuda.

* Add a test for qmm with a batch.

* Fix the zeros-dim test on metal.
2024-04-15 20:00:28 +02:00
8ad822a983 Add a function to clear the KV cache in falcon. (#2066)
* Add a function to clear the KV cache in falcon.

* Clippy.
2024-04-15 09:29:25 +02:00
e198bb0816 Handle zero dims in some simple operations. (#2064)
* Handle zero dims in some simple operations.

* Handle zero-dims in matmul.

* More testing.
2024-04-15 09:18:54 +02:00
f7d5bf5b97 Faster kernels for quantized matmul on cuda (#2060)
* Hook the quantized matmul cuda kernels.

* Add a (currently broken) test.

* Kernel fixes.

* Fix by transposing the rhs matrix.

* Add the q4-1 kernels.

* Proper block sizes.

* More details in the tests.
2024-04-15 08:32:47 +02:00
c449f65b12 Expose the synchronize function on the generic device. (#2062) 2024-04-14 23:02:03 +02:00
db7dbf3071 Add missing bfloat unary strided kernels and fix typo (#2058) 2024-04-14 20:01:13 +02:00
53e5380bf6 Add a synchronize method to devices. (#2055)
* Add a synchronize method to devices.

* Metal version.
2024-04-14 16:32:55 +02:00
4c88c3ce06 Add benchmarks for qmatmul operations (#2048)
* Add qmatmul bench

* add all dtypes
2024-04-13 12:30:14 +02:00
a4d5a414e3 Support gather on bf16 for metal. (#2035) 2024-04-10 12:49:25 +02:00
718671a0d5 Use BufferOffset in metal backend ops. (#2029)
* Use BufferOffset in the metal backend.

* More BufferOffset usage.

* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89 Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module.

* Use the BufferOffset for unary ops.

* Fix clippy lints.

* Use the new BufferOffset.

* Adapt the binary ops.

* Affine.

* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
9fd52b3b71 Handle the batch dimension in quantized MMV on metal. (#2022) 2024-04-06 20:02:24 +02:00
ab892274d1 first commit (#2018) 2024-04-05 15:20:28 +02:00
c5626b8271 Add support for "sign" on tensors (#2012)
* add the sign unary operator

* remove uneeded import

* remove uneeded import

* undo formatting

* undo formatting

* remove unnecessary redefintion

* allow gradient to flow through for sign and round

* fix cpu ops to ensure that negzero and positive zero are handled properly

* clippy fixes

* Properly avoid gradient tracking.

* Use a branchless version.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-04-04 22:32:47 +02:00
e6a5b82ba6 Fix the matmul layout for accelerate & mkl. (#2011)
* Fix the matmul layout for accelerate & mkl.

* Reduce the required precision for pow (because of accelerate).

* And a fix the gelu f16 test.
2024-04-04 19:18:03 +02:00
5aebe53dd2 update dtypes checks for several metal operations (#2010) 2024-04-04 18:39:06 +02:00
30b145150f Optimize the gelu f16 opt. (#2008)
* Optimize the gelu f16 opt.

* And add a test.
2024-04-04 16:28:23 +02:00
8967c46563 Split the cuda error file. (#2003) 2024-04-04 08:27:23 +02:00
318d143224 Relax the contiguous check for cuda kernels. (#2000)
* Relax the contiguous check for cuda kernels.

* Ensure contiguity for RNNs.

* Unrelated fix for segment anything.

* Better error message + allow concatenating empty slices.
2024-04-03 09:02:38 +02:00
08c049def3 Improve the handling of matmul with squeezed layouts. (#1998)
* Improve the handling of matmul with squeezed layouts.

* Fix for the cuda backend.

* Revert the temporary fix.
2024-04-02 23:17:05 +02:00
308ea070ed modify access for conv and op to be pub to allow external packages to have custom backends (#1986) 2024-04-01 17:44:49 +02:00
318cb82f16 Quantized cuda tweaks. (#1981)
* Quantized cuda tweaks.

* Add some safety checks.

* Factorize the dequantization bits.
2024-04-01 11:06:42 +02:00
c7557b65dc Switch the default to using the faster kernels. (#1978)
* Switch the default to using the faster kernels.

* Add the force-dmmv flag.
2024-04-01 10:00:11 +02:00
cd29c7ccd4 More ggml cuda kernels (#1977)
* Add more cuda kernels for quantized matmul.

* Add the vec-dot bits.

* Expose the quantized matmul-vec kernels.

* Also include the quantize-q8-1 kernel.

* Glue code for the q8-1 quantization.

* mm-vec product via q8-1 quantization.

* Add a test.

* Add a mm test.

* Get the test to return some sensible results.

* Also test dmmv.

* Fix the launch params.

* Allow for tweaking the force_dmmv parameter while it's experimental.
2024-04-01 00:15:48 +02:00
3144150b8d Move the tensor-tools binary in a separate crate. (#1969) 2024-03-30 15:49:37 +01:00
b190fd8592 Remove some unnecessary calls to contiguous. (#1968)
* Remove some unnecessary calls to contiguous.

* Slightly improved kv cache concatenation.
2024-03-30 13:22:00 +01:00
efe4a0c84b Add a print command to tensor-tools. (#1967)
* Add a print command to tensor-tools.

* Add some flags to tweak the formatting.
2024-03-30 11:34:33 +01:00
665da30487 Backend refactoring. (#1966)
* Backend refactoring.

* Metal tweaks.

* Move the cudnn module.
2024-03-29 23:02:11 +01:00
7ecbc6d50b fix minor typo (#1924) 2024-03-29 18:09:57 +01:00
b3484e7a5e Fix for the RWKV models. (#1955)
* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
2024-03-28 10:17:38 +01:00
ab86cd37c8 Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal.

* Add some testing of index-select for all dtypes.
2024-03-27 16:30:07 +01:00
a9abde5f93 More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
2024-03-27 10:59:05 +01:00
66f0a4eeea Another fix for squeezing. (#1943) 2024-03-26 17:05:26 +01:00
f5dfe883d7 Extend supported dtypes for metal (im2col & upsample_2d) (#1938)
* update im2col dtype implementations

* update dtypes for upsample
2024-03-26 06:48:56 +01:00
cd254074f3 Really unique identifier for metal device ids. (#1932)
* Really unique identifier for metal device ids.

* Same device.
2024-03-25 11:48:16 +01:00
fdfe8fd129 Preliminary support for inplace ops. (#1921)
* Preliminary support for inplace ops.

* Add a test.
2024-03-23 14:16:19 +01:00
cc856db9ce Backwards for ConvTranspose2D (#1910)
* add documentation  for nackprop

* add backwards for ConvTranspose2D

* add test python code to test
2024-03-23 07:05:55 +01:00
fee33b45c2 Add support for strided index-select on Metal (#1909)
* initial implementation

* use correct index, but still not breaking like it should have...

* fix test
2024-03-22 07:30:02 +01:00