8de0ce6cba
Add more QMMV cuda kernels. ( #2077 )
...
* Add more QMMV cuda kernels.
* Enable the new kernels.
* Adapt the testing.
2024-04-18 08:36:43 +02:00
2817643db9
Add the mmv kernels for small batch sizes. ( #2075 )
...
* Add the mmv kernels for smaller sizes.
* Support more mmv kernels.
* Use the new kernels.
* Fix the call.
* Silly fix.
* Improve the testing.
* Fix for dmmv.
* Add another dedicated test for the batching mmv.
2024-04-16 21:30:51 +02:00
f135b7963d
Fix for the batch dim in the quantized matmul example. ( #2073 )
...
* Fix for the batch dim in the quantized matmul example.
* Enable more tests on cuda.
* Add a test for qmm with a batch.
* Fix the zeros-dim test on metal.
2024-04-15 20:00:28 +02:00
8ad822a983
Add a function to clear the KV cache in falcon. ( #2066 )
...
* Add a function to clear the KV cache in falcon.
* Clippy.
2024-04-15 09:29:25 +02:00
e198bb0816
Handle zero dims in some simple operations. ( #2064 )
...
* Handle zero dims in some simple operations.
* Handle zero-dims in matmul.
* More testing.
2024-04-15 09:18:54 +02:00
f7d5bf5b97
Faster kernels for quantized matmul on cuda ( #2060 )
...
* Hook the quantized matmul cuda kernels.
* Add a (currently broken) test.
* Kernel fixes.
* Fix by transposing the rhs matrix.
* Add the q4-1 kernels.
* Proper block sizes.
* More details in the tests.
2024-04-15 08:32:47 +02:00
c449f65b12
Expose the synchronize function on the generic device. ( #2062 )
2024-04-14 23:02:03 +02:00
db7dbf3071
Add missing bfloat unary strided kernels and fix typo ( #2058 )
2024-04-14 20:01:13 +02:00
53e5380bf6
Add a synchronize method to devices. ( #2055 )
...
* Add a synchronize method to devices.
* Metal version.
2024-04-14 16:32:55 +02:00
4c88c3ce06
Add benchmarks for qmatmul operations ( #2048 )
...
* Add qmatmul bench
* add all dtypes
2024-04-13 12:30:14 +02:00
a4d5a414e3
Support gather on bf16 for metal. ( #2035 )
2024-04-10 12:49:25 +02:00
718671a0d5
Use BufferOffset in metal backend ops. ( #2029 )
...
* Use BufferOffset in the metal backend.
* More BufferOffset usage.
* Use in where-cond.
2024-04-08 09:37:25 +02:00
c5fe4a7f89
Rework the buffer offset logic for metal kernels ( #2028 )
...
* Move the metal kernels utils in a separate module.
* Use the BufferOffset for unary ops.
* Fix clippy lints.
* Use the new BufferOffset.
* Adapt the binary ops.
* Affine.
* More ops (powf, elu, cast).
2024-04-07 22:37:53 +02:00
9fd52b3b71
Handle the batch dimension in quantized MMV on metal. ( #2022 )
2024-04-06 20:02:24 +02:00
ab892274d1
first commit ( #2018 )
2024-04-05 15:20:28 +02:00
c5626b8271
Add support for "sign" on tensors ( #2012 )
...
* add the sign unary operator
* remove uneeded import
* remove uneeded import
* undo formatting
* undo formatting
* remove unnecessary redefintion
* allow gradient to flow through for sign and round
* fix cpu ops to ensure that negzero and positive zero are handled properly
* clippy fixes
* Properly avoid gradient tracking.
* Use a branchless version.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com >
2024-04-04 22:32:47 +02:00
e6a5b82ba6
Fix the matmul layout for accelerate & mkl. ( #2011 )
...
* Fix the matmul layout for accelerate & mkl.
* Reduce the required precision for pow (because of accelerate).
* And a fix the gelu f16 test.
2024-04-04 19:18:03 +02:00
5aebe53dd2
update dtypes checks for several metal operations ( #2010 )
2024-04-04 18:39:06 +02:00
30b145150f
Optimize the gelu f16 opt. ( #2008 )
...
* Optimize the gelu f16 opt.
* And add a test.
2024-04-04 16:28:23 +02:00
8967c46563
Split the cuda error file. ( #2003 )
2024-04-04 08:27:23 +02:00
318d143224
Relax the contiguous check for cuda kernels. ( #2000 )
...
* Relax the contiguous check for cuda kernels.
* Ensure contiguity for RNNs.
* Unrelated fix for segment anything.
* Better error message + allow concatenating empty slices.
2024-04-03 09:02:38 +02:00
08c049def3
Improve the handling of matmul with squeezed layouts. ( #1998 )
...
* Improve the handling of matmul with squeezed layouts.
* Fix for the cuda backend.
* Revert the temporary fix.
2024-04-02 23:17:05 +02:00
308ea070ed
modify access for conv and op to be pub to allow external packages to have custom backends ( #1986 )
2024-04-01 17:44:49 +02:00
318cb82f16
Quantized cuda tweaks. ( #1981 )
...
* Quantized cuda tweaks.
* Add some safety checks.
* Factorize the dequantization bits.
2024-04-01 11:06:42 +02:00
c7557b65dc
Switch the default to using the faster kernels. ( #1978 )
...
* Switch the default to using the faster kernels.
* Add the force-dmmv flag.
2024-04-01 10:00:11 +02:00
cd29c7ccd4
More ggml cuda kernels ( #1977 )
...
* Add more cuda kernels for quantized matmul.
* Add the vec-dot bits.
* Expose the quantized matmul-vec kernels.
* Also include the quantize-q8-1 kernel.
* Glue code for the q8-1 quantization.
* mm-vec product via q8-1 quantization.
* Add a test.
* Add a mm test.
* Get the test to return some sensible results.
* Also test dmmv.
* Fix the launch params.
* Allow for tweaking the force_dmmv parameter while it's experimental.
2024-04-01 00:15:48 +02:00
3144150b8d
Move the tensor-tools binary in a separate crate. ( #1969 )
2024-03-30 15:49:37 +01:00
b190fd8592
Remove some unnecessary calls to contiguous. ( #1968 )
...
* Remove some unnecessary calls to contiguous.
* Slightly improved kv cache concatenation.
2024-03-30 13:22:00 +01:00
efe4a0c84b
Add a print command to tensor-tools. ( #1967 )
...
* Add a print command to tensor-tools.
* Add some flags to tweak the formatting.
2024-03-30 11:34:33 +01:00
665da30487
Backend refactoring. ( #1966 )
...
* Backend refactoring.
* Metal tweaks.
* Move the cudnn module.
2024-03-29 23:02:11 +01:00
7ecbc6d50b
fix minor typo ( #1924 )
2024-03-29 18:09:57 +01:00
b3484e7a5e
Fix for the RWKV models. ( #1955 )
...
* Fix for the RWKV models.
* More general fix + revert the rwkv hack.
* Remove the old hack.
2024-03-28 10:17:38 +01:00
ab86cd37c8
Support i64 in index-select on metal. ( #1951 )
...
* Support i64 in index-select on metal.
* Add some testing of index-select for all dtypes.
2024-03-27 16:30:07 +01:00
a9abde5f93
More flexible matmul contiguity checks. ( #1949 )
...
* More flexible matmul contiguity checks.
* Also relax the checks on the metal side.
2024-03-27 10:59:05 +01:00
66f0a4eeea
Another fix for squeezing. ( #1943 )
2024-03-26 17:05:26 +01:00
f5dfe883d7
Extend supported dtypes for metal (im2col & upsample_2d) ( #1938 )
...
* update im2col dtype implementations
* update dtypes for upsample
2024-03-26 06:48:56 +01:00
cd254074f3
Really unique identifier for metal device ids. ( #1932 )
...
* Really unique identifier for metal device ids.
* Same device.
2024-03-25 11:48:16 +01:00
fdfe8fd129
Preliminary support for inplace ops. ( #1921 )
...
* Preliminary support for inplace ops.
* Add a test.
2024-03-23 14:16:19 +01:00
cc856db9ce
Backwards for ConvTranspose2D ( #1910 )
...
* add documentation for nackprop
* add backwards for ConvTranspose2D
* add test python code to test
2024-03-23 07:05:55 +01:00
fee33b45c2
Add support for strided index-select on Metal ( #1909 )
...
* initial implementation
* use correct index, but still not breaking like it should have...
* fix test
2024-03-22 07:30:02 +01:00
6708870e63
Add the alloc_uninit function. ( #1901 )
...
* Add the alloc_uninit function.
* Dummy metal fix.
* Lazy initialization.
2024-03-22 07:25:23 +01:00
9563a5fee4
Add support for conv_transpose2d on Metal backend ( #1903 )
...
* add support for conv transpose 2d and add bench mark for float types
* update bench calculation
* enable testing all conv operations on metal
2024-03-21 18:08:45 +01:00
ec97c98e81
Async tensor copying. ( #1900 )
2024-03-21 13:09:42 +01:00
74b7f59261
Prepare for the custom-op extension. ( #1892 )
2024-03-21 07:02:20 +01:00
b219903d0f
Cuda backend optimization ( #1886 )
...
* Attempt at making the kernel faster.
* Also adapt the cast kernels.
* Also apply to binary ops.
2024-03-20 18:32:55 +01:00
469635a3eb
Minor cleanup. ( #1885 )
2024-03-20 14:38:27 +01:00
455c42aa72
Avoid copying the data on squeeze and unsqueeze. ( #1884 )
...
* Avoid copying the data on squeeze and unsqueeze.
* Fix the quantized llama example.
* Unrelated fix for the quantized stable-lm example on cuda.
* Fix for mamba on cuda (unrelated to the PR).
2024-03-20 13:04:36 +01:00
2a8679509e
Add support for conv_transpose1d for metal backend ( #1874 )
...
* first attempt
* progress
* integrate into metal backend
* finish and get test passing
* add other dtype support
* update transpose1d dtypes supported
2024-03-19 08:46:58 +01:00
04a61a9c72
Add avg_pool2d metal implementation for the metal backend ( #1869 )
...
* implement metal avg pool 2d
* fixX
* add suggested precision workaround for the accumulator
2024-03-18 18:50:14 +01:00
754fa1e813
Add support for max_pool2d for Metal backend ( #1863 )
...
* first pass at implementation of maxpool2d
* Add definitions for other dtypes
* add tests for other dtypes
* Cosmetic tweaks + re-enable maxpool2d tests for metal.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2024-03-18 08:33:30 +01:00