Commit Graph

1769 Commits

Author SHA1 Message Date
b4cb982e49 Simplifying our internal cargo dependencies. (#1529) 2024-01-07 12:04:14 +01:00
6ebe043273 Merge branch 'main' into ivarflakstad/metal-prng 2024-01-07 11:52:03 +01:00
6bf52b9fdf Gaussian normal distribution of PRNG via Box-Muller transform 2024-01-07 11:39:46 +01:00
84250bf52f fix index_pos bug when kv cache is disabled. (#1517)
* fix index_pos bug when kv cache is disabled

* Tweak the fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-06 11:43:01 +01:00
8d1a57c9a0 chore: update flash attention kernels (#1518)
* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
2024-01-05 18:28:55 +01:00
955e63c803 Implement hybrid Tausworthe + LCG psuedo random number generator in metal 2024-01-05 13:27:59 +01:00
3a7304cb0d add link to gpt-from-scratch-rs (#1525) 2024-01-05 11:59:46 +01:00
fa3ea98ba9 Adding bfloat16 support for the cast kernels. (#1520) 2024-01-04 12:12:56 +01:00
135ae5f3eb Simplify the one-hot implementation, support arbitrary rank. (#1514)
* Simplify the one-hot implementation, support arbitrary rank.

* More cleanup.
2024-01-01 11:40:17 +01:00
41614b4a9b Add one-hot/cold encoding (#1489)
* add one-hot encoding

* one_hot: improve error handling, use generic to_vecN::<D>

Bails if the index value is equal to or greater than the depth value,
which would result in an out-of-bounds error.

A redundant check is added to ensure the index value does not exceed
the length of the one-hot matrix size, which would also result in an
out-of-bounds error.

Bails if the index value is less than -1. If the index value is -1,
then it ignores the setting of the on_value for the index value. Only
values that are less than -1 are considered errors.

* one-hot: use two generics, one_hot::<I, O>, for input and output data types

Separating the input and output data types allows the input tensor
indices to be a different data type than the output encoded tensor data type.

For example, one_hot::<i64, u8>(...) will take an input tensor of i64 values
and encode the output tensor using u8 values.

The generic I::DTYPE must match the data type of the input indices, otherwise
the method will bail.

Additionally, this method adds an `allow_f64` option to enable the input indices
data type to be f64 values. f64 values are disabled by default.

TODO: indices data type and the generic I data type are currently not compile-time
checked.

* one_hot: remove input generic, use indices dtype matching

This commit removes the to_f64() type cast and explicitly
matches the DType from the input tensor. Currently, only U8,
U32 and I64 is supported for input tensors.

The match arms on the dtype is verbose. It would be nice
to use a generic type with the WithDtype traitbound to
pass to the to_vecN method and then return an inner value.

Open to suggestions for better approaches here to reduce
the match arm verbosity.

* one_hot: use flat_map iterator over dims instead of nested for loop

This commit replaces the nested for loops with an flat map iter over
the dimensions of the input tensor.

This commit also adds a test for a rank 3 input tensor.

* one_hot: use mandatory on/off-values, remove const msgs

This commit also updates doc tests, comments and test cases.

* Small cleanups.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-01 11:18:40 +01:00
03ce8caf40 Format properly the Stable Diffusion example run with params (#1511)
Move out the --sd-version flag out of the prompt.
2024-01-01 11:13:35 +01:00
b0fe5e4453 Do not implement Module for BatchNorm. (#1513) 2024-01-01 10:13:13 +01:00
1fb2dd905c Add support for tiny-llama-1.1b. (#1512) 2023-12-31 12:18:25 +01:00
a0facd0e67 Small tweaks to batch-norm. (#1505) 2023-12-30 17:06:07 +01:00
4290b81244 [Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average

* Add more checks to batch norm

* Resolve some review comments

* Add with_momentum varients of `new` methods

* Add check for range of momentum variable; update batch norm test

* Run cargo fmt

* Add back num_features parameter

* Format; tiny simplification
2023-12-30 16:42:08 +01:00
51e577a682 Add Policy Gradient to Reinforcement Learning examples (#1500)
* added policy_gradient, modified main, ddpg and README

* fixed typo in README

* removed unnecessary imports

* small refactor

* Use clap for picking up the subcommand to run.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-12-30 09:01:29 +01:00
0a245e6fa4 Metal: support unary abs (#1503)
* Metal: support unary abs

* cargo fmt
2023-12-30 00:00:12 +01:00
87d7f81b43 Metal: more u8/u32 (#1502)
* Adds more metal u8

* Metal: more u32
2023-12-29 23:56:21 +01:00
4373534d59 Metal: i64 basic support (#1495)
* Adds basic metal i64 support

* metal copy i64
2023-12-29 19:42:50 +01:00
f4a2787217 Merge pull request #1498 from huggingface/debugging_windows_ci
Fix CI
2023-12-29 12:33:50 +01:00
488e02a3f6 Merge pull request #1496 from bayedieng/unary
Implement urecip op for metal backend
2023-12-29 12:20:52 +01:00
adc95ca2bf Ignore skipped. 2023-12-29 12:15:57 +01:00
4907c63ea1 Ignore stop on remote forks. 2023-12-29 12:12:10 +01:00
d76ac20e0e Fix. 2023-12-29 12:06:38 +01:00
f5c98f22c7 Merge pull request #1491 from mimiquate/metal-errors
Improves metal's not implemented error messages
2023-12-29 12:03:40 +01:00
5b12fbb143 Trying to fix flakyness by making hub_2 and hub_3 serial tests (potential issue on mingw with mmap). 2023-12-29 11:13:33 +01:00
cc06ba2294 fix bad pattern matching and function name 2023-12-29 09:46:24 +00:00
a6bd0b47a5 Fix the CI. 2023-12-29 10:17:52 +01:00
b59b1b2bb6 remove generated png 2023-12-28 21:50:58 +00:00
3922b42c18 add urecip op to metal backend 2023-12-28 21:50:12 +00:00
1e442d4bb9 Fix lints for clippy 1.75. (#1494) 2023-12-28 20:26:20 +01:00
cd889c0f8a add config_amazon_mistral_lite (#1493)
Co-authored-by: Ubuntu <danielclough@users.noreply.github.com>
2023-12-28 19:59:58 +01:00
8e93e76a91 fixes error message 2023-12-28 15:03:05 -03:00
b3e838f3e2 cargo fmt 2023-12-28 14:07:34 -03:00
8bf892403a Improves metal's not implemented error messages 2023-12-28 11:04:06 -03:00
d35f0a1376 Bump the crate version to 0.3.3. (#1490) 2023-12-28 13:38:30 +01:00
65cb90bd40 Add some mention to SOLAR-10.7B in the readme. (#1487) 2023-12-27 15:25:39 +01:00
996a7f2e24 Rework the llama example config, add the solar model. (#1485) 2023-12-26 22:24:04 +01:00
3071ea6c3e Use the new hub helper function. (#1484) 2023-12-26 09:44:30 +01:00
37c539f2b7 Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example.

* Add a helper function to load sharded safetensors weights.

* Use the sharded loader.
2023-12-25 21:49:21 +01:00
eae3a20d43 Merge pull request #1479 from huggingface/upsample_metal
Adding upsample_nearest_2d.
2023-12-25 14:25:53 +01:00
13a5d15ebc Adding upsample_nearest_2d. 2023-12-25 14:25:19 +01:00
1505d85276 Merge pull request #1461 from huggingface/metal-conv
Adding the convolutions (1d + 2d) to candle on metal.
2023-12-25 12:48:09 +01:00
95e18ef675 Fixing matmul for convolutions. 2023-12-25 12:29:34 +01:00
7135791dd5 Fix the quantized mistral example. (#1478) 2023-12-25 09:31:24 +01:00
88589d8815 Support mistral instruct v0.2. (#1475)
* Support mistral instruct v0.2.

* Use the safetensors model now that they are available.
2023-12-23 16:18:49 +01:00
5b35fd0fcf MMLU evaluation for Phi. (#1474)
* MMLU evaluation for Phi.

* Improve the evaluation.
2023-12-23 15:28:36 +01:00
ba1fae590e Validate the kernel size in pooling ops. (#1473)
* Validate the kernel size in pooling ops.

* Revert the changes to basics.
2023-12-23 11:19:22 +01:00
78d982e1bd Fix for mamba 2.8b. (#1472) 2023-12-23 11:01:39 +01:00
d8b9a727fc Support different mamba models. (#1471) 2023-12-23 10:46:02 +01:00