Commit Graph

13 Commits

Author SHA1 Message Date
04a61a9c72 Add avg_pool2d metal implementation for the metal backend (#1869)
* implement metal avg pool 2d

* fixX

* add suggested precision workaround for the accumulator
2024-03-18 18:50:14 +01:00
754fa1e813 Add support for max_pool2d for Metal backend (#1863)
* first pass at implementation of maxpool2d

* Add definitions for other dtypes

* add tests for other dtypes

* Cosmetic tweaks + re-enable maxpool2d tests for metal.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-03-18 08:33:30 +01:00
ce9fbc3682 Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d.

* Move the cat operations.

* Avoid transpositions in cat.

* Bugfix.

* Bugfix for the cuda kernel.

* Add a benchmark.

* Add more testing.

* Test fix.

* Faster kernel.

* Add the missing kernel.

* Tweak the test.

* Add a metal kernel.

* Fix for the metal kernel.

* Get the tests to pass on metal.

* Also use this opportunity to fix the metal kernel for ELU.

* Add some bf16 kernels.

* Clippy fixes.
2024-03-17 10:49:13 +01:00
7ec345c2eb Adding the test scaffolding. 2023-11-20 14:38:35 +01:00
2d3fcad267 Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions.

* Small tweak.

* Attempt at using apply to simplify the convnet definition.
2023-08-29 19:12:16 +01:00
5320aa6b7d Move the test-utils bits to a shared place. (#619) 2023-08-27 09:42:22 +01:00
afd965f77c More non square testing (#582)
* Add more non square testing.

* More testing.
2023-08-24 13:01:04 +01:00
c84883ecf2 Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling.

* Update for the latest tokenizers version.
2023-08-14 13:12:17 +01:00
a094dc503d Add a cuda kernel for avg-pool2d. (#440)
* Add a cuda kernel for avg-pool2d.

* Avoid running out of bounds.

* Finish wiring the avg pool kernel + add some testing.

* Support for max-pool + testing.
2023-08-14 12:32:05 +01:00
a325c1aa50 Upsample test + bugfix. (#399) 2023-08-10 21:02:35 +02:00
a5c5a893aa add max_pool2d (#371)
Co-authored-by: 赵理山 <ls@zhaolishandeMacBook-Air.local>
2023-08-09 18:05:26 +01:00
cd225bd3b1 More testing for avg-pool2d. (#366)
* More testing for avg-pool2d.

* Another fix.

* Add a max-pool test with non-divisible kernel sizes.
2023-08-09 16:12:23 +01:00
b80348d22f Bugfix for avg-pool + add some test. (#365) 2023-08-09 15:44:16 +01:00