Commit Graph

453 Commits

Author SHA1 Message Date
8f7973958c fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0

* cargo fmt
2023-10-05 18:46:13 +01:00
c18a856e76 Add the rounding operators. (#1030)
* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
2023-10-04 17:58:44 +01:00
11d3687cc6 Simd128 optimized q8k vecdot. (#1026) 2023-10-03 15:29:48 +01:00
dac73edb34 AVX optimized q8k vecdot. (#1024) 2023-10-03 12:10:58 +01:00
043cc25766 Fix for the index-select cuda setup. (#1022)
* Fix for index-select.

* Better fix + add some testing.
2023-10-03 10:21:46 +01:00
7670fe7d1f neon optimized q8k multiplication. (#1021)
* neon optimized q8k multiplication.

* Bugfixes.

* simdification.
2023-10-02 23:26:34 +01:00
cddfc3944c Add the q8k vec-dot multiplication. (#1019) 2023-10-02 21:53:34 +01:00
089fc3b584 Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
2023-10-02 17:17:46 +01:00
263a172202 Improve the testing of the optimized quantized vec-dot ops (#1016)
* Expose the unopt functions for testing.

* Better testing of the optimized quantized computations.
2023-10-02 09:50:43 +01:00
5130a7da32 Simd128 version of q6k vec-dot. (#1015)
* Add a specific function for the simd128 q6k vec-dot.

* Simdification.

* More simdification.
2023-10-01 19:44:12 +01:00
096dee7073 Bump the version to 0.3.0. (#1014)
* Bump the version to 0.3.0.

* Changelog update.
2023-10-01 13:51:57 +01:00
4e55aaa51f Simd128 version of the q2k-q8k vecdot product. (#1011)
* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.

* Simdify the q2k-q8k vecdot product.

* Cosmetic change.
2023-09-30 20:12:41 +01:00
deee7612da Quantized version of mistral. (#1009)
* Quantized version of mistral.

* Integrate the quantized mistral variant.

* Use the quantized weight files.

* Tweak the quantization command.

* Fix the dtype when computing the rotary embeddings.

* Update the readme with the quantized version.

* Fix the decoding of the remaining tokens.
2023-09-30 18:25:47 +01:00
fc59bc31bf fix: add missing gpu fill_* (#996) 2023-09-29 15:49:30 +01:00
01b92cd959 fixes slice_scatter dim type (#988) 2023-09-29 07:54:45 +01:00
25657804ef Simd128 q2k vecdot (#982)
* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.
2023-09-28 12:16:35 +01:00
5e1c595e00 Optimize the index-select cuda kernel. (#976) 2023-09-28 09:05:29 +01:00
9cb110c44c Sketch a simd128 optimized q4k vecdot. (#977)
* Sketch a simd128 optimized q4k vecdot.

* Simdify.

* More quantization optimizations.

* Again more simdification.

* Simdify the splitting loop.
2023-09-27 20:19:38 +01:00
667f01c173 Simd128 vec-dot for q4_0. (#974)
* Simd128 vec-dot for q4_0.

* Bugfix.

* Add wasm tests.

* Bugfix for the q40 vecdot.

* More quantization tests.
2023-09-27 14:15:30 +01:00
e59784e353 simd128 optimized q8_0 vecdot (#972)
* wasm/simd128 version of the quantized q8_0 vecdot.

* Add the missing conversion.
2023-09-27 11:03:20 +01:00
ce0a4e3a85 Use the gelu-erf activation. (#969) 2023-09-26 22:30:21 +01:00
4abc1ea34d Avoid some overflows on wasm32. (#968) 2023-09-26 11:15:38 +01:00
dc47224ab9 Override the default cudnn heuristics. (#957) 2023-09-25 10:31:53 +01:00
e32c89d90c Add the buffered safetensor wrapper. (#948) 2023-09-23 22:57:42 +01:00
890d069092 Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
ccf352f3d1 Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
912a3d63b0 Use the proper block size for quantizing models. (#933)
* Use the proper block size for quantizing models.

* Use the proper dimension.
2023-09-22 21:36:56 +01:00
8601537e31 Add slice-scatter. (#927)
* Add slice-scatter.

* Add the op.

* Make transpose be a no-op when the dimensions are identical.

* Add the backprop.

* And add some gradient test.
2023-09-22 12:18:16 +01:00
3b557765e8 T5 quantized example (#922)
* Load gguf files for the quantized t5.

* Add the quantized t5 example.

* Allow for loading local files.

* Add some support for quantizing safetensor files.

* Transpose before quantizing.

* Quantized t5.

* Retrieve the weights from the hub.
2023-09-21 12:33:15 +01:00
2619c4307f Add a quantized version of the t5 model. (#921) 2023-09-21 11:13:39 +01:00
7b26e513f1 Add the erf function. (#917) 2023-09-21 06:19:10 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
4f91c8e109 Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
2023-09-19 15:09:47 +01:00
7dd8e12472 Bump the crate versions to v0.2.3. (#886)
* Bump the crate version.

* Also update the python bindings.
2023-09-18 12:14:03 +01:00
635012d770 Do not backprop through argmin/argmax. (#865) 2023-09-15 22:15:40 +01:00
2746f2c4be DiffNeXt/unet (#859)
* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
2023-09-15 10:14:02 +01:00
130fe5a087 Add the upblocks. (#853) 2023-09-14 22:24:56 +01:00
d6447ad635 Tensor based indexing. (#842) 2023-09-14 07:47:07 +01:00
9a465e1b26 Add 1d upsampling. (#839)
* Add 1d upsampling.

* Add the interpolate functions.
2023-09-13 16:50:39 +01:00
b11a2a7b9d Move the constant to avoid some unused warning. (#837) 2023-09-13 11:56:53 +01:00
1c09164021 Add CANDLE_NVCC_CCBIN support for candle-kernels, and eliminate warning. (#836) 2023-09-13 11:39:22 +01:00
18d3c803a8 Scalar support in minimum/maximum. (#832)
* Scalar support in minimum/maximum.

* Add a clamp method to tensors.
2023-09-13 08:24:58 +01:00
2257f4d475 Bump the crate version + update the changelog. (#822) 2023-09-12 06:39:24 +01:00
871efc0307 Bugfix for the conv2d cpu kernel. (#820) 2023-09-11 23:11:27 +01:00
c5a058b169 Use the module trait in stable-diffusion. (#817) 2023-09-11 20:40:07 +01:00
dbd4561416 im2col version of the conv1d kernel. (#815)
* im2col version of the cuda conv1d kernel.

* im2col version of the conv1d cpu kernel.
2023-09-11 14:40:09 +01:00
70f38c2069 Proper error on unsupported dtypes when using gemm. (#813) 2023-09-11 12:10:51 +01:00
df712ecf64 Handle the case where the kernel is not contiguous in the cuda backend. (#809) 2023-09-11 09:48:31 +01:00
6fb665004c Enable im2col on the cpu side. (#805)
* Enable im2col on the cpu side.

* Hook im2col on the cpu backend.

* Use the kernel offset.

* Avoid an unnecessary copy.

* Handle non-contiguous kernels.

* Add a const to select the conv2d kernel.
2023-09-11 09:28:13 +01:00
1cd74129d4 Add Im2Col support on the gpu side. (#808)
* Add Im2Col support on the gpu side.

* Actually enable.
2023-09-11 08:52:33 +01:00