196765e995
Use the new rope kernel in mistral. ( #1937 )
...
* Use the new rope kernel in mistral.
* Compute the cos and sin with full precision.
* Bugfix.
2024-03-25 23:26:05 +01:00
60676780a9
Fix detail in new RoPE implementation ( #1935 )
2024-03-25 18:20:09 +01:00
d3a8d291d5
Avoid the attention mask where possible. ( #1933 )
2024-03-25 15:31:04 +01:00
cd254074f3
Really unique identifier for metal device ids. ( #1932 )
...
* Really unique identifier for metal device ids.
* Same device.
2024-03-25 11:48:16 +01:00
e7f8e72588
Contiguous variant of the rope kernel. ( #1929 )
...
* Contiguous variant of the rope kernel.
* Add the cuda kernel.
* Metal kernel.
2024-03-25 09:11:20 +01:00
1b98f84a2b
Fast kernels for rotary embeddings. ( #1928 )
...
* Fast kernels for rotary embeddings.
* Add a test for the fast CPU kernel.
* Rope cuda bindings.
* Cuda kernel.
* Metal kernel (part 1).
* Cuda kernels.
* Finish the metal kernel.
* Use the new kernels in the quantized example.
* Fix warning.
2024-03-24 22:48:52 +01:00
cf7d7fcf2f
Also avoid the mask in the llama example.
2024-03-24 19:04:32 +01:00
8c0db87992
Avoid using the attn mask when not necessary.
2024-03-24 18:55:56 +01:00
e2b4829531
Support more mistral models. ( #1927 )
...
* Support more mistral models.
* Use the appropriate rope parameter.
2024-03-24 08:04:04 +01:00
5e70821dd0
Allow for arbitrary temperature modifications.
2024-03-23 15:47:39 +01:00
a62a97340c
Add topk sampling. ( #1923 )
2024-03-23 15:26:09 +01:00
fdfe8fd129
Preliminary support for inplace ops. ( #1921 )
...
* Preliminary support for inplace ops.
* Add a test.
2024-03-23 14:16:19 +01:00
790037390c
Add cast_bf16_x/cast_x_bf16 when CUDA_ARCH<800 but CUDA_VERSION >= 11000 ( #1919 )
...
- it make possible to load bf16 models on T4(sm75)
2024-03-23 13:44:10 +01:00
6f877592a7
Avoid broadcasting on the batch dimension for the attention mask. ( #1920 )
2024-03-23 13:08:53 +01:00
cc856db9ce
Backwards for ConvTranspose2D ( #1910 )
...
* add documentation for nackprop
* add backwards for ConvTranspose2D
* add test python code to test
2024-03-23 07:05:55 +01:00
fc1fe5e45b
Support scatter/index_add with i64 indices for f16 ( #1915 )
2024-03-22 11:51:41 +01:00
32f567bac4
Fix loading the gguf files. ( #1913 )
2024-03-22 10:28:38 +01:00
fee33b45c2
Add support for strided index-select on Metal ( #1909 )
...
* initial implementation
* use correct index, but still not breaking like it should have...
* fix test
2024-03-22 07:30:02 +01:00
6708870e63
Add the alloc_uninit function. ( #1901 )
...
* Add the alloc_uninit function.
* Dummy metal fix.
* Lazy initialization.
2024-03-22 07:25:23 +01:00
a00e24d752
Improve the error message on overlong prompts. ( #1908 )
2024-03-21 21:08:07 +01:00
c07e4057ab
Fix for the llama model. ( #1906 )
2024-03-21 19:36:10 +01:00
c0bdd9c7a6
Use the fast RmsNorm in the quantized model. ( #1904 )
2024-03-21 18:49:35 +01:00
9563a5fee4
Add support for conv_transpose2d on Metal backend ( #1903 )
...
* add support for conv transpose 2d and add bench mark for float types
* update bench calculation
* enable testing all conv operations on metal
2024-03-21 18:08:45 +01:00
ec97c98e81
Async tensor copying. ( #1900 )
2024-03-21 13:09:42 +01:00
bb3ee48039
whisper readme ( #1899 )
2024-03-21 12:54:09 +01:00
0c11e055be
support distil-large-v3 ( #1898 )
2024-03-21 11:46:49 +01:00
18036c6ccb
Update the image crate + use the re-exported version. ( #1893 )
...
* Update the image crate + use the re-exported version.
* Update to using ab_glyph.
2024-03-21 10:56:41 +01:00
0fddec762e
RmsNorm kernel for metal. ( #1895 )
...
* RmsNorm kernel for metal.
* Wrapper for the metal kernel.
* Get the ops to actually work.
* Fix, get the tests to pass.
2024-03-21 09:48:56 +01:00
74b7f59261
Prepare for the custom-op extension. ( #1892 )
2024-03-21 07:02:20 +01:00
af7f8b87d3
Custom op for RmsNorm ( #1890 )
...
* Trying out a custom RmsNorm cuda kernel.
* CPU implementation for rms-norm.
* Cuda wrappers.
* Add some validation.
* Add some testing.
* More testing.
2024-03-21 06:36:28 +01:00
b219903d0f
Cuda backend optimization ( #1886 )
...
* Attempt at making the kernel faster.
* Also adapt the cast kernels.
* Also apply to binary ops.
2024-03-20 18:32:55 +01:00
469635a3eb
Minor cleanup. ( #1885 )
2024-03-20 14:38:27 +01:00
455c42aa72
Avoid copying the data on squeeze and unsqueeze. ( #1884 )
...
* Avoid copying the data on squeeze and unsqueeze.
* Fix the quantized llama example.
* Unrelated fix for the quantized stable-lm example on cuda.
* Fix for mamba on cuda (unrelated to the PR).
2024-03-20 13:04:36 +01:00
2a8679509e
Add support for conv_transpose1d for metal backend ( #1874 )
...
* first attempt
* progress
* integrate into metal backend
* finish and get test passing
* add other dtype support
* update transpose1d dtypes supported
2024-03-19 08:46:58 +01:00
143c481c20
Expose candle gather op in pyo3. ( #1870 )
2024-03-18 21:54:15 +01:00
f115895b9e
Apply rustfmt. ( #1873 )
2024-03-18 21:43:31 +01:00
90fc82211f
Use a common with_tracing::RmsNorm in a few models. ( #1871 )
...
* Add RmsNorm with tracing.
* Use with_tracing::RmsNorm in some models.
2024-03-18 21:40:06 +01:00
6a966cf9e0
Add a DQN example to the reinforcement-learning section ( #1872 )
2024-03-18 21:22:53 +01:00
04a61a9c72
Add avg_pool2d metal implementation for the metal backend ( #1869 )
...
* implement metal avg pool 2d
* fixX
* add suggested precision workaround for the accumulator
2024-03-18 18:50:14 +01:00
58605252e8
Microphone support for the encodec example. ( #1866 )
2024-03-18 11:19:46 +01:00
d365ef32d9
Improve the encodec example: handle resampling. ( #1865 )
...
* Improve the encodec example: handle resampling.
* Play the audio directly.
2024-03-18 10:09:40 +01:00
754fa1e813
Add support for max_pool2d for Metal backend ( #1863 )
...
* first pass at implementation of maxpool2d
* Add definitions for other dtypes
* add tests for other dtypes
* Cosmetic tweaks + re-enable maxpool2d tests for metal.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com >
2024-03-18 08:33:30 +01:00
184105792f
add test for index add and add missing match statements ( #1862 )
2024-03-17 22:19:12 +01:00
a15f859ab4
Fix for the encodec example. ( #1861 )
2024-03-17 21:15:12 +01:00
e316cb6997
add support for casting between all datatypes ( #1860 )
2024-03-17 20:55:11 +01:00
ce9fbc3682
Optimize the cat operation on contiguous tensors ( #1855 )
...
* Add a specialized kernel for copy2d.
* Move the cat operations.
* Avoid transpositions in cat.
* Bugfix.
* Bugfix for the cuda kernel.
* Add a benchmark.
* Add more testing.
* Test fix.
* Faster kernel.
* Add the missing kernel.
* Tweak the test.
* Add a metal kernel.
* Fix for the metal kernel.
* Get the tests to pass on metal.
* Also use this opportunity to fix the metal kernel for ELU.
* Add some bf16 kernels.
* Clippy fixes.
2024-03-17 10:49:13 +01:00
db8b24ae92
Add support for index u8/i64 and input f16/bf16 scatter-add on metal ( #1849 )
...
* add support and tests for scatter add on metal
* add support for all datatypes
2024-03-17 08:09:43 +01:00
74bf6994b1
Move the image tensor to the appropriate device. ( #1856 )
2024-03-16 22:25:46 +01:00
cdc4c172c4
Implement the error trait for DTypeParseError. ( #1852 )
2024-03-15 08:37:27 +01:00
e1f9c3776d
StableLM-2 models were updated to use GPT-2 tokenization. ( #1847 )
2024-03-14 21:01:36 +01:00