* Hook the quantized matmul cuda kernels.
* Add a (currently broken) test.
* Kernel fixes.
* Fix by transposing the rhs matrix.
* Add the q4-1 kernels.
* Proper block sizes.
* More details in the tests.
* Move the metal kernels utils in a separate module.
* Use the BufferOffset for unary ops.
* Fix clippy lints.
* Use the new BufferOffset.
* Adapt the binary ops.
* Affine.
* More ops (powf, elu, cast).
* add the sign unary operator
* remove uneeded import
* remove uneeded import
* undo formatting
* undo formatting
* remove unnecessary redefintion
* allow gradient to flow through for sign and round
* fix cpu ops to ensure that negzero and positive zero are handled properly
* clippy fixes
* Properly avoid gradient tracking.
* Use a branchless version.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
* Add more cuda kernels for quantized matmul.
* Add the vec-dot bits.
* Expose the quantized matmul-vec kernels.
* Also include the quantize-q8-1 kernel.
* Glue code for the q8-1 quantization.
* mm-vec product via q8-1 quantization.
* Add a test.
* Add a mm test.
* Get the test to return some sensible results.
* Also test dmmv.
* Fix the launch params.
* Allow for tweaking the force_dmmv parameter while it's experimental.
* Avoid copying the data on squeeze and unsqueeze.
* Fix the quantized llama example.
* Unrelated fix for the quantized stable-lm example on cuda.
* Fix for mamba on cuda (unrelated to the PR).
* first attempt
* progress
* integrate into metal backend
* finish and get test passing
* add other dtype support
* update transpose1d dtypes supported
* first pass at implementation of maxpool2d
* Add definitions for other dtypes
* add tests for other dtypes
* Cosmetic tweaks + re-enable maxpool2d tests for metal.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
* Add a specialized kernel for copy2d.
* Move the cat operations.
* Avoid transpositions in cat.
* Bugfix.
* Bugfix for the cuda kernel.
* Add a benchmark.
* Add more testing.
* Test fix.
* Faster kernel.
* Add the missing kernel.
* Tweak the test.
* Add a metal kernel.
* Fix for the metal kernel.
* Get the tests to pass on metal.
* Also use this opportunity to fix the metal kernel for ELU.
* Add some bf16 kernels.
* Clippy fixes.