Compare commits

..

221 Commits

Author SHA1 Message Date
69c1fb1ee8 Add a benchmark for the matmul slowness. 2023-10-11 15:49:42 +02:00
c55ebaf477 Use full tensors for zeros and ones. 2023-10-11 08:50:43 +02:00
4c91dd2ff4 Only optimize float tensors. 2023-10-10 09:45:49 +02:00
bc3351bce4 Tracing for StableLM and quantized StableLM. (#1068) 2023-10-10 08:09:25 +02:00
b34d7f0248 Remove some unusued bits. (#1067) 2023-10-09 19:49:57 +01:00
4d04ac83c7 Override the repo for SDXL f16 vae weights. (#1064)
* Override the repo for SDXL f16 vae weights.

* Slightly simpler change.
2023-10-09 06:52:28 +01:00
392fe02fba Move the common quantized-nn code to a shared module. (#1063) 2023-10-09 06:22:22 +01:00
59ab6d7832 Quantized version of StableLM. (#1058)
* Quantized version of StableLM.

* Adapt the stable-lm example to support quantizsed.

* Use some separate hub repo.

* Another repo name tweak.
2023-10-08 15:42:38 +01:00
783735cf22 Use softmax-last-dim where possible. (#1057) 2023-10-08 13:16:42 +01:00
9abeddd750 Make the cuda rng seedable. (#1056) 2023-10-08 09:32:36 +01:00
2e5fb0b251 Do not use the kv-cache on external key-value states. (#1054) 2023-10-07 22:37:19 +01:00
823fe23f9b Add flash-attn support for stable-lm. (#1052) 2023-10-07 21:12:54 +01:00
d833527fda Use candle_nn::LSTM in encodec. (#1051)
* Use candle_nn::LSTM in encodec.

* More Encodec implementation.

* Decoder implementation.
2023-10-07 19:43:06 +01:00
a4967600d0 More general seq forward functions for RNNs. (#1050) 2023-10-07 15:08:01 +01:00
aa53368aeb Better control on the optional dequantization in QMatMul (#1049)
* Cosmetic change to the quantized whisper model.

* Fix the dequantization.

* Add the dequantize all variable.
2023-10-07 10:16:18 +01:00
955e00b2e8 Add to the readmes for stable-lm. (#1047) 2023-10-06 21:26:04 +01:00
d5f7267087 Add the stable-lm example. (#1046)
* Add the stable-lm example.

* Get stable-lm to generate some proper text.
2023-10-06 19:20:35 +01:00
904bbdae65 Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
2023-10-06 19:01:07 +01:00
b0442eff8a Sketch the stable-lm model. (#1045) 2023-10-06 18:19:06 +01:00
4631c48273 Remove some todos. (#1042) 2023-10-05 22:42:20 +01:00
716883e9b0 Add the clamping for stable-diffusion. (#1041) 2023-10-05 22:20:39 +01:00
47c25a567b feat: [SAM] able to download the result as png (#1035)
* feat: able to download the result as png

* feat: update function and wording
2023-10-05 22:14:47 +01:00
7f7d95e2c3 Add the round-to function. (#1039) 2023-10-05 20:28:09 +01:00
f47bd9bab5 Delete invalid comment (#1038) 2023-10-05 19:28:08 +01:00
8f7973958c fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0 (#1037)
* fix: fix index_select cuda kernel for src target dim different than ids dim when selecting dim > 0

* cargo fmt
2023-10-05 18:46:13 +01:00
f0c619a4af Use AsRef<str> for set_one. (#1033) 2023-10-05 06:05:44 +01:00
b86ac0c507 Quant t5: Add coedit model to wasm demo and readme (#1031) 2023-10-04 20:57:33 +01:00
27e70a5093 Whisper quantized wasm (#1028)
* [Whisper] Update to use quantized model

* [whisper] add language detection

* [whisper] change assets location

* [whisper] adapt js example with quantized models

* [whisper] better task parsing

* [whisper] minor fixes
2023-10-04 20:22:57 +01:00
c18a856e76 Add the rounding operators. (#1030)
* Add the rounding operators.

* Avoid tracking gradients for the rounding operations.

* Add some rounding tests.
2023-10-04 17:58:44 +01:00
3349c89252 Add quantized t5 args for weight and config (#1029) 2023-10-04 17:02:49 +01:00
11d3687cc6 Simd128 optimized q8k vecdot. (#1026) 2023-10-03 15:29:48 +01:00
dac73edb34 AVX optimized q8k vecdot. (#1024) 2023-10-03 12:10:58 +01:00
b4da19d1be Merge pull request #1023 from evgenyigumnov/simlified-book-polish
small misspeling and polish fix
2023-10-03 12:29:41 +02:00
ff513314fc small misspeling and polish fix 2023-10-03 15:47:04 +06:00
043cc25766 Fix for the index-select cuda setup. (#1022)
* Fix for index-select.

* Better fix + add some testing.
2023-10-03 10:21:46 +01:00
7b06872f90 Merge pull request #926 from evgenyigumnov/book-trainin-simplified
Book train simlified example
2023-10-03 10:41:30 +02:00
65825e7240 [SAM] Add undo button and background point mode (#1020)
* [SAM] Add undo button and background point mode

* [SAM] remove pts on near clicks

* [SAM] check shiftKey toggle point mode

* [SAM] clear points when clearing image
2023-10-02 23:33:46 +01:00
7670fe7d1f neon optimized q8k multiplication. (#1021)
* neon optimized q8k multiplication.

* Bugfixes.

* simdification.
2023-10-02 23:26:34 +01:00
cddfc3944c Add the q8k vec-dot multiplication. (#1019) 2023-10-02 21:53:34 +01:00
089fc3b584 Improve the quantized whisper setup. (#1018)
* Improve the quantized whisper setup.

* Fix the config file paths.

* Use the standard matmul where possible.
2023-10-02 17:17:46 +01:00
e04c789230 Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model.

* Quantized the whisper model.

* Adapt the whisper example to handle quantization.

* Add the quantized flag.

* Load the proper weights.
2023-10-02 14:59:53 +01:00
263a172202 Improve the testing of the optimized quantized vec-dot ops (#1016)
* Expose the unopt functions for testing.

* Better testing of the optimized quantized computations.
2023-10-02 09:50:43 +01:00
638ccf9f46 Fix include code. 2023-10-02 10:22:44 +02:00
0baf5a1e19 Fixed PR warnings. 2023-10-02 10:15:10 +02:00
5130a7da32 Simd128 version of q6k vec-dot. (#1015)
* Add a specific function for the simd128 q6k vec-dot.

* Simdification.

* More simdification.
2023-10-01 19:44:12 +01:00
41143db1af [segment-anything] add multi point logic for demo site (#1002)
* [segment-anything] add multi point logic for demo site

* [segment-anything] remove libs and update functions
2023-10-01 18:25:22 +01:00
096dee7073 Bump the version to 0.3.0. (#1014)
* Bump the version to 0.3.0.

* Changelog update.
2023-10-01 13:51:57 +01:00
f6054e9d60 Fix the prompt for mistral when using instruct/interactive mode. (#1013) 2023-10-01 06:44:30 +01:00
328167ec04 Integrate TheBloke quantized mistral weights. (#1012) 2023-09-30 22:39:42 +01:00
4e55aaa51f Simd128 version of the q2k-q8k vecdot product. (#1011)
* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.

* Simdify the q2k-q8k vecdot product.

* Cosmetic change.
2023-09-30 20:12:41 +01:00
deee7612da Quantized version of mistral. (#1009)
* Quantized version of mistral.

* Integrate the quantized mistral variant.

* Use the quantized weight files.

* Tweak the quantization command.

* Fix the dtype when computing the rotary embeddings.

* Update the readme with the quantized version.

* Fix the decoding of the remaining tokens.
2023-09-30 18:25:47 +01:00
06207332bc Streaming mode for reporting the generated tokens (#1007)
* Token streaming.

* Use the token output stream.

* Flush the output.

* Ensure that the last characters get reported.
2023-09-30 15:04:11 +01:00
4021272875 Use flash-attn for mistral. (#1004) 2023-09-30 12:15:10 +01:00
87e3a4e175 Mistral: exit on eos token. (#1001)
* Mistral: exit on eos token.

* Print the proper stats.

* Also add a short flag.
2023-09-30 07:07:06 +01:00
6203ced495 Add negative prompts to segment-anything. (#1000) 2023-09-30 06:17:42 +01:00
34842fb234 [segment-anything] Print IOU values to help with debugging (#999) 2023-09-30 05:44:42 +01:00
d188d6a764 Fix the multiple points case for sam. (#998) 2023-09-29 22:39:43 +02:00
0ac2db577b Add an entry about WSL slowness to the faq. (#997) 2023-09-29 17:04:52 +01:00
fc59bc31bf fix: add missing gpu fill_* (#996) 2023-09-29 15:49:30 +01:00
03348e2e6f Update mistral README.md (#995) 2023-09-29 12:24:32 +01:00
49fa184a35 Mistral readme (#994)
* Mistral: print the generated text.

* Add mistral to the readmes.
2023-09-29 11:50:50 +01:00
6f17ef82be Mistral: print the generated text. (#992) 2023-09-29 10:56:11 +01:00
01b92cd959 fixes slice_scatter dim type (#988) 2023-09-29 07:54:45 +01:00
53510ce427 Use a silu activation in mistral. (#991) 2023-09-29 07:06:54 +01:00
23b3576c47 Add the sliding window. (#986) 2023-09-28 17:26:33 +01:00
716ab2ccdc Mistral gpu fix (#985)
* Add the mistral example.

* Use the two model files.

* Adjust the dtype.

* Tweak the weight paths.

* Remove the end of text token.

* Get the mistral model to generate some text.

* Fix when running on the gpu.

* More gpu fixes.
2023-09-28 16:38:13 +01:00
ada8851a23 Add the mistral example. (#984)
* Add the mistral example.

* Use the two model files.

* Adjust the dtype.

* Tweak the weight paths.

* Remove the end of text token.

* Get the mistral model to generate some text.
2023-09-28 16:19:18 +01:00
c05a348e36 Add the Mistral 7b model (#983)
* Start sketching the mistral 7b model.

* Add the kv cache.

* Add the decoder layer.

* Add the mistral model.

* Rotary embeddings.

* Add the attention mask.
2023-09-28 14:29:41 +01:00
25657804ef Simd128 q2k vecdot (#982)
* Sketch the simd128 version of q2k vecdot.

* Use a single accumulator.
2023-09-28 12:16:35 +01:00
5e1c595e00 Optimize the index-select cuda kernel. (#976) 2023-09-28 09:05:29 +01:00
8a49e01b9d Add the remaining quantized tests to the wasm suite. (#980) 2023-09-28 08:42:56 +01:00
9cb110c44c Sketch a simd128 optimized q4k vecdot. (#977)
* Sketch a simd128 optimized q4k vecdot.

* Simdify.

* More quantization optimizations.

* Again more simdification.

* Simdify the splitting loop.
2023-09-27 20:19:38 +01:00
667f01c173 Simd128 vec-dot for q4_0. (#974)
* Simd128 vec-dot for q4_0.

* Bugfix.

* Add wasm tests.

* Bugfix for the q40 vecdot.

* More quantization tests.
2023-09-27 14:15:30 +01:00
e59784e353 simd128 optimized q8_0 vecdot (#972)
* wasm/simd128 version of the quantized q8_0 vecdot.

* Add the missing conversion.
2023-09-27 11:03:20 +01:00
29bd6b2979 Phi 1.5 wasm module (#966)
* add phi wasm module

* replace input with textarea

* trim input prompt

* stop on <|endoftext|>

* formatting

* clean up

* add blurb, and syntax highlighting

* add phi-v1.5 wasm

* add note

* hide Options on details

* add first token to generated text

* whitespaces for new line

* fix: abort -> aborted
2023-09-27 06:07:11 +01:00
9571b200c9 fix firstToken, minor ui changes (#971) 2023-09-27 06:01:59 +01:00
ce0a4e3a85 Use the gelu-erf activation. (#969) 2023-09-26 22:30:21 +01:00
4abc1ea34d Avoid some overflows on wasm32. (#968) 2023-09-26 11:15:38 +01:00
2dd43d6cdd add eos token to phi example (#965)
* add eos token to phi example

* rustfmt + get the token directly.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-09-26 09:21:22 +01:00
1fcac4afed Expose a function to clear the KV cache on mixformers. (#964) 2023-09-26 05:41:07 +01:00
a084f65f9a fix rep penalty min value (#963) 2023-09-26 05:23:50 +01:00
c798184c2b Configurable layer idx for the lstm layer. (#962) 2023-09-25 21:31:14 +01:00
c78a294323 Add some repeat penalty to the phi example. (#961) 2023-09-25 20:53:30 +01:00
a36d883254 Use a single flag for the point argument. (#958) 2023-09-25 12:53:24 +01:00
7f2bbcf746 [segment-anything] Support multi-point as the prompt input (#945)
* [sam] Support multi-point prompts

* [segment-anything] Pass points by reference

* [segment-anything] Update example code and image

* Fix clippy lint.

---------

Co-authored-by: Yun Ding <yunding@nvidia.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2023-09-25 12:14:10 +01:00
dc47224ab9 Override the default cudnn heuristics. (#957) 2023-09-25 10:31:53 +01:00
1ce7fe2543 Add more examples to the phi readme. (#956) 2023-09-24 18:19:05 +01:00
402ddcfcb4 Add the missing kernel. (#955) 2023-09-24 17:21:37 +01:00
f5069dd354 Use the repo for the quantized phi model. (#954) 2023-09-24 16:30:26 +01:00
0007ae9c11 Add the quantized mixformer model. (#953)
* Add the quantized mixformer model.

* Add the quantized option in the phi example.
2023-09-24 15:03:48 +01:00
e15862cfdb Shared the quantized var-builder code. (#952)
* Shared the quantized var-builder code.

* Fix compilation.
2023-09-24 12:55:07 +01:00
4aeb449017 Depreate the VarBuilder::from_safetensors function. (#951) 2023-09-24 11:18:17 +01:00
bcb0ed8f1c Self-contained safetensors for the multiprocess llama example. (#950) 2023-09-24 06:54:49 +01:00
7edd755756 Pass directly the buffer ownership. (#949) 2023-09-24 06:34:44 +01:00
e32c89d90c Add the buffered safetensor wrapper. (#948) 2023-09-23 22:57:42 +01:00
bb3471ea31 Adapt more examples to the updated safetensor api. (#947)
* Simplify the safetensor usage.

* Convert more examples.

* Move more examples.

* Adapt stable-diffusion.
2023-09-23 21:26:03 +01:00
890d069092 Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
2023-09-23 20:39:52 +01:00
5dbe46b389 Add tracing. (#943) 2023-09-23 16:55:46 +01:00
ccf352f3d1 Use yoke to provide a self-referential container for mmaped safetenso… (#939)
* Use yoke to provide a self-referential container for mmaped safetensor files.

* Add the new self-owned type for safetensor files without removing the previous version.

* Add routing.

* Add an initializer for the case of multiple files.
2023-09-23 15:43:11 +01:00
402d207f0f VarMap setter functions (#938)
* Add some setter helper functions for varmap.

* Add more comments.
2023-09-23 10:27:51 +01:00
7582937a32 Add the causal mask in mixformer. (#937) 2023-09-23 09:50:26 +01:00
b54acfa3d0 Tracing for the phi model (#936)
* Add some tracing bits to mixformers.

* Add the missing file.

* Add the conv2d layer to with-tracing.

* Improve the tracing usage.
2023-09-23 09:19:34 +01:00
cda1786eed smaller t5 models quantized (#934) 2023-09-22 22:31:23 +01:00
912a3d63b0 Use the proper block size for quantizing models. (#933)
* Use the proper block size for quantizing models.

* Use the proper dimension.
2023-09-22 21:36:56 +01:00
3ef328c53d Mention the new phi model in the readme. (#932) 2023-09-22 21:24:51 +01:00
0c8e983514 update link to t5 (#931) 2023-09-22 20:30:01 +01:00
df6f5240ba Complete the mixformer implementation. (#930)
* Complete the mixformers implementation.

* Tweak the attention.

* Add the phi-1.5 example.

* Improve the phi example.

* Bugfix.

* Get the phi example to work.
2023-09-22 20:03:16 +01:00
a46b1b4657 Mixformer (#929)
* Sketch the mixformer model.

* More modeling code.

* More mixformers.

* MixFormer creation.

* More mixformers.
2023-09-22 16:17:14 +01:00
19e52e5007 T5 Wasm (#918)
* init t5 wasm model

* split workers for each model

* clean up

* add some ui

* readme

* index

* typo

* remove cache param, clear_kv_cache

* add max_length as param

* add model tasks option to ui

* add method to load quantized gguf from buffer

* Add quantized wasm module

* add quantized models to UI, dynamic import wasms

* link to quantized

* fix copy

* fix ModelEncoder

* fix README.md
2023-09-22 15:31:10 +01:00
8601537e31 Add slice-scatter. (#927)
* Add slice-scatter.

* Add the op.

* Make transpose be a no-op when the dimensions are identical.

* Add the backprop.

* And add some gradient test.
2023-09-22 12:18:16 +01:00
4ac6039a42 Merge branch 'main' into book-trainin-simplified 2023-09-22 11:01:23 +06:00
52a60ca3ad https://github.com/huggingface/candle/issues/637 2023-09-22 10:57:11 +06:00
a96878f235 cuda cast i64 (#925) 2023-09-21 19:52:39 +01:00
aa8ec06fd2 Add the t5-xxl version. (#924) 2023-09-21 14:48:13 +01:00
b43ca493f6 Add more quantized flan t5 variants (#923)
* Add the quantized flan-t5-large variant.

* Add more sizes.
2023-09-21 13:23:30 +01:00
3b557765e8 T5 quantized example (#922)
* Load gguf files for the quantized t5.

* Add the quantized t5 example.

* Allow for loading local files.

* Add some support for quantizing safetensor files.

* Transpose before quantizing.

* Quantized t5.

* Retrieve the weights from the hub.
2023-09-21 12:33:15 +01:00
2619c4307f Add a quantized version of the t5 model. (#921) 2023-09-21 11:13:39 +01:00
c89b82b2d4 Add a clear cache function to the t5 model. (#919) 2023-09-21 09:01:06 +01:00
7b26e513f1 Add the erf function. (#917) 2023-09-21 06:19:10 +01:00
ab1d40ea97 Add more t5 tracing. (#915) 2023-09-20 20:20:54 +01:00
3a0d3e05df Add more t5 tracing. (#914)
* Add more t5 tracing.

* Rever the sm change.
2023-09-20 16:37:51 +01:00
9b24d89d2d Tracing mode for T5. (#913)
* Tracing mode for T5.

* Tracing for the linear layer.
2023-09-20 15:03:35 +01:00
fb1c2ac535 Add flash-attn support. (#912)
* Add flash-attn support.

* Add the use-flash-attn flag.

* Re-enable flash-attn.
2023-09-20 14:07:55 +01:00
728e167334 Add details on wuerstchen. (#911) 2023-09-20 13:09:35 +01:00
7b1ddcff47 Add clone to various nn layers. (#910) 2023-09-20 11:33:51 +01:00
f685b2231c Add some missing biases. (#908) 2023-09-20 10:14:51 +01:00
c0b49d5a50 Wuerstchen parameter tweaks. (#907) 2023-09-20 09:26:24 +01:00
098dd0d1e9 fix: add missingtop_p in llama_multiprocess (#905) 2023-09-20 08:54:56 +01:00
05626ef492 Flan T5: Read lm_head when word embeddings are not tied (#903)
* Read lm_head when word embeddings are not tied

* Fix formatting

* Address comments
2023-09-19 22:36:47 +01:00
67a486d18d Line-up the wuerstchen model with the python implementation. (#901)
* Line-up the wuerstchen model with the python implementation.

* Missing cos.

* Fix the picture denormalization.
2023-09-19 21:59:44 +01:00
7ad82b87e4 BERT Wasm (#902)
* implement wasm module

* add example to workspace

* add UI explore semantic similiarity

* change status messages

* formatting

* minor changes
2023-09-19 21:31:37 +01:00
8696f64bae Fix T5 kv cache (#899)
* Fix T5 kv cache

* Add argument for decoder prompt

* Fix range
2023-09-19 20:36:15 +01:00
d7e48234d4 Add an erf based gelu op (#900)
* Erf based gelu.

* Add the erf backed gelu.

* Test the new gelu op (which is not gelu_new).
2023-09-19 19:54:28 +01:00
34f2ecbc3b Fix the leaky relu. (#898) 2023-09-19 18:17:17 +01:00
4f91c8e109 Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
2023-09-19 15:09:47 +01:00
06e46d7c3b Only use classifier free guidance for the prior. (#896)
* Only use classifier free guidance for the prior.

* Add another specific layer-norm structure.

* Tweaks.

* Fix the latent shape.

* Print the prior shape.

* More shape fixes.

* Remove some debugging continue.
2023-09-19 14:13:05 +01:00
9cf26c5cff Fix typo in error_manage.md (#888)
occured -> occurred
2023-09-19 07:14:15 +01:00
aaa9d4ed6c W decoding. (#893)
* W decoding.

* Add the diffusion loop.

* Use the appropriate config.
2023-09-19 07:13:44 +01:00
92db8cecd3 Specialized attention module for Wuerstchen. (#890)
* Specialized attention module for Wuerstchen.

* Reshaping ops.

* Attention processor.

* Finish the forward pass.

* Hook the new attention processor.

* Get the prior forward pass to work.

* Make it contiguous.
2023-09-18 21:16:09 +01:00
1542e92629 T5: Add option to override use_cache from config (#892)
* Add option to override use_cache from config

* Disable cache by default and cleanup code
2023-09-18 20:20:21 +01:00
82a98f6da0 Prior denoising. (#889) 2023-09-18 16:51:38 +01:00
5082954c52 Fix the W clip embeddings. (#887)
* Fix the W clip embeddings.

* Add the specialized ddpm scheduler.
2023-09-18 14:50:14 +01:00
7dd8e12472 Bump the crate versions to v0.2.3. (#886)
* Bump the crate version.

* Also update the python bindings.
2023-09-18 12:14:03 +01:00
12696b7b2d Fix typos in SAM WASM example (#884) 2023-09-18 09:41:50 +01:00
ef8cd8fea0 Update the candle-gemm version. (#885) 2023-09-18 09:36:20 +01:00
03e194123d Add return types to *.pyi stubs (#880)
* Start generating return types

* Finish tensor type hinting

* Add `save_gguf` to `utils`

* Typehint `quant-llama.py`
2023-09-17 22:11:01 +01:00
c2b866172a More Wuerstchen fixes. (#882)
* More Weurstchen fixes.

* More shape fixes.

* Add more of the prior specific bits.

* Broadcast add.

* Fix the clip config.

* Add some masking options to the clip model.
2023-09-17 22:08:11 +01:00
06cc329e71 Remove the parameters for the Wuerstchen layer-norm. (#879)
* Remove the parameters for the Wuerstchen layer-norm.

* Fixes.

* More fixes (including conv-transpose2d.

* More fixes.

* Again more fixes.
2023-09-17 15:59:27 +01:00
5f83c13f17 Add the DDPM scheduler. (#877)
* Add the DDPM scheduler.

* Minor tweaks.
2023-09-17 15:03:01 +01:00
db3e9dae04 Wuerstchen main (#876)
* Wuerstchen main.

* More of the wuerstchen cli example.

* Paella creation.

* Build the prior model.

* Fix the weight file names.
2023-09-17 12:46:38 +01:00
7f65af1f0d Avoid re-encoding the input in the T5 example. (#875) 2023-09-17 10:25:54 +01:00
eeb54716dd Tweaks for the T5 example. (#874) 2023-09-17 10:05:15 +01:00
1a276b5da7 Add a KV cache to T5. (#873)
* Add a KV cache to T5.

* Suggest using release mode.

* Use the kv cache in decoding.

* Add a comment.
2023-09-17 08:00:45 +01:00
8658df3485 Generate *.pyi stubs for PyO3 wrapper (#870)
* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
2023-09-16 17:23:38 +01:00
7cafca835a readme tweaks. (#867) 2023-09-16 07:22:24 +01:00
04ca2b9ebd Update README + SAM (#866)
* use serde-wasm-bindgen, faster serialization

* update readme with demos
2023-09-16 07:34:13 +02:00
635012d770 Do not backprop through argmin/argmax. (#865) 2023-09-15 22:15:40 +01:00
3e49f8fce5 Implement T5 decoding (#864)
* Load t5 decoder

* Run enc, dec, and lm head, but no cross attn

* Cross-attention over key_value_states

* New arg for decoder input ids

* Add mask, don't forward position biases through decoder

* Update t5 examples

* Clippy + rustfmt
2023-09-15 22:05:12 +02:00
c2007ac88f W fixes. (#862) 2023-09-15 15:11:11 +01:00
30be5b6660 Replication pad (#861)
* Add the embed mapper convolutions.

* Add the replication pad layer.

* Use the replication-pad op.

* Tweak a todo.
2023-09-15 14:06:21 +01:00
107d3d9530 Add the embed mapper convolutions. (#860) 2023-09-15 11:38:38 +02:00
2746f2c4be DiffNeXt/unet (#859)
* DiffNeXt/unet

* Start adding the vae.

* VAE residual block.

* VAE forward pass.

* Add pixel shuffling.

* Actually use pixel shuffling.
2023-09-15 10:14:02 +01:00
81a36b8713 Add link error info (#851)
* add link error info

* grammar fix
2023-09-15 07:25:10 +01:00
0633c85514 Add leaky-relu in the activation enum. (#858) 2023-09-15 07:05:38 +01:00
39157346cb Add SAM UI Demo (#854)
* fix tensor flattening

* send image data back

* sam ui worker example

* SAM example

* resize container

* no need for this
2023-09-15 06:31:58 +01:00
5cefbba757 minor UI fixes (#856)
* fixes

* remove listener

* remove event listener
2023-09-15 06:30:50 +01:00
130fe5a087 Add the upblocks. (#853) 2023-09-14 22:24:56 +01:00
91ec546feb More DiffNeXt. (#847)
* More DiffNeXt.

* Down blocks.
2023-09-14 22:16:31 +02:00
0a647875ec Use softmax-last-dim in the quantized example. (#848) 2023-09-14 17:29:24 +01:00
a0c6d5548c Add the attention block. (#846)
* Add the attention block.

* Add more to clipnext.
2023-09-14 15:40:09 +01:00
286f01db14 Start adding the Wuerstchen diffusion pipeline (#843)
* Wuerstchen common bits.

* Add the prior layer.

* Start adding diffnext.
2023-09-14 10:56:07 +01:00
d6447ad635 Tensor based indexing. (#842) 2023-09-14 07:47:07 +01:00
49d3f7f708 Add support to flan-t5 (#840) 2023-09-13 19:27:20 +02:00
9a465e1b26 Add 1d upsampling. (#839)
* Add 1d upsampling.

* Add the interpolate functions.
2023-09-13 16:50:39 +01:00
31ab2ddaeb Remove the padding. (#838) 2023-09-13 13:00:59 +01:00
b11a2a7b9d Move the constant to avoid some unused warning. (#837) 2023-09-13 11:56:53 +01:00
1c09164021 Add CANDLE_NVCC_CCBIN support for candle-kernels, and eliminate warning. (#836) 2023-09-13 11:39:22 +01:00
3e94324012 Add some sentence similarity part to the t5 example. (#835)
* Add some sentence similarity part to the t5 example.

* Clippy fix.
2023-09-13 10:44:02 +01:00
e6f040d6e3 Readme gallery (#834)
* More readme tweaks.

* Update README.md
2023-09-13 09:05:47 +01:00
cbd36157ac Add a gif to the quantized readme. (#833)
* Add a gif to the quantized readme.

* gif update.
2023-09-13 08:43:52 +01:00
18d3c803a8 Scalar support in minimum/maximum. (#832)
* Scalar support in minimum/maximum.

* Add a clamp method to tensors.
2023-09-13 08:24:58 +01:00
e4553fb355 T5 tweaks (#831)
* Use default values rather than options.

* Avoid exposing the device field.

* More tweaks.
2023-09-13 07:37:04 +01:00
d801e1d564 Clippy fix. (#830) 2023-09-13 07:16:20 +01:00
9daa6dbe87 Extract T5 module and add main function to use it (#829)
* Extract t5 out of musicgen

* Add main for t5 module
2023-09-13 07:14:05 +01:00
e82fcf1c59 Add more example readmes. (#828)
* Add more readmes.

* Add a readme for dinov2.

* Add some skeleton files for a couple more examples.

* More whisper details.
2023-09-12 17:21:24 +01:00
805bf9ffa7 Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling

* Update changelog

* rustfmt

* Add tests

* Fix clippy warning

* Fix another clippy error
2023-09-12 18:10:16 +02:00
42da17694a Segment Anything readme (#827)
* Add a readme for the segment-anything model.

* Add the original image.

* Clean-up the segment anything cli example.

* Also print the mask id in the outputs.
2023-09-12 14:35:55 +01:00
25aacda28e Add useful libraries section (#825)
* Add useful libraries section

* Add link
2023-09-12 11:06:21 +01:00
7a62aad24a Add a readme for yolo-v8. (#824) 2023-09-12 11:01:06 +01:00
bb23b90b1d Add a small readme for the quantized example. (#823) 2023-09-12 10:17:31 +01:00
2257f4d475 Bump the crate version + update the changelog. (#822) 2023-09-12 06:39:24 +01:00
871efc0307 Bugfix for the conv2d cpu kernel. (#820) 2023-09-11 23:11:27 +01:00
c5a058b169 Use the module trait in stable-diffusion. (#817) 2023-09-11 20:40:07 +01:00
59e63d690c Add weight, bias, and hidden_size methods (#816)
* Add weight, bias methods to Conv(1|2)

* Add hidden_size method to Embedding

* Expose hidden_size
2023-09-11 16:01:11 +01:00
dbd4561416 im2col version of the conv1d kernel. (#815)
* im2col version of the cuda conv1d kernel.

* im2col version of the conv1d cpu kernel.
2023-09-11 14:40:09 +01:00
5c35fbbb13 Stable-Diffusion readme (#814)
* Stable Diffusion readme.

* Fix the image path.

* Move the assets.

* Resize the sample image.

* Lower resolution.
2023-09-11 13:06:29 +01:00
70f38c2069 Proper error on unsupported dtypes when using gemm. (#813) 2023-09-11 12:10:51 +01:00
d7b9fec849 Move the stable-diffusion modeling code so that it's easier to re-use. (#812) 2023-09-11 11:45:57 +01:00
84ee870efd Use softmax-last-dim in whisper. (#810) 2023-09-11 11:05:05 +01:00
df712ecf64 Handle the case where the kernel is not contiguous in the cuda backend. (#809) 2023-09-11 09:48:31 +01:00
6fb665004c Enable im2col on the cpu side. (#805)
* Enable im2col on the cpu side.

* Hook im2col on the cpu backend.

* Use the kernel offset.

* Avoid an unnecessary copy.

* Handle non-contiguous kernels.

* Add a const to select the conv2d kernel.
2023-09-11 09:28:13 +01:00
1cd74129d4 Add Im2Col support on the gpu side. (#808)
* Add Im2Col support on the gpu side.

* Actually enable.
2023-09-11 08:52:33 +01:00
98d1242b8f im2col based conv2d (#802)
* im2col implementation for conv2d.

* Fix for the im2col implementation to match the current conv2d.

* Small optimization.

* Add a cuda kernel.

* Handle arbitrary layouts.

* Im2Col cuda code.
2023-09-10 21:02:42 +01:00
18d6db2180 more doc fixes (#804) 2023-09-10 20:36:29 +01:00
4f18180fc7 Bugfix so that im2col produce the same results as conv2d. (#801) 2023-09-10 16:59:46 +01:00
559944146f Add an im2col based benchmark. (#800)
* Add an im2col based benchmark.

* Reshape the final result.
2023-09-10 16:56:28 +01:00
3dd5804299 Fix typo in readme. (#799) 2023-09-10 13:49:47 +01:00
90e077e409 Return the low res mask in the wasm segment-anything module. (#798)
* Return the low res mask.

* Add some validations.
2023-09-10 13:03:02 +01:00
584171cae1 Add a wasm module for the segment anything example. (#797) 2023-09-10 12:29:37 +01:00
6c58fc59fd Little docs changes (#791)
* Little doc fixes

* change imports in lib

* rename candle_core to candle

* revert "rename candle_core to candle"
2023-09-10 12:02:52 +01:00
35f72514f5 Move more models to candle-transformers (#796)
* Move dinov2.

* Move efficientnet.

* Move the quantized llama model.

* Move segment-anything.
2023-09-10 10:20:18 +01:00
d3f05eae8c Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared.

* Also move falcon.

* Move Llama.

* Move whisper (partial).
2023-09-10 09:40:27 +01:00
258ac32c38 Fix cuda randn when generating an odd number of values. (#793) 2023-09-09 18:44:21 +01:00
31936c08fe ViT tracing. (#790) 2023-09-09 17:26:39 +01:00
74ad4deb42 Get the MobileSAM TinyViT based version to work. (#789)
* More TinyViT support in SA.

* More mobilesam work.

* Add the mobile-sam weights to the hub.
2023-09-09 16:21:44 +01:00
b7cd58473b TinyViT backbone for segment-anything. (#787)
* TinyViT.

* More TinyViT.

* Add more to the tinyvit backbone.

* Proper padding.

* Plus ViT.

* Add the tiniest vit spec.
2023-09-09 15:10:06 +01:00
3cd7e7b51d Fuse the rel-pos additions via a custom-op. (#786)
* Fuse the rel-pos additions via a custom-op.

* Run with rayon.

* Add more tracing.
2023-09-09 10:46:09 +01:00
722c50bb0c Use byteorder in mnist. (#785) 2023-09-09 09:03:59 +01:00
976a1086ee feat: u32 from_be_bytes (#765) 2023-09-09 08:55:35 +01:00
c88d6fd4b9 Remove set_training. (#784) 2023-09-09 08:27:37 +01:00
057f7909bc Accelerate support for gelu. (#782) 2023-09-08 21:58:56 +01:00
265 changed files with 22376 additions and 2766 deletions

8
.gitignore vendored
View File

@ -23,14 +23,16 @@ flamegraph.svg
*.dylib
*.so
*.swp
*.swo
trace-*.json
candle-wasm-examples/*/build
candle-wasm-examples/*/*.bin
candle-wasm-examples/*/*.jpeg
candle-wasm-examples/*/*.wav
candle-wasm-examples/*/*.safetensors
candle-wasm-examples/*/audios/*.wav
candle-wasm-examples/**/*.safetensors
candle-wasm-examples/**/*.gguf
candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*

11
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,11 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": [
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@ -1,13 +1,84 @@
# Changelog
This documents the main changes to the `candle` crate.
## v0.2.1 - Unreleased
## v0.3.1 - Unreleased
### Added
### Modified
## v0.3.0 - 2023-10-01
### Added
- Added the Mistral 7b v0.1 model
[983](https://github.com/huggingface/candle/pull/983).
- Quantized version of the Mistral model
[1009](https://github.com/huggingface/candle/pull/1009).
- Add the gelu-erf op and activation function
[969](https://github.com/huggingface/candle/pull/969).
- Add the mixformer/phi-v1.5 model
[930](https://github.com/huggingface/candle/pull/930).
- Add the sclice-scatter op
[927](https://github.com/huggingface/candle/pull/927).
- Add the Wuerstchen diffusion model
[911](https://github.com/huggingface/candle/pull/911).
### Modified
- Support for simd128 intrinsics in some quantized vecdots
[982](https://github.com/huggingface/candle/pull/982).
- Optimize the index-select cuda kernel
[976](https://github.com/huggingface/candle/pull/976).
- Self-contained safetensor wrappers
[946](https://github.com/huggingface/candle/pull/946).
## v0.2.2 - 2023-09-18
### Added
- Support for `top_p` sampling
[819](https://github.com/huggingface/candle/pull/819).
- T5 model including decoding
[864](https://github.com/huggingface/candle/pull/864).
- 1-d upsampling
[839](https://github.com/huggingface/candle/pull/839).
### Modified
- Bugfix for conv2d
[820](https://github.com/huggingface/candle/pull/820).
- Support tensor based indexing using `.i`
[842](https://github.com/huggingface/candle/pull/842).
## v0.2.1 - 2023-09-11
### Added
- Add some RNNs (GRU and LSTM) in `candle-nn`
[674](https://github.com/huggingface/candle/pull/674),
[688](https://github.com/huggingface/candle/pull/688).
- gguf v2 support
[725](https://github.com/huggingface/candle/pull/725).
- Quantized llama example in Python using the pyo3 api
[716](https://github.com/huggingface/candle/pull/716).
- `candle-nn` layer for conv2d-transposed
[760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segemnt anything example
[787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support
[770](https://github.com/huggingface/candle/pull/770).
### Modified
- Dilations are now supported in conv-transpose2d.
[671](https://github.com/huggingface/candle/pull/671).
- Interactive mode for the quantized model
[690](https://github.com/huggingface/candle/pull/690).
- Faster softmax operation
[747](https://github.com/huggingface/candle/pull/747).
- Faster convolution operations on CPU and CUDA via im2col
[802](https://github.com/huggingface/candle/pull/802).
- Moving some models to a more central location
[796](https://github.com/huggingface/candle/pull/796).
## v0.2.0 - 2023-08-30

View File

@ -8,17 +8,19 @@ members = [
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/phi",
"candle-wasm-examples/t5",
"candle-wasm-tests",
]
exclude = [
"candle-flash-attn",
"candle-kernels",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"
[workspace.package]
version = "0.2.1"
version = "0.3.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,7 +35,7 @@ byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.15.6", package = "candle-gemm" }
gemm = { version = "0.16.0", package = "candle-gemm" }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -41,9 +43,10 @@ imageproc = { version = "0.23.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
libc = { version = "0.2.147" }
log = "0.4"
memmap2 = "0.7.1"
memmap2 = { version = "0.7.1", features = ["stable_deref_trait"] }
num_cpus = "1.15.0"
num-traits = "0.2.15"
parquet = { version = "45.0.0" }
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.7.0"
@ -57,8 +60,8 @@ tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
parquet = { version = "45.0.0" }
[profile.release-with-debug]
inherits = "release"

126
README.md
View File

@ -8,7 +8,10 @@ Candle is a minimalist ML framework for Rust with a focus on performance (includ
and ease of use. Try our online demos:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo).
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[yolo](https://huggingface.co/spaces/lmz/candle-yolo),
[Segment
Anything](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
## Get started
@ -45,40 +48,60 @@ For more advanced examples, please have a look at the following section.
## Check out our examples
Check out our [examples](./candle-examples/examples/):
These online demos run entirely in your browser:
- [yolo](https://huggingface.co/spaces/lmz/candle-yolo): pose estimation and
object recognition.
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): text to speech.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
We also provide a some command line based examples using state of the art models:
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code
generation.
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
- [Wuerstchen](./candle-examples/examples/wuerstchen/): another text to
image generative model.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/wuerstchen/assets/cat.jpg" width="200">
- [yolo-v3](./candle-examples/examples/yolo-v3/) and
[yolo-v8](./candle-examples/examples/yolo-v8/): object detection and pose
estimation models.
[segment-anything](./candle-examples/examples/segment-anything/): image
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.od.jpg" width="200"><img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.pose.jpg" width="200">
- [segment-anything](./candle-examples/examples/segment-anything/): image
segmentation model with prompt.
Run them using the following commands:
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
Run them using commands like:
```
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release
cargo run --example yolo-v3 --release -- myimage.jpg
cargo run --example yolo-v8 --release -- myimage.jpg # for pose estimation, add --task pose
cargo run --example segment-anything --release -- --image myimage.jpg
```
In order to use **CUDA** add `--features cuda` to the example command line. If
@ -88,7 +111,10 @@ There are also some wasm examples for whisper and
[llama2.c](https://github.com/karpathy/llama2.c). You can either build them with
`trunk` or try them online:
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2).
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a
test server:
@ -101,6 +127,15 @@ trunk serve --release --port 8081
And then head over to
[http://localhost:8081/](http://localhost:8081/).
<!--- ANCHOR: useful_libraries --->
## Useful Libraries
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
If you have an addition to this list, please submit a pull request.
<!--- ANCHOR_END: useful_libraries --->
<!--- ANCHOR: features --->
## Features
@ -113,10 +148,24 @@ And then head over to
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- LLMs: LLaMA v1 and v2, Falcon, StarCoder.
- Language Models.
- LLaMA v1 and v2.
- Falcon.
- StarCoder.
- Phi v1.5.
- Mistral 7b v0.1.
- StableLM-3B-4E1T.
- T5.
- Bert.
- Whisper (multi-lingual support).
- Stable Diffusion.
- Computer Vision: DINOv2, EfficientNet, yolo-v3, yolo-v8.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Computer Vision Models.
- DINOv2.
- EfficientNet.
- yolo-v3.
- yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
@ -257,6 +306,29 @@ This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a d
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
```
#### Linking error on windows when running rustdoc or mdbook tests
```
Couldn't compile the test.
---- .\candle-book\src\inference\hub.md - Using_the_hub::Using_in_a_real_model_ (line 50) stdout ----
error: linking with `link.exe` failed: exit code: 1181
//very long chain of linking
= note: LINK : fatal error LNK1181: cannot open input file 'windows.0.48.5.lib'
```
Make sure you link all native libraries that might be located outside a project target, e.g., to run mdbook tests, you should run:
```
mdbook test candle-book -L .\target\debug\deps\ `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.42.2\lib `
-L native=$env:USERPROFILE\.cargo\registry\src\index.crates.io-6f17d22bba15001f\windows_x86_64_msvc-0.48.5\lib
```
#### Extremely slow model load time with WSL
This may be caused by the models being loaded from `/mnt/c`, more details on
[stackoverflow](https://stackoverflow.com/questions/68972448/why-is-wsl-extremely-slow-when-compared-with-native-windows-npm-yarn-processing).
#### Tracking down errors
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
@ -24,9 +24,10 @@ intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true, optional = true }
anyhow = { workspace = true }
tokio = "1.29.1"
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
@ -38,7 +39,6 @@ tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"
parquet = { workspace = true }
image = { workspace = true }

View File

@ -10,10 +10,11 @@
# Reference Guide
- [Running a model](inference/README.md)
- [Running a model](inference/inference.md)
- [Using the hub](inference/hub.md)
- [Error management](error_manage.md)
- [Training](training/README.md)
- [Training](training/training.md)
- [Simplified](training/simplified.md)
- [MNIST](training/mnist.md)
- [Fine-tuning]()
- [Serialization]()

View File

@ -29,7 +29,7 @@ After adding `RUST_BACKTRACE=1`:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Not super pretty at the moment, but we can see error occurred on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces

View File

@ -6,7 +6,7 @@ Open `src/main.rs` and fill in this content:
```rust
# extern crate candle_core;
use candle_core::{DType, Device, Result, Tensor};
use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
@ -25,11 +25,11 @@ fn main() -> Result<()> {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
let first = Tensor::zeros((784, 100), DType::F32, &device)?;
let second = Tensor::zeros((100, 10), DType::F32, &device)?;
let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@ -50,7 +50,7 @@ the classical `Linear` layer. We can do as such
```rust
# extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor};
# use candle_core::{Device, Result, Tensor};
struct Linear{
weight: Tensor,
bias: Tensor,
@ -80,7 +80,7 @@ This will change the model running code into a new function
```rust
# extern crate candle_core;
# use candle_core::{DType, Device, Result, Tensor};
# use candle_core::{Device, Result, Tensor};
# struct Linear{
# weight: Tensor,
# bias: Tensor,
@ -110,15 +110,15 @@ fn main() -> Result<()> {
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear{weight, bias};
let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
// Inference on the model
let digit = model.forward(&dummy_image)?;
@ -146,7 +146,7 @@ And rewrite our examples using it
```rust
# extern crate candle_core;
# extern crate candle_nn;
use candle_core::{DType, Device, Result, Tensor};
use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
@ -167,15 +167,15 @@ fn main() -> Result<()> {
let device = Device::Cpu;
// This has changed (784, 100) -> (100, 784) !
let weight = Tensor::zeros((100, 784), DType::F32, &device)?;
let bias = Tensor::zeros((100, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::zeros((10, 100), DType::F32, &device)?;
let bias = Tensor::zeros((10, ), DType::F32, &device)?;
let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?;
let digit = model.forward(&dummy_image)?;
println!("Digit {digit:?} digit");
@ -188,8 +188,8 @@ Feel free to modify this example to use `Conv2d` to create a classical convnet i
Now that we have the running dummy code we can get to more advanced topics:
- [For PyTorch users](./guide/cheatsheet.md)
- [Running existing models](./inference/README.md)
- [Training models](./training/README.md)
- [For PyTorch users](../guide/cheatsheet.md)
- [Running existing models](../inference/inference.md)
- [Training models](../training/training.md)

View File

@ -1,3 +1,6 @@
#[cfg(test)]
pub mod simplified;
#[cfg(test)]
mod tests {
use anyhow::Result;

View File

@ -0,0 +1,196 @@
//! #A simplified example in Rust of training a neural network and then using it based on the Candle Framework by Hugging Face.
//! Author: Evgeny Igumnov 2023 igumnovnsk@gmail.com
//! This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
//!
//! ##Basic moments:
//!
//! A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
//! The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
//! The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
//! For training, samples with real data on the results of the first and second stages of different elections are used.
//! The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
//! Model parameters (weights of neurons) are initialized randomly, then optimized during training.
//! After training, the model is tested on a deferred sample to evaluate the accuracy.
//! If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
//! Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
#[rustfmt::skip]
mod tests {
use candle::{DType, Result, Tensor, D, Device};
use candle_nn::{loss, ops, Linear, Module, VarBuilder, VarMap, Optimizer};
// ANCHOR: book_training_simplified1
const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;
#[derive(Clone)]
pub struct Dataset {
pub train_votes: Tensor,
pub train_results: Tensor,
pub test_votes: Tensor,
pub test_results: Tensor,
}
struct MultiLevelPerceptron {
ln1: Linear,
ln2: Linear,
ln3: Linear,
}
impl MultiLevelPerceptron {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
Ok(Self { ln1, ln2, ln3 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
let xs = self.ln2.forward(&xs)?;
let xs = xs.relu()?;
self.ln3.forward(&xs)
}
}
// ANCHOR_END: book_training_simplified1
// ANCHOR: book_training_simplified3
#[tokio::test]
async fn simplified() -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_votes_vec: Vec<u32> = vec![
15, 10,
10, 15,
5, 12,
30, 20,
16, 12,
13, 25,
6, 14,
31, 21,
];
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let train_results_vec: Vec<u32> = vec![
1,
0,
0,
1,
1,
0,
0,
1,
];
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
let test_votes_vec: Vec<u32> = vec![
13, 9,
8, 14,
3, 10,
];
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let test_results_vec: Vec<u32> = vec![
1,
0,
0,
];
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
let m = Dataset {
train_votes: train_votes_tensor,
train_results: train_results_tensor,
test_votes: test_votes_tensor,
test_results: test_results_tensor,
};
let trained_model: MultiLevelPerceptron;
loop {
println!("Trying to train neural network.");
match train(m.clone(), &dev) {
Ok(model) => {
trained_model = model;
break;
},
Err(e) => {
println!("Error: {}", e);
continue;
}
}
}
let real_world_votes: Vec<u32> = vec![
13, 22,
];
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let final_result = trained_model.forward(&tensor_test_votes)?;
let result = final_result
.argmax(D::Minus1)?
.to_dtype(DType::F32)?
.get(0).map(|x| x.to_scalar::<f32>())??;
println!("real_life_votes: {:?}", real_world_votes);
println!("neural_network_prediction_result: {:?}", result);
Ok(())
}
// ANCHOR_END: book_training_simplified3
// ANCHOR: book_training_simplified2
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
let train_results = m.train_results.to_device(dev)?;
let train_votes = m.train_votes.to_device(dev)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
let model = MultiLevelPerceptron::new(vs.clone())?;
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
let test_votes = m.test_votes.to_device(dev)?;
let test_results = m.test_results.to_device(dev)?;
let mut final_accuracy: f32 = 0.0;
for epoch in 1..EPOCHS + 1 {
let logits = model.forward(&train_votes)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_results)?;
sgd.backward_step(&loss)?;
let test_logits = model.forward(&test_votes)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_results)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_results.dims1()? as f32;
final_accuracy = 100. * test_accuracy;
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
loss.to_scalar::<f32>()?,
final_accuracy
);
if final_accuracy == 100.0 {
break;
}
}
if final_accuracy < 100.0 {
Err(anyhow::Error::msg("The model is not trained well enough."))
} else {
Ok(model)
}
}
// ANCHOR_END: book_training_simplified2
}

View File

@ -0,0 +1,45 @@
# Simplified
## How its works
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
Basic moments:
1. A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
2. The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
3. The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
4. For training, samples with real data on the results of the first and second stages of different elections are used.
5. The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
6. Model parameters (weights of neurons) are initialized randomly, then optimized during training.
7. After training, the model is tested on a deferred sample to evaluate the accuracy.
8. If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
```rust,ignore
{{#include ../simplified.rs:book_training_simplified1}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified2}}
```
```rust,ignore
{{#include ../simplified.rs:book_training_simplified3}}
```
## Example output
```bash
Trying to train neural network.
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0
```

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.2.1", optional = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -26,6 +26,7 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
[dev-dependencies]

View File

@ -103,8 +103,10 @@ enum Command {
Quantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format.
#[arg(long)]
out_file: std::path::PathBuf,
/// The quantization schema to apply.
@ -150,8 +152,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
}
}
Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
let tensors = tensors.deserialize()?;
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
@ -218,15 +219,99 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
Ok(())
}
fn run_quantize_safetensors(
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
) -> Result<()> {
let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() {
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors)
}
println!("tensors: {}", tensors.len());
let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let block_size = match q {
Quantization::Q4_0 => k_quants::QK4_0,
Quantization::Q4_1 => k_quants::QK4_1,
Quantization::Q5_0 => k_quants::QK5_0,
Quantization::Q5_1 => k_quants::QK5_1,
Quantization::Q8_0 => k_quants::QK8_0,
Quantization::Q8_1 => k_quants::QK8_1,
Quantization::Q2k
| Quantization::Q3k
| Quantization::Q4k
| Quantization::Q5k
| Quantization::Q6k
| Quantization::Q8k => k_quants::QK_K,
Quantization::F16 | Quantization::F32 => 1,
};
let qtensors = tensors
.into_par_iter()
.map(|(name, tensor)| {
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize {
quantize_fn(&tensor)?
} else {
QTensor::quantize::<f32>(&tensor)?
};
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, &[], &qtensors)?;
Ok(())
}
fn run_quantize(
in_file: std::path::PathBuf,
in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf,
q: Quantization,
qmode: QuantizationMode,
) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
}
if let Some(extension) = out_file.extension() {
if extension == "safetensors" {
candle_core::bail!("the generated file cannot use the safetensors extension")
}
}
if let Some(extension) = in_files[0].extension() {
if extension == "safetensors" {
return run_quantize_safetensors(in_files, out_file, q);
}
}
if in_files.len() != 1 {
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
}
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?;
let mut in_ = std::fs::File::open(&in_files[0])?;
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
@ -252,7 +337,7 @@ fn run_quantize(
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_file)?;
let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor))
@ -293,7 +378,7 @@ fn main() -> anyhow::Result<()> {
out_file,
quantization,
mode,
} => run_quantize(in_file, out_file, quantization, mode)?,
} => run_quantize(&in_file, out_file, quantization, mode)?,
}
Ok(())
}

View File

@ -370,6 +370,38 @@ pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
}
#[inline]
pub fn vs_tanh_inplace(y: &mut [f32]) {
unsafe { ffi::vvtanhf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vs_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
#[inline]
pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
}
vd_tanh_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = 0.5 * v * (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]

View File

@ -57,6 +57,7 @@ pub trait BackendStorage: Sized {
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
@ -110,4 +111,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
fn set_seed(&self, _: u64) -> Result<()>;
}

View File

@ -69,7 +69,8 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Matmul(lhs, rhs) => {
| Op::Matmul(lhs, rhs)
| Op::SliceScatter0(lhs, rhs, _) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@ -90,14 +91,18 @@ impl Tensor {
nodes
}
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -111,6 +116,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@ -262,9 +268,21 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
}
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
@ -436,7 +454,18 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
Op::Unary(_, UnaryOp::Floor) => {
Err(Error::BackwardNotSupported { op: "floor" })?
}
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
}
Op::Unary(arg, UnaryOp::Relu) => {
let sum_grad = grads.or_insert(arg)?;
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
@ -517,6 +546,7 @@ impl Tensor {
}
}
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {

View File

@ -25,6 +25,19 @@ impl ParamsConv1D {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
Gemm,
Direct,
Fft,
FftTiling,
Winograd,
WinogradNonFused,
Count,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParamsConv2D {
pub(crate) b_size: usize,
@ -37,6 +50,7 @@ pub struct ParamsConv2D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv2D {
@ -188,6 +202,7 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)

763
candle-core/src/cpu/erf.rs Normal file
View File

@ -0,0 +1,763 @@
#![allow(clippy::excessive_precision)]
// Code taken from https://github.com/statrs-dev/statrs
//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and
//! related functions
mod evaluate {
//! Provides functions that don't have a numerical solution and must
//! be solved computationally (e.g. evaluation of a polynomial)
/// evaluates a polynomial at `z` where `coeff` are the coeffecients
/// to a polynomial of order `k` where `k` is the length of `coeff` and the
/// coeffecient
/// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to
/// `2z^2 - z + 3`
///
/// # Remarks
///
/// Returns 0 for a 0 length coefficient slice
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
let n = coeff.len();
if n == 0 {
return 0.0;
}
let mut sum = *coeff.last().unwrap();
for c in coeff[0..n - 1].iter().rev() {
sum = *c + z * sum;
}
sum
}
}
use std::f64;
/// `erf` calculates the error function at `x`.
pub fn erf(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x >= 0.0 && x.is_infinite() {
1.0
} else if x <= 0.0 && x.is_infinite() {
-1.0
} else if x == 0. {
0.0
} else {
erf_impl(x, false)
}
}
/// `erf_inv` calculates the inverse error function
/// at `x`.
pub fn erf_inv(x: f64) -> f64 {
if x == 0.0 {
0.0
} else if x >= 1.0 {
f64::INFINITY
} else if x <= -1.0 {
f64::NEG_INFINITY
} else if x < 0.0 {
erf_inv_impl(-x, 1.0 + x, -1.0)
} else {
erf_inv_impl(x, 1.0 - x, 1.0)
}
}
/// `erfc` calculates the complementary error function
/// at `x`.
pub fn erfc(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
0.0
} else if x == f64::NEG_INFINITY {
2.0
} else {
erf_impl(x, true)
}
}
/// `erfc_inv` calculates the complementary inverse
/// error function at `x`.
pub fn erfc_inv(x: f64) -> f64 {
if x <= 0.0 {
f64::INFINITY
} else if x >= 2.0 {
f64::NEG_INFINITY
} else if x > 1.0 {
erf_inv_impl(-1.0 + x, 2.0 - x, -1.0)
} else {
erf_inv_impl(1.0 - x, x, 1.0)
}
}
// **********************************************************
// ********** Coefficients for erf_impl polynomial **********
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_impl`
/// in the interval [1e-10, 0.5].
const ERF_IMPL_AN: &[f64] = &[
0.00337916709551257388990745,
-0.00073695653048167948530905,
-0.374732337392919607868241,
0.0817442448733587196071743,
-0.0421089319936548595203468,
0.0070165709512095756344528,
-0.00495091255982435110337458,
0.000871646599037922480317225,
];
/// Polynomial coefficients for a denominator of `erf_impl`
/// in the interval [1e-10, 0.5]
const ERF_IMPL_AD: &[f64] = &[
1.0,
-0.218088218087924645390535,
0.412542972725442099083918,
-0.0841891147873106755410271,
0.0655338856400241519690695,
-0.0120019604454941768171266,
0.00408165558926174048329689,
-0.000615900721557769691924509,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BN: &[f64] = &[
-0.0361790390718262471360258,
0.292251883444882683221149,
0.281447041797604512774415,
0.125610208862766947294894,
0.0274135028268930549240776,
0.00250839672168065762786937,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.5, 0.75].
const ERF_IMPL_BD: &[f64] = &[
1.0,
1.8545005897903486499845,
1.43575803037831418074962,
0.582827658753036572454135,
0.124810476932949746447682,
0.0113724176546353285778481,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CN: &[f64] = &[
-0.0397876892611136856954425,
0.153165212467878293257683,
0.191260295600936245503129,
0.10276327061989304213645,
0.029637090615738836726027,
0.0046093486780275489468812,
0.000307607820348680180548455,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [0.75, 1.25].
const ERF_IMPL_CD: &[f64] = &[
1.0,
1.95520072987627704987886,
1.64762317199384860109595,
0.768238607022126250082483,
0.209793185936509782784315,
0.0319569316899913392596356,
0.00213363160895785378615014,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DN: &[f64] = &[
-0.0300838560557949717328341,
0.0538578829844454508530552,
0.0726211541651914182692959,
0.0367628469888049348429018,
0.00964629015572527529605267,
0.00133453480075291076745275,
0.778087599782504251917881e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [1.25, 2.25].
const ERF_IMPL_DD: &[f64] = &[
1.0,
1.75967098147167528287343,
1.32883571437961120556307,
0.552528596508757581287907,
0.133793056941332861912279,
0.0179509645176280768640766,
0.00104712440019937356634038,
-0.106640381820357337177643e-7,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_EN: &[f64] = &[
-0.0117907570137227847827732,
0.014262132090538809896674,
0.0202234435902960820020765,
0.00930668299990432009042239,
0.00213357802422065994322516,
0.00025022987386460102395382,
0.120534912219588189822126e-4,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [2.25, 3.5].
const ERF_IMPL_ED: &[f64] = &[
1.0,
1.50376225203620482047419,
0.965397786204462896346934,
0.339265230476796681555511,
0.0689740649541569716897427,
0.00771060262491768307365526,
0.000371421101531069302990367,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FN: &[f64] = &[
-0.00546954795538729307482955,
0.00404190278731707110245394,
0.0054963369553161170521356,
0.00212616472603945399437862,
0.000394984014495083900689956,
0.365565477064442377259271e-4,
0.135485897109932323253786e-5,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [3.5, 5.25].
const ERF_IMPL_FD: &[f64] = &[
1.0,
1.21019697773630784832251,
0.620914668221143886601045,
0.173038430661142762569515,
0.0276550813773432047594539,
0.00240625974424309709745382,
0.891811817251336577241006e-4,
-0.465528836283382684461025e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GN: &[f64] = &[
-0.00270722535905778347999196,
0.0013187563425029400461378,
0.00119925933261002333923989,
0.00027849619811344664248235,
0.267822988218331849989363e-4,
0.923043672315028197865066e-6,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [5.25, 8].
const ERF_IMPL_GD: &[f64] = &[
1.0,
0.814632808543141591118279,
0.268901665856299542168425,
0.0449877216103041118694989,
0.00381759663320248459168994,
0.000131571897888596914350697,
0.404815359675764138445257e-11,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HN: &[f64] = &[
-0.00109946720691742196814323,
0.000406425442750422675169153,
0.000274499489416900707787024,
0.465293770646659383436343e-4,
0.320955425395767463401993e-5,
0.778286018145020892261936e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [8, 11.5].
const ERF_IMPL_HD: &[f64] = &[
1.0,
0.588173710611846046373373,
0.139363331289409746077541,
0.0166329340417083678763028,
0.00100023921310234908642639,
0.24254837521587225125068e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_IN: &[f64] = &[
-0.00056907993601094962855594,
0.000169498540373762264416984,
0.518472354581100890120501e-4,
0.382819312231928859704678e-5,
0.824989931281894431781794e-7,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [11.5, 17].
const ERF_IMPL_ID: &[f64] = &[
1.0,
0.339637250051139347430323,
0.043472647870310663055044,
0.00248549335224637114641629,
0.535633305337152900549536e-4,
-0.117490944405459578783846e-12,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JN: &[f64] = &[
-0.000241313599483991337479091,
0.574224975202501512365975e-4,
0.115998962927383778460557e-4,
0.581762134402593739370875e-6,
0.853971555085673614607418e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [17, 24].
const ERF_IMPL_JD: &[f64] = &[
1.0,
0.233044138299687841018015,
0.0204186940546440312625597,
0.000797185647564398289151125,
0.117019281670172327758019e-4,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KN: &[f64] = &[
-0.000146674699277760365803642,
0.162666552112280519955647e-4,
0.269116248509165239294897e-5,
0.979584479468091935086972e-7,
0.101994647625723465722285e-8,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [24, 38].
const ERF_IMPL_KD: &[f64] = &[
1.0,
0.165907812944847226546036,
0.0103361716191505884359634,
0.000286593026373868366935721,
0.298401570840900340874568e-5,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LN: &[f64] = &[
-0.583905797629771786720406e-4,
0.412510325105496173512992e-5,
0.431790922420250949096906e-6,
0.993365155590013193345569e-8,
0.653480510020104699270084e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [38, 60].
const ERF_IMPL_LD: &[f64] = &[
1.0,
0.105077086072039915406159,
0.00414278428675475620830226,
0.726338754644523769144108e-4,
0.477818471047398785369849e-6,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MN: &[f64] = &[
-0.196457797609229579459841e-4,
0.157243887666800692441195e-5,
0.543902511192700878690335e-7,
0.317472492369117710852685e-9,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [60, 85].
const ERF_IMPL_MD: &[f64] = &[
1.0,
0.052803989240957632204885,
0.000926876069151753290378112,
0.541011723226630257077328e-5,
0.535093845803642394908747e-15,
];
/// Polynomial coefficients for a numerator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_NN: &[f64] = &[
-0.789224703978722689089794e-5,
0.622088451660986955124162e-6,
0.145728445676882396797184e-7,
0.603715505542715364529243e-10,
];
/// Polynomial coefficients for a denominator in `erf_impl`
/// in the interval [85, 110].
const ERF_IMPL_ND: &[f64] = &[
1.0,
0.0375328846356293715248719,
0.000467919535974625308126054,
0.193847039275845656900547e-5,
];
// **********************************************************
// ********** Coefficients for erf_inv_impl polynomial ******
// **********************************************************
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AN: &[f64] = &[
-0.000508781949658280665617,
-0.00836874819741736770379,
0.0334806625409744615033,
-0.0126926147662974029034,
-0.0365637971411762664006,
0.0219878681111168899165,
0.00822687874676915743155,
-0.00538772965071242932965,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0, 0.5].
const ERF_INV_IMPL_AD: &[f64] = &[
1.0,
-0.970005043303290640362,
-1.56574558234175846809,
1.56221558398423026363,
0.662328840472002992063,
-0.71228902341542847553,
-0.0527396382340099713954,
0.0795283687341571680018,
-0.00233393759374190016776,
0.000886216390456424707504,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BN: &[f64] = &[
-0.202433508355938759655,
0.105264680699391713268,
8.37050328343119927838,
17.6447298408374015486,
-18.8510648058714251895,
-44.6382324441786960818,
17.445385985570866523,
21.1294655448340526258,
-3.67192254707729348546,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.5, 0.75].
const ERF_INV_IMPL_BD: &[f64] = &[
1.0,
6.24264124854247537712,
3.9713437953343869095,
-28.6608180499800029974,
-20.1432634680485188801,
48.5609213108739935468,
10.8268667355460159008,
-22.6436933413139721736,
1.72114765761200282724,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CN: &[f64] = &[
-0.131102781679951906451,
-0.163794047193317060787,
0.117030156341995252019,
0.387079738972604337464,
0.337785538912035898924,
0.142869534408157156766,
0.0290157910005329060432,
0.00214558995388805277169,
-0.679465575181126350155e-6,
0.285225331782217055858e-7,
-0.681149956853776992068e-9,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x less than 3.
const ERF_INV_IMPL_CD: &[f64] = &[
1.0,
3.46625407242567245975,
5.38168345707006855425,
4.77846592945843778382,
2.59301921623620271374,
0.848854343457902036425,
0.152264338295331783612,
0.01105924229346489121,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DN: &[f64] = &[
-0.0350353787183177984712,
-0.00222426529213447927281,
0.0185573306514231072324,
0.00950804701325919603619,
0.00187123492819559223345,
0.000157544617424960554631,
0.460469890584317994083e-5,
-0.230404776911882601748e-9,
0.266339227425782031962e-11,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 3 and 6.
const ERF_INV_IMPL_DD: &[f64] = &[
1.0,
1.3653349817554063097,
0.762059164553623404043,
0.220091105764131249824,
0.0341589143670947727934,
0.00263861676657015992959,
0.764675292302794483503e-4,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_EN: &[f64] = &[
-0.0167431005076633737133,
-0.00112951438745580278863,
0.00105628862152492910091,
0.000209386317487588078668,
0.149624783758342370182e-4,
0.449696789927706453732e-6,
0.462596163522878599135e-8,
-0.281128735628831791805e-13,
0.99055709973310326855e-16,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 6 and 18.
const ERF_INV_IMPL_ED: &[f64] = &[
1.0,
0.591429344886417493481,
0.138151865749083321638,
0.0160746087093676504695,
0.000964011807005165528527,
0.275335474764726041141e-4,
0.282243172016108031869e-6,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FN: &[f64] = &[
-0.0024978212791898131227,
-0.779190719229053954292e-5,
0.254723037413027451751e-4,
0.162397777342510920873e-5,
0.396341011304801168516e-7,
0.411632831190944208473e-9,
0.145596286718675035587e-11,
-0.116765012397184275695e-17,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x between 18 and 44.
const ERF_INV_IMPL_FD: &[f64] = &[
1.0,
0.207123112214422517181,
0.0169410838120975906478,
0.000690538265622684595676,
0.145007359818232637924e-4,
0.144437756628144157666e-6,
0.509761276599778486139e-9,
];
/// Polynomial coefficients for a numerator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GN: &[f64] = &[
-0.000539042911019078575891,
-0.28398759004727721098e-6,
0.899465114892291446442e-6,
0.229345859265920864296e-7,
0.225561444863500149219e-9,
0.947846627503022684216e-12,
0.135880130108924861008e-14,
-0.348890393399948882918e-21,
];
/// Polynomial coefficients for a denominator of `erf_inv_impl`
/// in the interval [0.75, 1] with x greater than 44.
const ERF_INV_IMPL_GD: &[f64] = &[
1.0,
0.0845746234001899436914,
0.00282092984726264681981,
0.468292921940894236786e-4,
0.399968812193862100054e-6,
0.161809290887904476097e-8,
0.231558608310259605225e-11,
];
/// `erf_impl` computes the error function at `z`.
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
fn erf_impl(z: f64, inv: bool) -> f64 {
if z < 0.0 {
if !inv {
return -erf_impl(-z, false);
}
if z < -0.5 {
return 2.0 - erf_impl(-z, true);
}
return 1.0 + erf_impl(-z, false);
}
let result = if z < 0.5 {
if z < 1e-10 {
z * 1.125 + z * 0.003379167095512573896158903121545171688
} else {
z * 1.125
+ z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD)
}
} else if z < 110.0 {
let (r, b) = if z < 0.75 {
(
evaluate::polynomial(z - 0.5, ERF_IMPL_BN)
/ evaluate::polynomial(z - 0.5, ERF_IMPL_BD),
0.3440242112,
)
} else if z < 1.25 {
(
evaluate::polynomial(z - 0.75, ERF_IMPL_CN)
/ evaluate::polynomial(z - 0.75, ERF_IMPL_CD),
0.419990927,
)
} else if z < 2.25 {
(
evaluate::polynomial(z - 1.25, ERF_IMPL_DN)
/ evaluate::polynomial(z - 1.25, ERF_IMPL_DD),
0.4898625016,
)
} else if z < 3.5 {
(
evaluate::polynomial(z - 2.25, ERF_IMPL_EN)
/ evaluate::polynomial(z - 2.25, ERF_IMPL_ED),
0.5317370892,
)
} else if z < 5.25 {
(
evaluate::polynomial(z - 3.5, ERF_IMPL_FN)
/ evaluate::polynomial(z - 3.5, ERF_IMPL_FD),
0.5489973426,
)
} else if z < 8.0 {
(
evaluate::polynomial(z - 5.25, ERF_IMPL_GN)
/ evaluate::polynomial(z - 5.25, ERF_IMPL_GD),
0.5571740866,
)
} else if z < 11.5 {
(
evaluate::polynomial(z - 8.0, ERF_IMPL_HN)
/ evaluate::polynomial(z - 8.0, ERF_IMPL_HD),
0.5609807968,
)
} else if z < 17.0 {
(
evaluate::polynomial(z - 11.5, ERF_IMPL_IN)
/ evaluate::polynomial(z - 11.5, ERF_IMPL_ID),
0.5626493692,
)
} else if z < 24.0 {
(
evaluate::polynomial(z - 17.0, ERF_IMPL_JN)
/ evaluate::polynomial(z - 17.0, ERF_IMPL_JD),
0.5634598136,
)
} else if z < 38.0 {
(
evaluate::polynomial(z - 24.0, ERF_IMPL_KN)
/ evaluate::polynomial(z - 24.0, ERF_IMPL_KD),
0.5638477802,
)
} else if z < 60.0 {
(
evaluate::polynomial(z - 38.0, ERF_IMPL_LN)
/ evaluate::polynomial(z - 38.0, ERF_IMPL_LD),
0.5640528202,
)
} else if z < 85.0 {
(
evaluate::polynomial(z - 60.0, ERF_IMPL_MN)
/ evaluate::polynomial(z - 60.0, ERF_IMPL_MD),
0.5641309023,
)
} else {
(
evaluate::polynomial(z - 85.0, ERF_IMPL_NN)
/ evaluate::polynomial(z - 85.0, ERF_IMPL_ND),
0.5641584396,
)
};
let g = (-z * z).exp() / z;
g * b + g * r
} else {
0.0
};
if inv && z >= 0.5 {
result
} else if z >= 0.5 || inv {
1.0 - result
} else {
result
}
}
// `erf_inv_impl` computes the inverse error function where
// `p`,`q`, and `s` are the first, second, and third intermediate
// parameters respectively
fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 {
let result = if p <= 0.5 {
let y = 0.0891314744949340820313;
let g = p * (p + 10.0);
let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD);
g * y + g * r
} else if q >= 0.25 {
let y = 2.249481201171875;
let g = (-2.0 * q.ln()).sqrt();
let xs = q - 0.25;
let r =
evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD);
g / (y + r)
} else {
let x = (-q.ln()).sqrt();
if x < 3.0 {
let y = 0.807220458984375;
let xs = x - 1.125;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_CD);
y * x + r * x
} else if x < 6.0 {
let y = 0.93995571136474609375;
let xs = x - 3.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_DD);
y * x + r * x
} else if x < 18.0 {
let y = 0.98362827301025390625;
let xs = x - 6.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_ED);
y * x + r * x
} else if x < 44.0 {
let y = 0.99714565277099609375;
let xs = x - 18.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_FD);
y * x + r * x
} else {
let y = 0.99941349029541015625;
let xs = x - 44.0;
let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN)
/ evaluate::polynomial(xs, ERF_INV_IMPL_GD);
y * x + r * x
}
};
s * result
}

View File

@ -1,3 +1,4 @@
pub mod erf;
pub mod kernels;
trait Cpu<const ARR: usize> {

View File

@ -4,6 +4,9 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
use rayon::prelude::*;
const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)]
@ -724,6 +727,36 @@ impl Map1 for MaxPool2D {
}
}
struct UpsampleNearest1D(usize);
impl Map1 for UpsampleNearest1D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// TODO: Specialized implementation for the case 2*sz?
let dst_sz = self.0;
let (b_sz, c, src_sz) = layout.shape().dims3()?;
let stride = layout.stride();
let stride_sz = stride[2];
let src_index = layout.start_offset();
let scale_sz = src_sz as f64 / dst_sz as f64;
let mut dst = vec![T::zero(); b_sz * c * dst_sz];
let src_idxs = (0..dst_sz)
.map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
.collect::<Vec<_>>();
for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * dst_sz..];
let src_index = src_index + b_idx * stride[0];
for c_idx in 0..c {
let dst = &mut dst[c_idx * dst_sz..];
let src_index = src_index + c_idx * stride[1];
for (idx, src_idx) in src_idxs.iter().enumerate() {
dst[idx] = src[src_index + src_idx * stride_sz]
}
}
}
Ok(dst)
}
}
struct UpsampleNearest2D(usize, usize);
impl Map1 for UpsampleNearest2D {
@ -1089,6 +1122,140 @@ impl<'a> Map2 for Conv1D<'a> {
}
}
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let &Self {
l_k,
stride,
dilation,
padding,
} = self;
let (b, c, l) = layout.shape().dims3()?;
let l_out = self.l_out(l);
let src = &vs[layout.start_offset()..];
let mut dst = vec![T::zero(); b * l_out * c * l_k];
let (src_s0, src_s1, src_s2) = {
let s = layout.stride();
(s[0], s[1], s[2])
};
// TODO: provide specialized kernels for the common use cases.
// - l_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * l_out * c * l_k;
for l_idx in 0..l_out {
let dst_idx = dst_idx + l_idx * c * l_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * l_k;
let src_idx = c_idx * src_s1 + src_idx;
for l_k_idx in 0..l_k {
let src_l = l_idx * stride + l_k_idx * dilation;
if padding != 0 && (src_l < padding || src_l >= l + padding) {
continue;
}
let src_l = src_l - padding;
let src_idx = src_idx + src_l * src_s2;
let dst_idx = dst_idx + l_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let &Self {
h_k,
w_k,
stride,
dilation,
padding,
} = self;
let (b, c, h, w) = layout.shape().dims4()?;
let (h_out, w_out) = self.hw_out(h, w);
let src = &vs[layout.start_offset()..];
let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
let (src_s0, src_s1, src_s2, src_s3) = {
let s = layout.stride();
(s[0], s[1], s[2], s[3])
};
// TODO: provide specialized kernels for the common use cases.
// - h_k = w_k = 1
// - padding = 0
// - stride = 1
// - dilation = 1
for b_idx in 0..b {
let src_idx = b_idx * src_s0;
let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
for h_idx in 0..h_out {
let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
for w_idx in 0..w_out {
let dst_idx = dst_idx + w_idx * c * h_k * w_k;
for c_idx in 0..c {
let dst_idx = dst_idx + c_idx * h_k * w_k;
let src_idx = c_idx * src_s1 + src_idx;
for h_k_idx in 0..h_k {
let src_h = h_idx * stride + h_k_idx * dilation;
if padding != 0 && (src_h < padding || src_h >= h + padding) {
continue;
}
let src_h = src_h - padding;
let src_idx = src_idx + src_h * src_s2;
let dst_idx = dst_idx + h_k_idx * w_k;
for w_k_idx in 0..w_k {
let src_w = w_idx * stride + w_k_idx * dilation;
if padding != 0 && (src_w < padding || src_w >= w + padding) {
continue;
}
let src_w = src_w - padding;
let src_idx = src_idx + src_w * src_s3;
let dst_idx = dst_idx + w_k_idx;
dst[dst_idx] = src[src_idx]
}
}
}
}
}
}
Ok(dst)
}
}
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
impl<'a> Map2 for Conv2D<'a> {
@ -1294,8 +1461,9 @@ impl Map2 for MatMul {
) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism};
if T::DTYPE == DType::BF16 {
return Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?;
match T::DTYPE {
DType::F16 | DType::F32 | DType::F64 => {}
_ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
}
let (b, m, n, k) = self.0;
@ -1999,6 +2167,10 @@ impl BackendStorage for CpuStorage {
MaxPool2D(kernel_size, stride).map(self, layout)
}
fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
UpsampleNearest1D(sz).map(self, layout)
}
fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
UpsampleNearest2D(h, w).map(self, layout)
}
@ -2227,7 +2399,40 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Conv1D(params).map(self, l, kernel, kernel_l)
if !USE_IM2COL_CONV1D {
return Conv1D(params).map(self, l, kernel, kernel_l);
}
let op = Im2Col1D {
l_k: params.k_size,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
};
let col = op.map(self, l)?;
let b = params.b_size;
let n = params.c_out;
let l_out = params.l_out();
let k = op.l_k * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv2d(
@ -2237,7 +2442,43 @@ impl BackendStorage for CpuStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Conv2D(params).map(self, l, kernel, kernel_l)
if !USE_IM2COL_CONV2D {
return Conv2D(params).map(self, l, kernel, kernel_l);
}
let op = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
padding: params.padding,
stride: params.stride,
dilation: params.dilation,
};
let col = op.map(self, l)?;
let b = params.b_size;
let n = params.c_out;
let (h_out, w_out) = (params.out_h(), params.out_w());
let k = op.h_k * op.w_k * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
fn conv_transpose2d(
@ -2362,6 +2603,10 @@ impl BackendDevice for CpuDevice {
Ok(Self)
}
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("cannot seed the CPU rng with set_seed")
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
use rand::prelude::*;

View File

@ -223,6 +223,12 @@ impl BackendDevice for CudaDevice {
})
}
fn set_seed(&self, seed: u64) -> Result<()> {
let mut curand = self.curand.lock().unwrap();
curand.0.set_seed(seed).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
@ -312,6 +318,13 @@ impl BackendDevice for CudaDevice {
// cudarc changes.
let elem_count = shape.elem_count();
let curand = self.curand.lock().unwrap();
// curand can only generate an odd number of values.
// https://github.com/huggingface/candle/issues/734
let elem_count_round = if elem_count % 2 == 1 {
elem_count + 1
} else {
elem_count
};
let slice = match dtype {
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
Err(CudaError::UnsupportedDtype {
@ -321,7 +334,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@ -329,7 +342,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@ -593,6 +606,105 @@ impl Map1 for Elu {
}
}
struct Im2Col1D {
l_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col1D {
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
}
impl Map1 for Im2Col1D {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
l_out,
self.l_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Im2Col {
h_k: usize,
w_k: usize,
stride: usize,
dilation: usize,
padding: usize,
}
impl Im2Col {
fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
(h_out, w_out)
}
}
impl Map1 for Im2Col {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
dst_el,
h_out,
w_out,
self.h_k,
self.w_k,
self.stride,
self.padding,
self.dilation,
&ds,
src,
&dst,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
struct Powf(f64);
impl Map1 for Powf {
fn f<T: DeviceRepr + WithDType>(
@ -778,8 +890,6 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
let ids_el = ids_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(ids_el as u32);
let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
@ -787,19 +897,23 @@ impl<'a> Map1 for IndexSelect<'a> {
};
let left_size: usize = src_l.dims()[..self.2].iter().product();
let right_size: usize = src_l.dims()[self.2 + 1..].iter().product();
let dim_size = src_l.dims()[self.2];
let src_dim_size = src_l.dims()[self.2];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(ids_el * left_size * right_size) }.w()?;
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let params = (
ids_el,
dst_el,
ids_dims.len(),
&ds,
ids,
&src,
&out,
left_size,
dim_size,
src_dim_size,
ids_dim_size,
right_size,
);
// SAFETY: ffi.
@ -1650,9 +1764,46 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
const USE_IM2COL_CONV1D: bool = true;
let device = self.device().clone();
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
if !USE_IM2COL_CONV1D {
let slice = Conv1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let col = Im2Col1D {
l_k: params.k_size,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let l_out = params.l_out();
let b = params.b_size;
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
#[cfg(not(feature = "cudnn"))]
@ -1663,9 +1814,50 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout,
params: &crate::conv::ParamsConv2D,
) -> Result<Self> {
const USE_IM2COL_CONV2D: bool = true;
let device = self.device().clone();
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
if !USE_IM2COL_CONV2D {
let slice = Conv2D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let col = Im2Col {
h_k: params.k_h,
w_k: params.k_w,
stride: params.stride,
dilation: params.dilation,
padding: params.padding,
}
.map(&self.slice, &device, l)?;
let col = Self { slice: col, device };
let h_out = params.out_h();
let w_out = params.out_w();
let b = params.b_size;
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
.transpose(1, 3)?;
let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
res.copy_strided_src(&mut res_t, 0, &res_l)?;
Ok(res_t)
}
#[cfg(feature = "cudnn")]
@ -1770,6 +1962,10 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
fn upsample_nearest1d(&self, _: &Layout, _out_sz: usize) -> Result<Self> {
crate::bail!("upsample-nearest1d is not supported on cuda")
}
fn upsample_nearest2d(&self, l: &Layout, out_w: usize, out_h: usize) -> Result<Self> {
let device = self.device().clone();
let slice = UpsampleNearest2D(out_w, out_h).map(&self.slice, &device, l)?;

View File

@ -34,6 +34,9 @@ pub(crate) fn launch_conv2d<
params: &crate::conv::ParamsConv2D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
@ -90,7 +93,20 @@ pub(crate) fn launch_conv2d<
w: &w,
y: &y,
};
let alg = conv2d.pick_algorithm()?;
let alg = match params.cudnn_fwd_algo {
None => conv2d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
unsafe {

View File

@ -67,6 +67,20 @@ impl DType {
Self::F64 => 8,
}
}
pub fn is_int(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => true,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false,
}
}
pub fn is_float(&self) -> bool {
match self {
Self::U8 | Self::U32 | Self::I64 => false,
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
}
}
}
pub trait WithDType:

View File

@ -152,6 +152,10 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
@ -163,6 +167,10 @@ impl crate::backend::BackendDevice for CudaDevice {
Err(Error::NotCompiledWithCudaSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}

View File

@ -46,19 +46,31 @@ impl Tensor {
current_dim += 1;
out
}
TensorIndexer::IndexSelect(indexes) => {
if indexes.rank() != 1 {
crate::bail!("multi-dimensional tensor indexing is not supported")
}
let out = x.index_select(&indexes.to_device(x.device())?, current_dim)?;
current_dim += 1;
out
}
TensorIndexer::Err(e) => crate::bail!("indexing error {e:?}"),
};
}
Ok(x)
}
}
#[derive(Debug, Clone)]
#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
/// Indexing via a 1d tensor
IndexSelect(Tensor),
Err(Error),
}
impl From<usize> for TensorIndexer {
@ -67,6 +79,31 @@ impl From<usize> for TensorIndexer {
}
}
impl From<&[u32]> for TensorIndexer {
fn from(index: &[u32]) -> Self {
match Tensor::new(index, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<Vec<u32>> for TensorIndexer {
fn from(index: Vec<u32>) -> Self {
let len = index.len();
match Tensor::from_vec(index, len, &crate::Device::Cpu) {
Ok(tensor) => TensorIndexer::IndexSelect(tensor),
Err(e) => TensorIndexer::Err(e),
}
}
}
impl From<&Tensor> for TensorIndexer {
fn from(tensor: &Tensor) -> Self {
TensorIndexer::IndexSelect(tensor.clone())
}
}
macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {

View File

@ -110,14 +110,8 @@ impl ToUsize2 for (usize, usize) {
}
// A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug {
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
/// Change the module to use training mode vs eval mode.
///
/// The default implementation does nothing as this is only used for a couple modules such as
/// dropout or batch-normalization.
fn set_training(&mut self, _training: bool) {}
}
impl Module for quantized::QMatMul {
@ -125,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs)
}
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}

View File

@ -58,8 +58,13 @@ pub enum UnaryOp {
Sqr,
Sqrt,
Gelu,
GeluErf,
Erf,
Relu,
Tanh,
Floor,
Ceil,
Round,
}
#[derive(Clone)]
@ -116,6 +121,7 @@ pub enum Op {
stride: (usize, usize),
},
UpsampleNearest1D(Tensor),
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),
@ -130,6 +136,7 @@ pub enum Op {
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
@ -324,8 +331,13 @@ pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
pub(crate) struct Round;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
@ -600,6 +612,194 @@ impl UnaryOpT for Gelu {
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_gelu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_gelu(xs, ys)
}
}
impl UnaryOpT for Erf {
const NAME: &'static str = "erf";
const KERNEL: &'static str = "uerf";
const V: Self = Erf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
crate::cpu::erf::erf(v)
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
const V: Self = Ceil;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.ceil()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.ceil()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.ceil()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.ceil()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Floor {
const NAME: &'static str = "floor";
const KERNEL: &'static str = "ufloor";
const V: Self = Floor;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.floor()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.floor()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.floor()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.floor()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for Round {
const NAME: &'static str = "round";
const KERNEL: &'static str = "uround";
const V: Self = Round;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.round()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.round()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.round()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.round()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v
}
}
impl UnaryOpT for GeluErf {
const NAME: &'static str = "gelu_erf";
const KERNEL: &'static str = "ugelu_erf";
const V: Self = GeluErf;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
bf16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f16(v: f16) -> f16 {
f16::from_f64(Self::f64(v.to_f64()))
}
#[inline(always)]
fn f32(v: f32) -> f32 {
Self::f64(v as f64) as f32
}
#[inline(always)]
fn f64(v: f64) -> f64 {
(crate::cpu::erf::erf(v / 2f64.sqrt()) + 1.) * 0.5 * v
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
}
impl UnaryOpT for Relu {

View File

@ -638,3 +638,35 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
Ok(hsum_float_8(acc) + summs)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % qk != 0 {
crate::bail!("vec_dot_q8k_8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (xs, ys) in xs.iter().zip(ys.iter()) {
let mut sumi = _mm256_setzero_si256();
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
for j in (0..QK_K).step_by(32) {
let xs = _mm256_loadu_si256(x_qs.add(j) as *const __m256i);
let ys = _mm256_loadu_si256(y_qs.add(j) as *const __m256i);
let xs0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 0));
let ys0 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 0));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs0, ys0));
let xs1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xs, 1));
let ys1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(ys, 1));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(xs1, ys1));
}
let d = _mm256_set1_ps(xs.d * ys.d);
acc = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi), acc);
}
Ok(hsum_float_8(acc))
}
}

View File

@ -135,7 +135,13 @@ pub fn qtensor_from_ggml(
dims: Vec<usize>,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
let blck_size = ggml_dtype.blck_size();
if tensor_elems % blck_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
)
}
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),

View File

@ -59,8 +59,13 @@ impl TensorInfo {
tensor_data_offset: u64,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let size_in_bytes =
tensor_elems * self.ggml_dtype.type_size() / self.ggml_dtype.blck_size();
let blck_size = self.ggml_dtype.blck_size();
if tensor_elems % blck_size != 0 {
crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
)
}
let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;

View File

@ -34,6 +34,9 @@ pub trait GgmlType: Sized + Clone + Send + Sync {
/// Dot product used as a building block for quantized mat-mul.
/// n is the number of elements to be considered.
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
/// Generic implementation of the dot product without simd optimizations.
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
}
#[derive(Debug, Clone, PartialEq)]
@ -225,6 +228,13 @@ impl GgmlType for BlockQ4_0 {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
@ -255,6 +265,10 @@ impl GgmlType for BlockQ4_1 {
type VecDotType = BlockQ8_1;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
// ggml_vec_dot_q4_1_q8_1
let qk = QK8_1;
if n % qk != 0 {
@ -354,7 +368,10 @@ impl GgmlType for BlockQ5_0 {
if nb % 2 != 0 {
crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
}
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
// Generic implementation.
let mut sumf = 0f32;
@ -445,6 +462,10 @@ impl GgmlType for BlockQ5_1 {
type VecDotType = BlockQ8_1;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = Self::BLCK_SIZE;
if n % Self::BLCK_SIZE != 0 {
crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
@ -606,6 +627,13 @@ impl GgmlType for BlockQ8_0 {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
@ -631,7 +659,11 @@ impl GgmlType for BlockQ8_1 {
const BLCK_SIZE: usize = QK8_1;
type VecDotType = BlockQ8_1;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
unimplemented!("no support for vec-dot on Q8_1")
}
@ -681,6 +713,13 @@ impl GgmlType for BlockQ2K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q2k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
@ -701,18 +740,17 @@ impl GgmlType for BlockQ2K {
let mut isum = 0;
let mut is = 0;
let mut d;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
d = (sc[is] & 0xF) as i32;
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = 0;
for l in 0..16 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
}
isum += d * isuml;
d = (sc[is] & 0xF) as i32;
let d = (sc[is] & 0xF) as i32;
is += 1;
isuml = 0;
for l in 16..32 {
@ -851,6 +889,10 @@ impl GgmlType for BlockQ3K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q3k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
}
@ -1077,7 +1119,6 @@ impl GgmlType for BlockQ3K {
let d_all = block.d.to_f32();
let mut m = 1;
let mut is = 0;
let mut dl;
// Dequantize both 128 long blocks
// 32 qs values per 128 long block
@ -1088,7 +1129,7 @@ impl GgmlType for BlockQ3K {
for (scale_index, scale_scoped_y) in
shift_scoped_y.chunks_exact_mut(16).enumerate()
{
dl = d_all * (scales[is] as f32 - 32.0);
let dl = d_all * (scales[is] as f32 - 32.0);
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
let new_y = dl
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
@ -1126,6 +1167,13 @@ impl GgmlType for BlockQ4K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
@ -1312,6 +1360,10 @@ impl GgmlType for BlockQ5K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q5k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
}
@ -1529,6 +1581,13 @@ impl GgmlType for BlockQ6K {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q6k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
@ -1697,8 +1756,38 @@ impl GgmlType for BlockQ8K {
const BLCK_SIZE: usize = QK_K;
type VecDotType = BlockQ8K;
fn vec_dot(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
unreachable!()
#[allow(unreachable_code)]
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
#[cfg(target_feature = "avx")]
return super::avx::vec_dot_q8k_q8k(n, xs, ys);
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q8k_q8k(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
let sum_i = xs
.qs
.iter()
.zip(ys.qs.iter())
.map(|(&x, &y)| x as i32 * y as i32)
.sum::<i32>();
sumf += sum_i as f32 * xs.d * ys.d
}
Ok(sumf)
}
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
@ -1804,6 +1893,10 @@ impl GgmlType for f32 {
type VecDotType = f32;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}
@ -1838,6 +1931,10 @@ impl GgmlType for f16 {
type VecDotType = f16;
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}

View File

@ -7,6 +7,8 @@ pub mod gguf_file;
pub mod k_quants;
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
pub mod simd128;
pub mod utils;
pub use k_quants::GgmlType;
@ -229,20 +231,40 @@ impl QTensor {
}
}
#[derive(Debug)]
pub struct QMatMul(std::sync::Arc<QTensor>);
#[derive(Clone, Debug)]
pub enum QMatMul {
QTensor(std::sync::Arc<QTensor>),
Tensor(Tensor),
}
thread_local! {
static DEQUANTIZE_ALL: bool = {
match std::env::var("CANDLE_DEQUANTIZE_ALL") {
Ok(s) => {
!s.is_empty() && s != "0"
},
Err(_) => false,
}
}
}
impl QMatMul {
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Self {
Self(qtensor)
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
let dequantize = match qtensor.dtype() {
GgmlDType::F32 | GgmlDType::F16 => true,
_ => DEQUANTIZE_ALL.with(|b| *b),
};
let t = if dequantize {
let tensor = qtensor.dequantize(&Device::Cpu)?;
Self::Tensor(tensor)
} else {
Self::QTensor(qtensor)
};
Ok(t)
}
pub fn from_qtensor(qtensor: QTensor) -> Self {
Self(std::sync::Arc::new(qtensor))
}
pub fn inner(&self) -> &std::sync::Arc<QTensor> {
&self.0
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
Self::from_arc(std::sync::Arc::new(qtensor))
}
}
@ -287,6 +309,16 @@ impl crate::CustomOp1 for QTensor {
impl QMatMul {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(self.0.as_ref())
match self {
Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
Self::Tensor(w) => {
let w = match *xs.dims() {
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
_ => w.t()?,
};
xs.matmul(&w)
}
}
}
}

View File

@ -148,6 +148,35 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {
unsafe {
let mut sum_i = vdupq_n_s32(0);
let scale = xs.d * ys.d;
let xs = xs.qs.as_ptr();
let ys = ys.qs.as_ptr();
for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys));
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
sum_i = vaddq_s32(sum_i, xy)
}
sumf += vaddvq_s32(sum_i) as f32 * scale
}
}
Ok(sumf)
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {

View File

@ -0,0 +1,427 @@
use super::k_quants::{BlockQ2K, BlockQ4K, BlockQ4_0, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K};
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use core::arch::wasm32::*;
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
let x12 = v128_and(x1234, u8x16_splat(0x0F));
let x12 = i8x16_sub(x12, i8x16_splat(8));
let x34 = u8x16_shr(x1234, 4);
let x34 = i8x16_sub(x34, i8x16_splat(8));
let x1 = i16x8_extend_low_i8x16(x12);
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_extend_high_i8x16(x12);
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_extend_low_i8x16(x34);
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_extend_high_i8x16(x34);
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1 = i16x8_load_extend_i8x8(x.qs.as_ptr());
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(8));
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(16));
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_load_extend_i8x8(x.qs.as_ptr().add(24));
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
unsafe {
let mut sumf = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;
let mut summs = i32x4_splat(0);
for i in (0..(QK_K / 16)).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
let scales = i32x4_shr(
i32x4(
sc[i] as i32,
sc[i + 1] as i32,
sc[i + 2] as i32,
sc[i + 3] as i32,
),
4,
);
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
}
let summs = f32x4_convert_i32x4(summs);
let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();
let mut isum = i32x4_splat(0);
let mut is = 0;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (0..16).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (16..32).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
shift += 2;
// adjust the indexing
q8 = &q8[32..];
}
// adjust the indexing
q2 = &q2[32..];
}
let isum = f32x4_convert_i32x4(isum);
sumf = f32x4_add(
sumf,
f32x4_sub(
f32x4_mul(isum, f32x4_splat(dall)),
f32x4_mul(summs, f32x4_splat(dmin)),
),
);
}
let sumf = f32x4_extract_lane::<0>(sumf)
+ f32x4_extract_lane::<1>(sumf)
+ f32x4_extract_lane::<2>(sumf)
+ f32x4_extract_lane::<3>(sumf);
Ok(sumf)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
}
const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;
let mut utmp: [u32; 4] = [0; 4];
let mut scales: [u8; 8] = [0; 8];
let mut mins: [u8; 8] = [0; 8];
let mut aux8: [u8; QK_K] = [0; QK_K];
let mut sums = f32x4_splat(0f32);
unsafe {
for (y, x) in ys.iter().zip(xs.iter()) {
let q4 = &x.qs;
let q8 = &y.qs;
for j in 0..QK_K / 64 {
let q4_1 = v128_load(q4.as_ptr().add(32 * j) as *const v128);
let q4_2 = v128_load(q4.as_ptr().add(32 * j + 16) as *const v128);
v128_store(
aux8.as_mut_ptr().add(64 * j) as *mut v128,
v128_and(q4_1, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 16) as *mut v128,
v128_and(q4_2, u8x16_splat(0x0F)),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 32) as *mut v128,
u8x16_shr(q4_1, 4),
);
v128_store(
aux8.as_mut_ptr().add(64 * j + 48) as *mut v128,
u8x16_shr(q4_2, 4),
);
}
LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
let uaux = utmp[1] & KMASK1;
utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
utmp[2] = uaux;
utmp[0] &= KMASK1;
//extract scales and mins
LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
let mut sumi = i32x4_splat(0);
for j in (0..QK_K / 16).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(j));
let (m1, m2) = (mins[j / 2] as i32, mins[j / 2 + 1] as i32);
let mins = i32x4(m1, m1, m2, m2);
sumi = i32x4_add(sumi, i32x4_mul(bsums, mins));
}
let mut aux32 = i32x4_splat(0i32);
for (scale_i, scale) in scales.iter().enumerate() {
let scale = i32x4_splat(*scale as i32);
for j in 0..4 {
let i = 32 * scale_i + 8 * j;
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(i));
let aux8 = i16x8_load_extend_u8x8(aux8.as_ptr().add(i));
let aux16 = i16x8_mul(q8, aux8);
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_low_i16x8(aux16)));
aux32 = i32x4_add(aux32, i32x4_mul(scale, i32x4_extend_high_i16x8(aux16)));
}
}
let aux32 = f32x4_convert_i32x4(aux32);
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
let dmin = x.dmin.to_f32() * y.d;
let dmin = f32x4_splat(dmin);
let sumi = f32x4_convert_i32x4(sumi);
sums = f32x4_sub(sums, f32x4_mul(sumi, dmin));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
}
let mut aux8 = [0i8; QK_K];
unsafe {
let mut sums = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let q4 = &x.ql;
let qh = &x.qh;
let q8 = &y.qs;
let mut aux32 = f32x4_splat(0f32);
for j in (0..QK_K).step_by(128) {
let aux8 = aux8.as_mut_ptr().add(j);
let q4 = &q4.as_ptr().add(j / 2);
let qh = &qh.as_ptr().add(j / 4);
for l in (0..32).step_by(16) {
// aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(v128_load(qh.add(l) as *const v128), u8x16_splat(3)),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 32] =
// (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
v128_and(v128_load(q4.add(l + 32) as *const v128), u8x16_splat(0xF)),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 2),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 32) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 4),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 64) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
// aux8[l + 96] =
// (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
let a8 = v128_or(
u8x16_shr(v128_load(q4.add(l + 32) as *const v128), 4),
u8x16_shl(
v128_and(
u8x16_shr(v128_load(qh.add(l) as *const v128), 6),
u8x16_splat(3),
),
4,
),
);
let a8_low = i16x8_sub(i16x8_extend_low_u8x16(a8), i16x8_splat(32));
let a8_high = i16x8_sub(i16x8_extend_high_u8x16(a8), i16x8_splat(32));
v128_store(
aux8.add(l + 96) as *mut v128,
i8x16_narrow_i16x8(a8_low, a8_high),
);
}
}
for (j, &scale) in x.scales.iter().enumerate() {
let scale = f32x4_splat(scale as f32);
for offset in [0, 8] {
let aux16 = i16x8_mul(
i16x8_load_extend_i8x8(q8.as_ptr().add(16 * j + offset)),
i16x8_load_extend_i8x8(aux8.as_ptr().add(16 * j + offset)),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_low_i16x8(aux16)), scale),
);
aux32 = f32x4_add(
aux32,
f32x4_mul(f32x4_convert_i32x4(i32x4_extend_high_i16x8(aux16)), scale),
);
}
}
let d = f32x4_splat(x.d.to_f32() * y.d);
sums = f32x4_add(sums, f32x4_mul(aux32, d));
}
let sums = f32x4_extract_lane::<0>(sums)
+ f32x4_extract_lane::<1>(sums)
+ f32x4_extract_lane::<2>(sums)
+ f32x4_extract_lane::<3>(sums);
Ok(sums)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Result<f32> {
let qk = QK_K;
if n % QK_K != 0 {
crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (xs, ys) in xs.iter().zip(ys.iter()) {
let x_qs = xs.qs.as_ptr();
let y_qs = ys.qs.as_ptr();
let mut sumi = i32x4_splat(0);
for j in (0..QK_K).step_by(8) {
let xs = i16x8_load_extend_i8x8(x_qs.add(j));
let ys = i16x8_load_extend_i8x8(y_qs.add(j));
let sum_xy = i32x4_dot_i16x8(xs, ys);
sumi = i32x4_add(sumi, sum_xy)
}
let d = f32x4_splat(xs.d * ys.d);
acc = f32x4_add(acc, f32x4_mul(f32x4_convert_i32x4(sumi), d))
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}

View File

@ -17,7 +17,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let expected_blocks = xs.len() / block_size;
let actual_blocks = ys.len();
//validate that the input is the right size
// Validate that the input is the right size
if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
}
@ -37,12 +37,12 @@ pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_output_len = ys.len();
let expected_output_len = xs.len() * block_size;
//validate that the output is the right size
// Validate that the output is the right size
if expected_output_len != actual_output_len {
crate::bail!("dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!")
}
//zip the blocks and outputs together
// Zip the blocks and outputs together
Ok(xs.iter().zip(ys.chunks_exact_mut(block_size)).collect())
}

View File

@ -78,11 +78,7 @@ impl st::View for &Tensor {
}
impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
name: &str,
filename: P,
) -> Result<()> {
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@ -255,6 +251,134 @@ pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(
Ok(st::serialize_to_file(tensors, &None, filename.as_ref())?)
}
#[derive(yoke::Yokeable)]
struct SafeTensors_<'a>(SafeTensors<'a>);
pub struct MmapedSafetensors {
safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
routing: Option<HashMap<String, usize>>,
}
impl MmapedSafetensors {
/// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self {
safetensors: vec![safetensors],
routing: None,
})
}
/// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
///
/// If a tensor name appears in multiple files, the last entry is returned.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
let mut routing = HashMap::new();
let mut safetensors = vec![];
for (index, p) in paths.iter().enumerate() {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let file = memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| Error::from(e).with_path(p))?;
let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
file,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)
.map_err(|e| Error::from(e).with_path(p))?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
for k in data.get().0.names() {
routing.insert(k.to_string(), index);
}
safetensors.push(data)
}
Ok(Self {
safetensors,
routing: Some(routing),
})
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
let mut tensors = vec![];
for safetensors in self.safetensors.iter() {
tensors.push(safetensors.get().0.tensors())
}
tensors.into_iter().flatten().collect()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
let index = match &self.routing {
None => 0,
Some(routing) => {
let index = routing.get(name).ok_or_else(|| {
Error::CannotFindTensor {
path: name.to_string(),
}
.bt()
})?;
*index
}
};
Ok(self.safetensors[index].get().0.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}
impl BufferedSafetensors {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: Vec<u8>) -> Result<Self> {
let safetensors = yoke::Yoke::<SafeTensors_<'static>, Vec<u8>>::try_attach_to_cart(
buffer,
|data: &[u8]| {
let st = safetensors::SafeTensors::deserialize(data)?;
Ok::<_, Error>(SafeTensors_(st))
},
)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.get(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.get().0.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.get().0.tensor(name)?)
}
}
pub struct MmapedFile {
path: std::path::PathBuf,
inner: memmap2::Mmap,
@ -267,7 +391,7 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()

View File

@ -444,6 +444,18 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
let d4 = self.4.to_index(shape, op)?;
let d5 = self.5.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3, d4, d5])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));

View File

@ -369,6 +369,19 @@ impl Storage {
}
}
pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
match self {
Storage::Cpu(storage) => {

View File

@ -105,6 +105,28 @@ macro_rules! binary_op {
};
}
macro_rules! binary_op_scalar {
($fn_name:ident, $op_name:ident) => {
pub fn $fn_name<T: TensorOrScalar>(&self, rhs: T) -> Result<Self> {
let rhs = match rhs.to_tensor_scalar()? {
crate::scalar::TensorScalar::Tensor(rhs) => rhs,
crate::scalar::TensorScalar::Scalar(rhs) => rhs
.to_dtype(self.dtype())?
.to_device(self.device())?
.broadcast_as(self.shape())?,
};
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
let storage = self.storage().binary_impl::<crate::op::$op_name>(
&*rhs.storage(),
self.layout(),
rhs.layout(),
)?;
let op = BackpropOp::new2(self, &rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
}
macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
@ -155,14 +177,9 @@ impl Tensor {
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
if is_variable {
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor filled with ones.
@ -200,14 +217,9 @@ impl Tensor {
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
if is_variable {
let shape = shape.into();
let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
let shape = shape.into();
let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor filled with zeros.
@ -447,8 +459,8 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
binary_op!(maximum, Maximum);
binary_op!(minimum, Minimum);
binary_op_scalar!(maximum, Maximum);
binary_op_scalar!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
@ -467,7 +479,21 @@ impl Tensor {
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu);
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);
/// Round element of the input tensor to the nearest integer.
///
/// If the number of decimals is negative, it specifies the number of positions to the left of
/// the decimal point.
pub fn round_to(&self, decimals: i32) -> Result<Self> {
let mult = 10f64.powi(decimals);
(self * mult)?.round()? * (1f64 / mult)
}
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
/// dimensions, an error is returned instead.
@ -644,7 +670,12 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
let op = match op {
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
}
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
};
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)
@ -827,12 +858,35 @@ impl Tensor {
self.cmp(rhs, CmpOp::Le)
}
/// Upsample the input tensor to the `(target_h, target_w)` size, taking the value of the
/// Clamp the tensor values to be between `min` and `max`.
pub fn clamp<T1: TensorOrScalar, T2: TensorOrScalar>(&self, min: T1, max: T2) -> Result<Self> {
self.maximum(min)?.minimum(max)
}
/// Interpolate the input tensor to the `target_size` size, taking the value of the nearest element.
///
/// The input tensor should have three dimensions, `(batch, channels, l)`, the returned
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;
Ok(from_storage(storage, (n, c, target_size), op, false))
}
/// Alias for `interpolate1d`.
pub fn upsample_nearest1d(&self, target_size: usize) -> Result<Self> {
self.interpolate1d(target_size)
}
/// Interpolate the input tensor to the `(target_h, target_w)` size, taking the value of the
/// nearest element.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
@ -841,6 +895,11 @@ impl Tensor {
Ok(from_storage(storage, (n, c, target_h, target_w), op, false))
}
/// Alias for `interpolate2d`.
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
self.interpolate2d(target_h, target_w)
}
/// 2D average pooling over an input tensor with multiple channels.
///
/// The input tensor should have four dimensions, `(batch, channels, h, w)`, the returned
@ -1075,6 +1134,74 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?;
if dim == 0 {
self.slice_scatter0(src, start)
} else {
// TODO: Maybe we want to add a more efficient implementation at some point.
self.transpose(0, dim)?
.slice_scatter0(&src.transpose(0, dim)?, start)?
.transpose(0, dim)
}
}
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-scatter",
}
.bt())?
}
if self.device().location() != src.device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-scatter",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: src.shape().clone(),
}
.bt())?
}
let shape_ok =
self.dims()
.iter()
.zip(src.dims().iter())
.enumerate()
.all(|(dim_idx, (&d1, &d2))| {
if 0 == dim_idx {
d2 + start <= d1
} else {
d1 == d2
}
});
if !shape_ok {
Err(Error::ShapeMismatchBinaryOp {
op: "slice-scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
})?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
src.storage()
.copy_strided_src(&mut storage, offset, src.layout())?;
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
Ok(from_storage(storage, self.shape(), op, false))
}
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
@ -1491,6 +1618,9 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?;
if dim1 == dim2 {
return Ok(self.clone());
}
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -1852,6 +1982,34 @@ impl Tensor {
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {

View File

@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
let x = x_var.as_tensor();
let y_var = Var::new(&[2f32, 7., 1.], device)?;
let y = y_var.as_tensor();
let ss = x
.reshape((2, 3))?
.slice_scatter0(&y.reshape((1, 3))?, 1)?
.sqr()?;
let grads = ss.backward()?;
let grad_x = grads.get(x).context("no grad for x")?;
let grad_y = grads.get(y).context("no grad for y")?;
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
Ok(())
}

View File

@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> {
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
@ -491,6 +491,9 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo.
GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
@ -508,17 +511,22 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
T::VecDotType::from_float(&b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);
if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
)
}
let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
if error > GGML_MAX_DOT_PRODUCT_ERROR {
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!(
"Dot product error {} exceeds max error {}",
error,
GGML_MAX_DOT_PRODUCT_ERROR
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
);
}
@ -571,7 +579,7 @@ fn quantized_matmul_q2k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -597,7 +605,7 @@ fn quantized_matmul_q3k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -623,7 +631,7 @@ fn quantized_matmul_q4k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -649,7 +657,7 @@ fn quantized_matmul_q5k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -676,7 +684,7 @@ fn quantized_matmul_q6k() -> Result<()> {
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs);
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
@ -687,3 +695,28 @@ fn quantized_matmul_q6k() -> Result<()> {
ggml_matmul_error_test::<BlockQ6K>()?;
Ok(())
}
#[test]
fn quantized_matmul_q8k() -> Result<()> {
use k_quants::BlockQ8K;
let cpu = &Device::Cpu;
let (m, k, n) = (11, 512, 21);
let (lhs, rhs, mm) = get_random_tensors(m, k, n, cpu)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?;
assert_eq!(mm.dims(), [m, n]);
let dst = mm.flatten_all()?.to_vec1::<f32>()?;
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.266, 1.504, -0.204, 1.7]);
ggml_matmul_error_test::<BlockQ8K>()?;
Ok(())
}

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, DType, Device, IndexOp, Result, Tensor};
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> {
Ok(())
}
fn ones(device: &Device) -> Result<()> {
assert_eq!(
Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?,
[[1, 1, 1], [1, 1, 1]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.dims1()?;
@ -33,6 +58,65 @@ fn tensor_2d(device: &Device) -> Result<()> {
Ok(())
}
fn clamp(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let tensor = tensor.clamp(1.5, 6.2)?;
assert_eq!(
tensor.to_vec2::<f32>()?,
[[3.0, 1.5, 4.0, 1.5, 5.0], [2.0, 1.5, 6.2, 6.2, 2.0]],
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu()?, 4)?,
[
[-0.0036, 0.8412, 3.9999, -0.046, 0.3457],
[2.6911, -0.0647, -0.1091, 1.7353, 2.7933]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.gelu_erf()?, 4)?,
[
[-0.004, 0.8413, 3.9999, -0.046, 0.3457],
[2.6906, -0.0647, -0.1091, 1.7353, 2.7928]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.erf()?, 4)?,
[
[-1.0, 0.8427, 1.0, -0.1125, 0.5205],
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.round()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
);
let tensor = Tensor::new(&[2997.9246, 314.15926f32], device)?;
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(2)?, 4)?,
[2997.92, 314.16]
);
assert_eq!(
test_utils::to_vec1_round(&tensor.round_to(-2)?, 4)?,
[3000.0, 300.]
);
Ok(())
}
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor1 = Tensor::new(data, device)?;
@ -590,6 +674,30 @@ fn index_select(device: &Device) -> Result<()> {
hs.to_vec2::<f32>()?,
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
);
// Prior to https://github.com/huggingface/candle/pull/1022
// There would be a bug where the last values in the result tensor would be set to 0.
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
let hs = t.index_select(&ids, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[6.0, 7.0, 8.0],
[3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
Ok(())
}
@ -636,6 +744,48 @@ fn index_add(device: &Device) -> Result<()> {
Ok(())
}
fn slice_scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
assert_eq!(
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
&[
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
[6.0, 7.0, 8.0],
[9.0, 10.0, 11.0]
]
);
assert_eq!(
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
[9.0, 10.0, 11.0]
]
);
assert_eq!(
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[100.0, 101.0, 102.0],
[103.0, 104.0, 105.0],
]
);
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
@ -877,7 +1027,16 @@ fn broadcasting(device: &Device) -> Result<()> {
Ok(())
}
fn randn(device: &Device) -> Result<()> {
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
assert_eq!(tensor.dims(), [5, 3]);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
@ -889,6 +1048,7 @@ test_device!(max, max_cpu, max_gpu);
test_device!(argmax, argmax_cpu, argmax_gpu);
test_device!(argmin, argmin_cpu, argmin_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(unary_op, unary_op_cpu, unary_op_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
@ -899,6 +1059,9 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }

View File

@ -8,13 +8,9 @@ use parquet::file::reader::{FileReader, SerializedFileReader};
use std::fs::File;
use std::io::{self, BufReader, Read};
fn read_u32<T: Read>(reader: &mut T) -> Result<u32> {
let mut b = vec![0u8; 4];
reader.read_exact(&mut b)?;
let (result, _) = b.iter().rev().fold((0u64, 1u64), |(s, basis), &x| {
(s + basis * u64::from(x), basis * 256)
});
Ok(result as u32)
fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
use byteorder::ReadBytesExt;
reader.read_u32::<byteorder::BigEndian>()
}
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {

View File

@ -11,19 +11,21 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.0" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../candle-transformers", version = "0.3.0" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
[dev-dependencies]
anyhow = { workspace = true }
@ -34,7 +36,6 @@ imageproc = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
rusttype = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
@ -50,7 +51,7 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]

View File

@ -0,0 +1,44 @@
# candle-bert
Bert is a general large language model. In this example it can be used for two
different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example bert --release -- --prompt "Here is a test sentence"
> [[[ 0.0798, -0.0665, -0.0247, ..., -0.1082, -0.1000, -0.2751],
> [ 0.4218, 0.2690, 0.2740, ..., 0.3889, 1.3503, 0.9908],
> [ 0.0466, 0.3041, -0.1143, ..., 0.4427, 0.6926, -0.1515],
> ...
> [ 0.3396, 0.4320, -0.4408, ..., 0.9212, 0.2331, -0.6777],
> [ 0.2789, 0.7539, 0.4306, ..., -0.0095, 0.3375, -1.7529],
> [ 0.6737, 0.7882, 0.0548, ..., 0.1836, 0.7299, -0.6617]]]
> Tensor[[1, 7, 384], f32]
```
## Similarities
In this example, Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example bert --release
> score: 0.85 'The new movie is awesome' 'The new movie is so great'
> score: 0.61 'The cat sits outside' 'The cat plays in the garden'
> score: 0.52 'I love pasta' 'Do you like pizza?'
> score: 0.23 'The new movie is awesome' 'Do you like pizza?'
> score: 0.22 'I love pasta' 'The new movie is awesome'
```

View File

@ -3,14 +3,13 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod model;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
@ -87,9 +86,8 @@ impl Args {
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}

View File

@ -0,0 +1,19 @@
# candle-starcoder: code generation model
[StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
model specialized to code generation. The initial model was trained on 80
programming languages.
## Running some example
```bash
cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64 "
> fn fact(n: u64) -> u64 {
> if n == 0 {
> 1
> } else {
> n * fact(n - 1)
> }
> }
```

View File

@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
mod model;
use model::{Config, GPTBigCode};
use candle_transformers::models::bigcode::{Config, GPTBigCode};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@ -29,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp);
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
@ -95,6 +95,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -134,23 +138,21 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,19 @@
# candle-dinov2
[DINOv2](https://github.com/facebookresearch/dinov2) is a computer vision model.
In this example, it is used as an ImageNet classifier: the model returns the
probability for the image to belong to each of the 1000 ImageNet categories.
## Running some example
```bash
cargo run --example dinov2 --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 43.67%
> bicycle-built-for-two, tandem bicycle, tandem: 33.20%
> crash helmet : 13.23%
> unicycle, monocycle : 2.44%
> maillot : 2.42%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::dinov2;
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -317,10 +42,8 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = vit_small(vb)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?

View File

@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,
@ -397,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
}
Some(model) => model.into(),
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let cfg = match args.which {
Which::B0 => MBConvConfig::b0(),
Which::B1 => MBConvConfig::b1(),

View File

@ -0,0 +1,3 @@
# candle-falcon
Falcon is a general large language model.

View File

@ -14,8 +14,7 @@ use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
mod model;
use model::{Config, Falcon};
use candle_transformers::models::falcon::{Config, Falcon};
struct TextGeneration {
model: Falcon,
@ -26,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize,
}
struct GenerationOptions {
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
fn new(
model: Falcon,
tokenizer: Tokenizer,
generation_options: GenerationOptions,
seed: u64,
temp: Option<f64>,
device: &Device,
repeat_penalty: f32,
repeat_last_n: usize,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp);
let logits_processor =
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
let repeat_penalty = generation_options.repeat_penalty;
let repeat_last_n = generation_options.repeat_last_n;
Self {
model,
tokenizer,
@ -119,6 +126,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -166,35 +177,25 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let dtype = if args.use_f32 {
DType::F32
} else {
DType::BF16
};
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let config = Config::falcon7b();
config.validate()?;
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
&device,
args.repeat_penalty,
args.repeat_last_n,
);
let generation_options = GenerationOptions {
temp: args.temperature,
top_p: args.top_p,
repeat_penalty: args.repeat_penalty,
repeat_last_n: args.repeat_last_n,
};
let mut pipeline =
TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
mod model;
use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)]
@ -43,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -169,17 +172,9 @@ fn main() -> Result<()> {
}
println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}
};
@ -194,7 +189,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;

View File

@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
#[arg(long, default_value = "")]
prompt: String,
@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => {
let cmd = InferenceCmd {
temperature: None,
top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);

View File

@ -89,6 +89,10 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -201,16 +205,9 @@ fn main() -> Result<()> {
let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
let vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
};
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -222,7 +219,7 @@ fn main() -> Result<()> {
.to_vec();
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
let mut index_pos = 0;

View File

@ -0,0 +1,90 @@
# candle-mistral: 7b LLM with Apache 2.0 licensed weights
Mistral-7B-v0.1 is a pretrained generative LLM with 7 billion parameters. It outperforms all the publicly available 13b models
as of 2023-09-28. Weights (and the original Python model code) are released under the permissive Apache 2.0 license.
- [Blog post](https://mistral.ai/news/announcing-mistral-7b/) from Mistral announcing the model release.
- [Model card](https://huggingface.co/mistralai/Mistral-7B-v0.1) on the
HuggingFace Hub.
This example supports the initial model as well as a quantized variant.
## Running the example
```bash
$ cargo run --example mistral --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
Generated text:
Write helloworld code in Rust
=============================
This is a simple example of how to write "Hello, world!" program in Rust.
## Compile and run
``bash
$ cargo build --release
Compiling hello-world v0.1.0 (/home/user/rust/hello-world)
Finished release [optimized] target(s) in 0.26s
$ ./target/release/hello-world
Hello, world!
``
## Source code
``rust
fn main() {
println!("Hello, world!");
}
``
## License
This example is released under the terms
```
## Running the quantized version of the model
```bash
$ cargo run --example mistral --features accelerate --release -- \
$ --prompt "Here is a sample quick sort implementation in rust " --quantized -n 400
avx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
retrieved the files in 562.292µs
loaded the model in 1.100323667s
Here is a sample quick sort implementation in rust
``rust
fn quick_sort(arr: &mut [i32]) {
if arr.len() <= 1 {
return;
}
let pivot = arr[0];
let mut left = vec![];
let mut right = vec![];
for i in 1..arr.len() {
if arr[i] < pivot {
left.push(arr[i]);
} else {
right.push(arr[i]);
}
}
quick_sort(&mut left);
quick_sort(&mut right);
let mut i = 0;
for _ in &left {
arr[i] = left.pop().unwrap();
i += 1;
}
for _ in &right {
arr[i] = right.pop().unwrap();
i += 1;
}
}
``
226 tokens generated (10.91 token/s)
```

View File

@ -0,0 +1,271 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mistral::{Config, Model as Mistral};
use candle_transformers::models::quantized_mistral::Model as QMistral;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
Mistral(Mistral),
Quantized(QMistral),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::Mistral(m) => m.forward(&input, start_pos)?,
Model::Quantized(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "lmz/candle-mistral")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![
repo.get("pytorch_model-00001-of-00002.safetensors")?,
repo.get("pytorch_model-00002-of-00002.safetensors")?,
]
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn);
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Mistral::new(&config, vb)?;
(Model::Mistral(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,6 +1,6 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
use candle::{DType, IndexOp, Module, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer {
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for i in 0..cfg.num_lstm_layers {
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
layers.push((w_hh, w_ih, b_hh, b_ih))
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
Ok(xs)
}
}
@ -247,7 +256,9 @@ impl EncodecConvTranspose1d {
bias,
})
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
@ -299,7 +310,9 @@ impl EncodecConv1d {
conv,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
@ -340,7 +353,9 @@ impl EncodecResnetBlock {
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
@ -439,8 +454,17 @@ impl EncodecEncoder {
})
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
@ -507,8 +531,15 @@ impl EncodecDecoder {
})
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}

View File

@ -13,7 +13,6 @@ extern crate accelerate_src;
mod encodec_model;
mod musicgen_model;
mod nn;
mod t5_model;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
@ -74,11 +73,9 @@ fn main() -> Result<()> {
))
.get("model.safetensors")?,
};
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
let config = GenConfig::small();
let model = MusicgenForConditionalGeneration::load(vb, config)?;
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
let tokens = tokenizer
.encode(args.prompt.as_str(), true)

View File

@ -1,9 +1,10 @@
use crate::{encodec_model, t5_model};
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@ -39,7 +40,7 @@ impl Default for Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
activation_function: Activation::Gelu,
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -65,7 +66,7 @@ impl Config {
num_attention_heads: 16,
layerdrop: 0.0,
use_cache: true,
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
activation_function: Activation::Gelu,
hidden_size: 1024,
dropout: 0.1,
attention_dropout: 0.0,
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: crate::t5_model::T5EncoderModel,
pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
#[derive(Debug, Clone, PartialEq)]
pub struct GenConfig {
musicgen: Config,
t5: crate::t5_model::Config,
t5: t5::Config,
encodec: crate::encodec_model::Config,
}
@ -387,7 +388,7 @@ impl GenConfig {
pub fn small() -> Self {
Self {
musicgen: Config::musicgen_small(),
t5: t5_model::Config::musicgen_small(),
t5: t5::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(),
}
}
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
}
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;

View File

@ -1,434 +0,0 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::{DType, Result, Tensor, D};
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
vocab_size: usize,
d_model: usize,
d_kv: usize,
d_ff: usize,
num_layers: usize,
num_decoder_layers: Option<usize>,
num_heads: usize,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
feed_forward_proj: Activation,
is_decoder: bool,
is_encoder_decoder: bool,
use_cache: bool,
pad_token_id: usize,
eos_token_id: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 32128,
d_model: 512,
d_kv: 64,
d_ff: 2048,
num_layers: 6,
num_decoder_layers: None,
num_heads: 8,
relative_attention_num_buckets: 32,
relative_attention_max_distance: 128,
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
pad_token_id: 0,
eos_token_id: 1,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
pub fn musicgen_small() -> Self {
Self {
d_ff: 3072,
d_kv: 64,
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
layer_norm_epsilon: 1e-6,
num_decoder_layers: Some(12),
num_heads: 12,
num_layers: 12,
pad_token_id: 0,
relative_attention_max_distance: 128,
relative_attention_num_buckets: 32,
use_cache: true,
vocab_size: 32128,
}
}
}
#[derive(Debug)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
}
impl T5LayerNorm {
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(h, "weight")?;
Ok(Self {
weight,
variance_epsilon: eps,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let dtype = xs.dtype();
let xs_f32 = xs.to_dtype(DType::F32)?;
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
let xs = xs.to_dtype(dtype)?;
let xs = xs.broadcast_mul(&self.weight)?;
Ok(xs)
}
}
#[derive(Debug)]
struct T5DenseActDense {
wi: Linear,
wo: Linear,
act: Activation,
}
impl T5DenseActDense {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
Ok(Self {
wi,
wo,
act: Activation::Relu,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.wi.forward(xs)?;
let xs = self.act.forward(&xs)?;
let xs = self.wo.forward(&xs)?;
Ok(xs)
}
}
#[derive(Debug)]
struct T5LayerFF {
dense_relu_dense: T5DenseActDense,
layer_norm: T5LayerNorm,
}
impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
// is_gated_act is not supported.
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
dense_relu_dense,
layer_norm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let ys = self.layer_norm.forward(xs)?;
let ys = self.dense_relu_dense.forward(&ys)?;
let xs = (xs + ys)?;
Ok(xs)
}
}
#[derive(Debug)]
struct T5Attention {
q: Linear,
k: Linear,
v: Linear,
o: Linear,
n_heads: usize,
d_kv: usize,
relative_attention_bias: Option<Embedding>,
relative_attention_num_buckets: usize,
relative_attention_max_distance: usize,
inner_dim: usize,
}
impl T5Attention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
let relative_attention_bias = if h {
let emb = embedding(
cfg.relative_attention_num_buckets,
cfg.num_heads,
vb.pp("relative_attention_bias"),
)?;
Some(emb)
} else {
None
};
Ok(Self {
q,
k,
v,
o,
n_heads: cfg.num_heads,
d_kv: cfg.d_kv,
relative_attention_bias,
relative_attention_num_buckets: cfg.relative_attention_num_buckets,
relative_attention_max_distance: cfg.relative_attention_max_distance,
inner_dim,
})
}
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
// TODO: Apply the mask(s)?
// TODO: kv caching.
let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?);
let q = self.q.forward(xs)?;
let k = self.k.forward(xs)?;
let v = self.v.forward(xs)?;
let q = q
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b_sz, seq_len, self.n_heads, self.d_kv))?
.transpose(1, 2)?
.contiguous()?;
let scores = q.matmul(&k.t()?)?;
let (scores, position_bias) = match position_bias {
Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())),
None => match &self.relative_attention_bias {
None => (scores, None),
Some(relative_attention_bias) => {
let query_length = seq_len;
let key_length = seq_len;
// This only handles the bidirectional case.
let num_buckets = self.relative_attention_num_buckets as u32 / 2;
let max_exact = num_buckets / 2;
let relative_position = (0..query_length as u32)
.map(|i| {
(0..key_length as u32)
.map(|j| {
if i < j {
if j - i < max_exact {
j - i + num_buckets
} else {
let b = f32::log(
(j - i) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
u32::min(
max_exact + num_buckets + b as u32,
self.relative_attention_num_buckets as u32 - 1,
)
}
} else if i - j < max_exact {
i - j
} else {
let b = f32::log(
(i - j) as f32 / max_exact as f32,
self.relative_attention_max_distance as f32
/ max_exact as f32,
) * (num_buckets - max_exact) as f32;
max_exact + b as u32
}
})
.collect::<Vec<u32>>()
})
.collect::<Vec<Vec<_>>>();
let relative_buckets = Tensor::new(relative_position, q.device())?;
let position_bias = relative_attention_bias
.forward(&relative_buckets)?
.permute((2, 0, 1))?
.unsqueeze(0)?;
((scores + &position_bias)?, Some(position_bias))
// TODO: position_bias_masked?
}
},
};
let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.inner_dim))?;
let attn_output = self.o.forward(&attn_output)?;
Ok((attn_output, position_bias))
}
}
#[derive(Debug)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
}
impl T5LayerSelfAttention {
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
Ok(Self {
self_attention,
layer_norm,
})
}
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let normed_xs = self.layer_norm.forward(xs)?;
let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?;
let ys = (xs + ys)?;
Ok((ys, position_bias))
}
}
#[derive(Debug)]
struct T5LayerCrossAttention {}
impl T5LayerCrossAttention {
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
todo!()
}
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
ff: T5LayerFF,
}
impl T5Block {
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = vb.pp("layer");
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
let cross_attn = if cfg.is_decoder {
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
} else {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
Ok(Self {
self_attn,
cross_attn,
ff,
})
}
fn forward(
&self,
xs: &Tensor,
position_bias: Option<&Tensor>,
) -> Result<(Tensor, Option<Tensor>)> {
let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?;
// TODO: clamp for f16?
if let Some(cross_attn) = &self.cross_attn {
xs = cross_attn.forward(&xs)?;
// TODO: clamp for f16?
}
let xs = self.ff.forward(&xs)?;
// TODO: clamp for f16?
Ok((xs, position_bias))
}
}
#[derive(Debug)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
final_layer_norm: T5LayerNorm,
}
impl T5Stack {
fn load(vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
let block = (0..cfg.num_layers)
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
.collect::<Result<Vec<_>>>()?;
let final_layer_norm = T5LayerNorm::load(
cfg.d_model,
cfg.layer_norm_epsilon,
vb.pp("final_layer_norm"),
)?;
Ok(Self {
block,
shared: shared.clone(),
final_layer_norm,
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let input_embeds = self.shared.as_ref().forward(input_ids)?;
let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?);
let mut hidden_states = input_embeds;
let mut position_bias = None;
for block in self.block.iter() {
(hidden_states, position_bias) =
block.forward(&hidden_states, position_bias.as_ref())?
}
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
Ok(hidden_states)
}
}
#[derive(Debug)]
pub struct T5EncoderModel {
shared: Arc<Embedding>,
encoder: T5Stack,
}
impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
let encoder = T5Stack::load(vb.pp("encoder"), &shared, cfg)?;
Ok(Self { shared, encoder })
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let encoder_outputs = self.encoder.forward(input_ids)?;
Ok(encoder_outputs)
}
}

View File

@ -0,0 +1,43 @@
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
only 1.3 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a
quantized variant.
## Running some example
```bash
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
def print_prime(n):
print("Printing prime numbers")
for i in range(2, n+1):
if is_prime(i):
print(i)
def is_prime(n):
if n <= 1:
return False
for i in range(2, int(math.sqrt(n))+1):
if n % i == 0:
return False
return True
$ cargo run --example phi --release -- \
--prompt "Explain how to find the median of an array and write the corresponding python function.\nAnswer:" \
--quantized --sample-len 200
Explain how to find the median of an array and write the corresponding python function.
Answer: The median is the middle value in an array. If the array has an even number of elements, the median is the average of the two middle values.
def median(arr):
arr.sort()
n = len(arr)
if n % 2 == 0:
return (arr[n//2 - 1] + arr[n//2]) / 2
else:
return arr[n//2]
```

View File

@ -0,0 +1,238 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
MixFormer(MixFormer),
Quantized(QMixFormer),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::MixFormer(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "microsoft/phi-1_5")]
model_id: String,
#[arg(long, default_value = "refs/pr/18")]
revision: String,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => {
if args.quantized {
api.model("lmz/candle-quantized-phi".to_string())
.get("model-q4k.gguf")?
} else {
repo.get("model.safetensors")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::v1_5();
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = QMixFormer::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = MixFormer::new(&config, vb)?;
(Model::MixFormer(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,42 @@
# candle-quantized-t5
This example uses a quantized version of the t5 model.
```bash
$ cargo run --example quantized-t5 --release -- --prompt "translate to German: A beautiful candle."
...
Eine schöne Kerze.
```
The weight file is automatically retrieved from the hub. It is also possible to
generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:
```bash
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```
To use a different model, specify the `model-id`. For example, you can use
quantized [CoEdit models](https://huggingface.co/jbochi/candle-coedit-quantized).
```bash
$ cargo run --example quantized-t5 --release -- \
--model-id "jbochi/candle-coedit-quantized" \
--prompt "Make this text coherent: Their flight is weak. They run quickly through the tree canopy." \
--temperature 0
...
Although their flight is weak, they run quickly through the tree canopy.
By default, it will look for `model.gguf` and `config.json`, but you can specify
custom local or remote `weight-file` and `config-file`s:
```bash
cargo run --example quantized-t5 --release -- \
--model-id "jbochi/candle-coedit-quantized" \
--weight-file "model-xl.gguf" \
--config-file "config-xl.json" \
--prompt "Rewrite to make this easier to understand: Note that a storm surge is what forecasters consider a hurricane's most treacherous aspect." \
--temperature 0
...
Note that a storm surge is what forecasters consider a hurricane's most dangerous part.
```

View File

@ -0,0 +1,228 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::models::quantized_t5 as t5;
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, api::sync::ApiRepo, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
T5Small,
FlanT5Small,
FlanT5Base,
FlanT5Large,
FlanT5Xl,
FlanT5Xxl,
}
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
// Enable/disable decoding.
#[arg(long, default_value = "false")]
disable_cache: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "t5-small")]
which: Which,
}
struct T5ModelBuilder {
device: Device,
config: t5::Config,
weights_filename: PathBuf,
}
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = Device::Cpu;
let default_model = "lmz/candle-quantized-t5".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, "main".to_string()),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = match &args.config_file {
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
None => match args.which {
Which::T5Small => api.get("config.json")?,
Which::FlanT5Small => api.get("config-flan-t5-small.json")?,
Which::FlanT5Base => api.get("config-flan-t5-base.json")?,
Which::FlanT5Large => api.get("config-flan-t5-large.json")?,
Which::FlanT5Xl => api.get("config-flan-t5-xl.json")?,
Which::FlanT5Xxl => api.get("config-flan-t5-xxl.json")?,
},
};
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = match &args.weight_file {
Some(filename) => Self::get_local_or_remote_file(filename, &api)?,
None => match args.which {
Which::T5Small => api.get("model.gguf")?,
Which::FlanT5Small => api.get("model-flan-t5-small.gguf")?,
Which::FlanT5Base => api.get("model-flan-t5-base.gguf")?,
Which::FlanT5Large => api.get("model-flan-t5-large.gguf")?,
Which::FlanT5Xl => api.get("model-flan-t5-xl.gguf")?,
Which::FlanT5Xxl => api.get("model-flan-t5-xxl.gguf")?,
},
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
config.use_cache = !args.disable_cache;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
device,
config,
weights_filename,
},
tokenizer,
))
}
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
fn get_local_or_remote_file(filename: &str, api: &ApiRepo) -> Result<PathBuf> {
let local_filename = std::path::PathBuf::from(filename);
if local_filename.exists() {
Ok(local_filename)
} else {
Ok(api.get(filename)?)
}
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mut model = builder.build_model()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let encoder_output = model.encode(&input_token_ids)?;
let start = std::time::Instant::now();
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
output_token_ids.len(),
output_token_ids.len() as f64 / dt.as_secs_f64(),
);
Ok(())
}

View File

@ -0,0 +1,37 @@
# candle-quantized-llama: Fast Inference of quantized LLaMA models
This example provides a quantized LLaMA model similar to
[llama.cpp](https://github.com/ggerganov/llama.cpp). This is based on candle
built-in quantization methods. Supported features include:
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support.
- SIMD optimizations on Apple Silicon and x86.
- Support using the `gguf` and `ggml` file formats.
The weights are automatically downloaded for you from the [HuggingFace
Hub](https://huggingface.co/) on the first run. There are various command line
flags to use local files instead, run with `--help` to learn about them.
![Axiom of Choice](./assets/aoc.gif)
## Running some example.
```bash
cargo run --example quantized --release -- --prompt "The best thing about coding in rust is "
> avx: true, neon: false, simd128: false, f16c: true
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
> loaded 291 tensors (3.79GB) in 2.17s
> params: HParams { n_vocab: 32000, n_embd: 4096, n_mult: 256, n_head: 32, n_layer: 32, n_rot: 128, ftype: 2 }
> The best thing about coding in rust is 1.) that I dont need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
```
## Command-line flags
Run with `--help` to see all options.
- `--which`: specify the model to use, e.g. `7b`, `13-chat`, `7b-code`.
- `--prompt interactive`: interactive mode where multiple prompts can be
entered.
- `--model mymodelfile.gguf`: use a local model file rather than getting one
from the hub.

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

View File

@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
mod model;
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
@ -44,6 +44,27 @@ enum Which {
L13bCode,
#[value(name = "32b-code")]
L34bCode,
#[value(name = "7b-mistral")]
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
}
impl Which {
fn is_mistral(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
Self::Mistral7b | Self::Mistral7bInstruct => true,
}
}
}
#[derive(Parser, Debug)]
@ -71,6 +92,10 @@ struct Args {
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -106,7 +131,12 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
let repo = if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
};
@ -136,6 +166,14 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
),
Which::Mistral7bInstruct => (
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -257,7 +295,7 @@ fn main() -> anyhow::Result<()> {
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => 1,
Which::L70b | Which::L70bChat => 8,
Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@ -287,7 +325,11 @@ fn main() -> anyhow::Result<()> {
prompt.pop();
}
}
prompt
if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else {
prompt
}
}
};
print!("{}", &prompt_str);
@ -310,7 +352,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
@ -323,6 +365,8 @@ fn main() -> anyhow::Result<()> {
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
let eos_token = *tokenizer.get_vocab(true).get("</s>").unwrap();
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
@ -341,6 +385,9 @@ fn main() -> anyhow::Result<()> {
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if next_token == eos_token {
break;
};
}
let dt = start_post_prompt.elapsed();
println!(

View File

@ -0,0 +1,44 @@
# candle-segment-anything: Segment-Anything Model
This example is based on Meta AI [Segment-Anything
Model](https://github.com/facebookresearch/segment-anything). This model
provides a robust and fast image segmentation pipeline that can be tweaked via
some prompting (requesting some points to be in the target mask, requesting some
points to be part of the background so _not_ in the target mask, specifying some
bounding box).
The default backbone can be replaced by the smaller and faster TinyViT model
based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
## Running some example.
```bash
cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny
--point 0.6,0.6 --point 0.6,0.55
```
Running this command generates a `sam_merged.jpg` file containing the original
image with a blue overlay of the selected mask. The red dots represent the prompt
specified by `--point 0.6,0.6 --point 0.6,0.55`, this prompt is assumed to be part
of the target mask.
The values used for `--point` should be a comma delimited pair of float values.
They are proportional to the image dimension, i.e. use 0.5 for the image center.
Original image:
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)
Segment results by prompting with a single point `--point 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/single_pt_prompt.jpg)
Segment results by prompting with multiple points `--point 0.6,0.6 --point 0.6,0.55`:
![Leading group, Giro d'Italia 2021](./assets/two_pt_prompt.jpg)
### Command-line flags
- `--use-tiny`: use the TinyViT based MobileSAM backbone rather than the default
one.
- `--point`: specifies the location of the target points.
- `--threshold`: sets the threshold value to be part of the mask, a negative
value results in a larger mask and can be specified via `--threshold=-1.2`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 157 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

View File

@ -7,107 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub mod model_image_encoder;
pub mod model_mask_decoder;
pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::segment_anything::sam;
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
pub struct LayerNorm2d {
weight: Tensor,
bias: Tensor,
num_channels: usize,
eps: f64,
}
impl LayerNorm2d {
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(num_channels, "weight")?;
let bias = vb.get(num_channels, "bias")?;
Ok(Self {
weight,
bias,
num_channels,
eps,
})
}
}
impl Module for LayerNorm2d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let u = xs.mean_keepdim(1)?;
let xs = xs.broadcast_sub(&u)?;
let s = xs.sqr()?.mean_keepdim(1)?;
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
}
}
#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
}
impl MlpBlock {
pub fn new(
embedding_dim: usize,
mlp_dim: usize,
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -123,15 +27,28 @@ struct Args {
#[arg(long)]
generate_masks: bool,
#[arg(long, default_value_t = 0.5)]
point_x: f64,
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
/// should be part of the generated mask.
#[arg(long)]
point: Vec<String>,
#[arg(long, default_value_t = 0.5)]
point_y: f64,
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
/// should not be part of the generated mask and should be part of the background instead.
#[arg(long)]
neg_point: Vec<String>,
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Use the TinyViT based models from MobileSAM
#[arg(long)]
use_tiny: bool,
}
pub fn main() -> anyhow::Result<()> {
@ -149,28 +66,9 @@ pub fn main() -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
let mut tensors = candle::safetensors::load(&args.image, &device)?;
let image = match tensors.remove("image") {
Some(image) => image,
None => {
if tensors.len() != 1 {
anyhow::bail!("multiple tensors in '{}'", args.image)
}
tensors.into_values().next().unwrap()
}
};
let image = if image.rank() == 4 {
image.get(0)?
} else {
image
};
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
let (image, initial_h, initial_w) =
candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
let image = image.to_device(&device)?;
println!("loaded image {image:?}");
let model = match args.model {
@ -178,13 +76,20 @@ pub fn main() -> anyhow::Result<()> {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-sam".to_string());
api.get("sam_vit_b_01ec64.safetensors")?
let filename = if args.use_tiny {
"mobile_sam-tiny-vitt.safetensors"
} else {
"sam_vit_b_01ec64.safetensors"
};
api.get(filename)?
}
};
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let sam = if args.use_tiny {
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {
// Default options similar to the Python version.
@ -196,7 +101,7 @@ pub fn main() -> anyhow::Result<()> {
/* crop_n_points_downscale_factor */ 1,
)?;
for (idx, bbox) in bboxes.iter().enumerate() {
println!("{bbox:?}");
println!("{idx} {bbox:?}");
let mask = (&bbox.data.to_dtype(DType::U8)? * 255.)?;
let (h, w) = mask.dims2()?;
let mask = mask.broadcast_as((3, h, w))?;
@ -208,66 +113,69 @@ pub fn main() -> anyhow::Result<()> {
)?;
}
} else {
let point = Some((args.point_x, args.point_y));
let iter_points = args.point.iter().map(|p| (p, true));
let iter_neg_points = args.neg_point.iter().map(|p| (p, false));
let points = iter_points
.chain(iter_neg_points)
.map(|(point, b)| {
use std::str::FromStr;
let xy = point.split(',').collect::<Vec<_>>();
if xy.len() != 2 {
anyhow::bail!("expected format for points is 0.4,0.2")
}
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b))
})
.collect::<anyhow::Result<Vec<_>>>()?;
let start_time = std::time::Instant::now();
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
let (mask, iou_predictions) = sam.forward(&image, &points, false)?;
println!(
"mask generated in {:.2}s",
start_time.elapsed().as_secs_f32()
);
println!("mask:\n{mask}");
println!("iou_predictions: {iou_predictions:?}");
println!("iou_predictions: {iou_predictions}");
// Save the mask as an image.
let mask = (mask.ge(0f32)? * 255.)?;
let mask = (mask.ge(args.threshold)? * 255.)?;
let (_one, h, w) = mask.dims3()?;
let mask = mask.expand((3, h, w))?;
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
if !args.image.ends_with(".safetensors") {
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
let mut img = image::io::Reader::open(&args.image)?
.decode()
.map_err(candle::Error::wrap)?;
let mask_pixels = mask.permute((1, 2, 0))?.flatten_all()?.to_vec1::<u8>()?;
let mask_img: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
match image::ImageBuffer::from_raw(w as u32, h as u32, mask_pixels) {
Some(image) => image,
None => anyhow::bail!("error saving merged image"),
};
let mask_img = image::DynamicImage::from(mask_img).resize_to_fill(
img.width(),
img.height(),
image::imageops::FilterType::CatmullRom,
);
for x in 0..img.width() {
for y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_img, x, y);
if mask_p.0[0] > 100 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, x, y);
img_p.0[2] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[1] /= 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, x, y, img_p)
}
}
match point {
Some((x, y)) => {
let (x, y) = (
(x * img.width() as f64) as i32,
(y * img.height() as f64) as i32,
);
imageproc::drawing::draw_filled_circle(
&img,
(x, y),
3,
image::Rgba([255, 0, 0, 200]),
)
.save("sam_merged.jpg")?
}
None => img.save("sam_merged.jpg")?,
};
}
for (x, y, b) in points {
let x = (x * img.width() as f64) as i32;
let y = (y * img.height() as f64) as i32;
let color = if b {
image::Rgba([255, 0, 0, 200])
} else {
image::Rgba([0, 255, 0, 200])
};
imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color);
}
img.save("sam_merged.jpg")?
}
Ok(())
}

View File

@ -0,0 +1,63 @@
# candle-stable-diffusion: A Diffusers API in Rust/Candle
![rusty robot holding a candle](./assets/stable-diffusion-xl.jpg)
_A rusty robot holding a fire torch in its hand_, generated by Stable Diffusion
XL using Rust and [candle](https://github.com/huggingface/candle).
The `stable-diffusion` example is a conversion of
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
as well as Stable Diffusion XL 1.0.
## Getting the weights
The weights are automatically downloaded for you from the [HuggingFace
Hub](https://huggingface.co/) on the first run. There are various command line
flags to use local files instead, run with `--help` to learn about them.
## Running some example.
```bash
cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
```
The final image is named `sd_final.png` by default.
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
### Command-line flags
- `--prompt`: the prompt to be used to generate the image.
- `--uncond-prompt`: the optional unconditional prompt.
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
`xl`.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.
- `--num-samples`: the number of samples to generate.
- `--final-image`: the filename for the generated image(s).
### Using flash-attention
Using flash attention makes image generation a lot faster and uses less memory.
The downside is some long compilation time. You can set the
`CANDLE_FLASH_ATTN_BUILD_DIR` environment variable to something like
`/home/user/.candle` to ensures that the compilation artifacts are properly
cached.
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`.
## Image to Image Pipeline
...
## FAQ
### Memory Issues
This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used
with the `--cpu` flag but is much slower.
Alternatively, reducing the height and width with the `--height` and `--width`
flag is likely to reduce memory usage significantly.

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

View File

@ -4,20 +4,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod attention;
mod clip;
mod ddim;
mod embeddings;
mod resnet;
mod schedulers;
mod stable_diffusion;
mod unet_2d;
mod unet_2d_blocks;
mod utils;
mod vae;
use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@ -107,14 +97,13 @@ struct Args {
img2img_strength: f64,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
}
#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
@ -214,7 +203,18 @@ impl ModelFile {
Self::Clip => (version.repo(), version.clip_file(use_f16)),
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
if version == StableDiffusionVersion::Xl && use_f16 {
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
)
} else {
(version.repo(), version.vae_file(use_f16))
}
}
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)
@ -494,9 +494,8 @@ fn run(args: Args) -> Result<()> {
num_samples
);
let image = vae.decode(&(&latents / 0.18215)?)?;
// TODO: Add the clamping between 0 and 1.
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
}

View File

@ -1,39 +0,0 @@
use candle::{Device, Result, Tensor};
use candle_nn::Module;
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
candle::bail!("cannot use linspace with steps {steps} <= 1")
}
let delta = (stop - start) / (steps - 1) as f64;
let vs = (0..steps)
.map(|step| start + step as f64 * delta)
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}
impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}

View File

@ -0,0 +1,25 @@
# candle-stable-lm
StableLM-3B-4E1T is a 3 billion parameter decoder-only language model
pre-trained on 1 trillion tokens of diverse English and code datasets for 4
epochs. See the [HuggingFace Hub Model
Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
Note that this model is gated so you will have to request access on the Hub in
order to be able to use it.
## Running some example
```bash
$ cargo run --example stable-lm --release --features cuda -- --prompt 'What is the most efficient programming language in use?' --sample-len 150
avx: true, neon: false, simd128: false, f16c: true
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
retrieved the files in 126.593µs
loaded the model in 3.474148965s
What is the most efficient programming language in use?
The answer to this question depends on what you mean by "efficient". If you're talking about speed, then C++ and Java are probably your best bets. But if you're talking about ease of development, then Python is probably the way to go.
Python is a high-level, interpreted language that is easy to learn and use. It has a large community of developers who are always working on new features and improvements.
C++ is a low-level, compiled language that can be used for both desktop applications and web development. It's more difficult to learn than Python but offers greater control over the code.
Java is another high-level language that is popular with programmers because it runs on many different platforms (including Android phones
150 tokens generated (37.61 token/s)
```

View File

@ -0,0 +1,268 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
StableLM(StableLM),
Quantized(QStableLM),
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = match &mut self.model {
Model::StableLM(m) => m.forward(&input, start_pos)?,
Model::Quantized(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![repo.get("model.safetensors")?]
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = StableLM::new(&config, vb)?;
(Model::StableLM(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -0,0 +1,25 @@
# candle-t5
## Encoder-decoder example:
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate to German: A beautiful candle." --decode
...
Running on CPU, to run on GPU, build this example with `--features cuda`
Eine schöne Kerze.
9 tokens generated (2.42 token/s)
```
## Sentence embedding example:
```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
...
[[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
[-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
[ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
[-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
[ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
Tensor[[1, 5, 512], f32]
Took 303.766583ms
```

View File

@ -0,0 +1,300 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use std::io::Write;
use std::path::PathBuf;
use candle_transformers::models::t5;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
const DTYPE: DType = DType::F32;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// Enable decoding.
#[arg(long)]
decode: bool,
// Enable/disable decoding.
#[arg(long, default_value = "false")]
disable_cache: bool,
/// Use this prompt, otherwise compute sentence similarities.
#[arg(long)]
prompt: Option<String>,
/// If set along with --decode, will use this prompt to initialize the decoder.
#[arg(long)]
decoder_prompt: Option<String>,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
struct T5ModelBuilder {
device: Device,
config: t5::Config,
weights_filename: Vec<PathBuf>,
}
impl T5ModelBuilder {
pub fn load(args: &Args) -> Result<(Self, Tokenizer)> {
let device = candle_examples::device(args.cpu)?;
let default_model = "t5-small".to_string();
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id.to_owned(), args.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id.clone(), RepoType::Model, revision);
let api = Api::new()?;
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" {
vec![
api.get("model-00001-of-00005.safetensors")?,
api.get("model-00002-of-00005.safetensors")?,
api.get("model-00003-of-00005.safetensors")?,
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: t5::Config = serde_json::from_str(&config)?;
config.use_cache = !args.disable_cache;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
Ok((
Self {
device,
config,
weights_filename,
},
tokenizer,
))
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
};
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
};
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (builder, mut tokenizer) = T5ModelBuilder::load(&args)?;
let device = &builder.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
match args.prompt {
Some(prompt) => {
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
if !args.decode {
let mut model = builder.build_encoder()?;
let start = std::time::Instant::now();
let ys = model.forward(&input_token_ids)?;
println!("{ys}");
println!("Took {:?}", start.elapsed());
} else {
let mut model = builder.build_conditional_generation()?;
let mut output_token_ids = [builder.config.pad_token_id as u32].to_vec();
if let Some(decoder_prompt) = &args.decoder_prompt {
print!("{decoder_prompt}");
output_token_ids.extend(
tokenizer
.encode(decoder_prompt.to_string(), false)
.map_err(E::msg)?
.get_ids()
.to_vec(),
);
}
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(299792458, temperature, args.top_p);
let encoder_output = model.encode(&input_token_ids)?;
let start = std::time::Instant::now();
for index in 0.. {
if output_token_ids.len() > 512 {
break;
}
let decoder_token_ids = if index == 0 || !builder.config.use_cache {
Tensor::new(output_token_ids.as_slice(), device)?.unsqueeze(0)?
} else {
let last_token = *output_token_ids.last().unwrap();
Tensor::new(&[last_token], device)?.unsqueeze(0)?
};
let logits = model
.decode(&decoder_token_ids, &encoder_output)?
.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = output_token_ids.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&output_token_ids[start_at..],
)?
};
let next_token_id = logits_processor.sample(&logits)?;
if next_token_id as usize == builder.config.eos_token_id {
break;
}
output_token_ids.push(next_token_id);
if let Some(text) = tokenizer.id_to_token(next_token_id) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
let dt = start.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",
output_token_ids.len(),
output_token_ids.len() as f64 / dt.as_secs_f64(),
);
}
}
None => {
let mut model = builder.build_encoder()?;
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
let mut all_embeddings = Vec::with_capacity(n_sentences);
for sentence in sentences {
let tokens = tokenizer
.encode(sentence, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
all_embeddings.push(embeddings)
}
let mut similarities = vec![];
for (i, e_i) in all_embeddings.iter().enumerate() {
for (j, e_j) in all_embeddings
.iter()
.enumerate()
.take(n_sentences)
.skip(i + 1)
{
let sum_ij = (e_i * e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (e_i * e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (e_j * e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -0,0 +1,39 @@
# candle-whisper: speech recognition
An implementation of [OpenAI Whisper](https://github.com/openai/whisper) using
candle. Whisper is a general purpose speech recognition model, it can be used to
convert audio files (in the `.wav` format) to text. Supported features include
language detection as well as multilingual speech recognition.
## Running some example
If no audio file is passed as input, a [sample
file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/samples_jfk.wav) is automatically downloaded
from the hub.
```bash
cargo run --example whisper --release
> No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
> loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 }
> pcm data loaded 176000
> loaded mel: [1, 80, 3000]
> 0.0s -- 30.0s: And so my fellow Americans ask not what your country can do for you ask what you can do for your country
```
In order to use the multilingual mode, specify a multilingual model via the
`--model` flag, see the details below.
## Command line flags
- `--input`: the audio file to be converted to text, in wav format.
- `--language`: force the language to some specific value rather than being
detected, e.g. `en`.
- `--task`: the task to be performed, can be `transcribe` (return the text data
in the original language) or `translate` (translate the text to English).
- `--timestamps`: enable the timestamp mode where some timestamps are reported
for each recognized audio extracts.
- `--model`: the model to be used. Models that do not end with `-en` are
multilingual models, other ones are English only models. The supported models
are `tiny`, `tiny.en`, `base`, `base.en`, `small`, `small.en`, `medium`,
`medium.en`, `large`, and `large-v2`.

View File

@ -10,41 +10,56 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
mod audio;
mod model;
use model::{Config, Whisper};
mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
const DTYPE: DType = DType::F32;
pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}
// Audio parameters.
const SAMPLE_RATE: usize = 16000;
const N_FFT: usize = 400;
const N_MELS: usize = 80;
const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
@ -66,7 +81,7 @@ struct Segment {
}
struct Decoder {
model: Whisper,
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
@ -85,7 +100,7 @@ struct Decoder {
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Whisper,
model: Model,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
@ -94,12 +109,12 @@ impl Decoder {
timestamps: bool,
verbose: bool,
) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if model.config.suppress_tokens.contains(&i)
if model.config().suppress_tokens.contains(&i)
|| timestamps && i == no_timestamps_token
{
f32::NEG_INFINITY
@ -109,11 +124,11 @@ impl Decoder {
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@ -134,11 +149,11 @@ impl Decoder {
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
let audio_features = model.encoder_forward(mel, true)?;
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config.max_target_positions / 2;
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
@ -158,12 +173,12 @@ impl Decoder {
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder.forward(&tokens_t, &audio_features, i == 0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
@ -171,8 +186,7 @@ impl Decoder {
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder
.final_linear(&ys.i((..1, seq_len - 1..))?)?
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
@ -201,7 +215,7 @@ impl Decoder {
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config.max_target_positions {
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
@ -220,17 +234,17 @@ impl Decoder {
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() {
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t);
if i == TEMPERATURES.len() - 1 {
if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
@ -248,13 +262,13 @@ impl Decoder {
let mut segments = vec![];
while seek < content_frames {
let start = std::time::Instant::now();
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
@ -358,6 +372,7 @@ impl WhichModel {
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn => false,
}
}
fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny", "main"),
@ -407,6 +422,9 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
quantized: bool,
/// Language.
#[arg(long)]
language: Option<String>,
@ -438,10 +456,13 @@ fn main() -> Result<()> {
None
};
let device = candle_examples::device(args.cpu)?;
let (default_model, default_revision) = args.model.model_and_revision();
let (default_model, default_revision) = if args.quantized {
("lmz/candle-whisper", "main")
} else {
args.model.model_and_revision()
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let path = std::path::PathBuf::from(default_model.clone());
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
@ -449,20 +470,7 @@ fn main() -> Result<()> {
(None, None) => (default_model, default_revision),
};
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
let mut config_filename = path.clone();
config_filename.push("config.json");
let mut tokenizer_filename = path.clone();
tokenizer_filename.push("tokenizer.json");
let mut model_filename = path;
model_filename.push("model.safetensors");
(
config_filename,
tokenizer_filename,
model_filename,
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
)
} else {
let (config_filename, tokenizer_filename, weights_filename, input) = {
let api = Api::new()?;
let dataset = api.dataset("Narsil/candle-examples".to_string());
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
@ -476,12 +484,25 @@ fn main() -> Result<()> {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
dataset.get("samples_jfk.wav")?
};
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
sample,
)
let (config, tokenizer, model) = if args.quantized {
let ext = match args.model {
WhichModel::TinyEn => "tiny-en",
WhichModel::Tiny => "tiny",
_ => unimplemented!("no quantized support for {:?}", args.model),
};
(
repo.get(&format!("config-{ext}.json"))?,
repo.get(&format!("tokenizer-{ext}.json"))?,
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
(
repo.get("config.json")?,
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
)
};
(config, tokenizer, model, sample)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -492,8 +513,8 @@ fn main() -> Result<()> {
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@ -501,16 +522,21 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?;
let mut model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config)?)
};
let language_token = match (args.model.is_multilingual(), args.language) {
(true, None) => Some(multilingual::detect_language(&mut model, &tokenizer, &mel)?),

View File

@ -1,4 +1,3 @@
use crate::Whisper;
use candle::{IndexOp, Result, Tensor, D};
use tokenizers::Tokenizer;
@ -105,20 +104,28 @@ const LANGUAGES: [(&str, &str); 99] = [
];
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
pub fn detect_language(
model: &mut super::Model,
tokenizer: &Tokenizer,
mel: &Tensor,
) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(2, 0, usize::min(seq_len, model.config.max_source_positions))?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?;
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder.forward(&tokens, &audio_features, true)?;
let logits = model.decoder.final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;

View File

@ -0,0 +1,27 @@
# candle-wuerstchen: Efficient Pretraining of Text-to-Image Models
![anthropomorphic cat dressed as a fire fighter](./assets/cat.jpg)
The `wuerstchen` example is a port of the [diffusers
implementation](https://github.com/huggingface/diffusers/tree/19edca82f1ff194c07317369a92b470dbae97f34/src/diffusers/pipelines/wuerstchen) for Würstchen v2.
The candle implementation reproduces the same structure/files for models and
pipelines. Useful resources:
- [Official implementation](https://github.com/dome272/Wuerstchen).
- [Arxiv paper](https://arxiv.org/abs/2306.00637).
- Blog post: [Introducing Würstchen: Fast Diffusion for Image Generation](https://huggingface.co/blog/wuerstchen).
## Getting the weights
The weights are automatically downloaded for you from the [HuggingFace
Hub](https://huggingface.co/) on the first run. There are various command line
flags to use local files instead, run with `--help` to learn about them.
## Running some example.
```bash
cargo run --example wuerstchen --release --features cuda,cudnn -- \
--prompt "Anthropomorphic cat dressed as a fire fighter"
```
The final image is named `sd_final.png` by default.

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@ -0,0 +1,385 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::stable_diffusion;
use candle_transformers::models::wuerstchen;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use clap::Parser;
use tokenizers::Tokenizer;
const PRIOR_GUIDANCE_SCALE: f64 = 4.0;
const RESOLUTION_MULTIPLE: f64 = 42.67;
const LATENT_DIM_SCALE: f64 = 10.67;
const PRIOR_CIN: usize = 16;
const DECODER_CIN: usize = 4;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(
long,
default_value = "A very realistic photo of a rusty robot walking on a sandy beach"
)]
prompt: String,
#[arg(long, default_value = "")]
uncond_prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,
/// The decoder weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
decoder_weights: Option<String>,
/// The CLIP weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
clip_weights: Option<String>,
/// The CLIP weight file used by the prior model, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_clip_weights: Option<String>,
/// The prior weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
prior_weights: Option<String>,
/// The VQGAN weight file, in .safetensors format.
#[arg(long, value_name = "FILE")]
vqgan_weights: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for tokenization.
tokenizer: Option<String>,
#[arg(long, value_name = "FILE")]
/// The file specifying the tokenizer to used for prior tokenization.
prior_tokenizer: Option<String>,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
num_samples: i64,
/// The name of the final image to generate.
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
final_image: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
PriorTokenizer,
Clip,
PriorClip,
Decoder,
VqGan,
Prior,
}
impl ModelFile {
fn get(&self, filename: Option<String>) -> Result<std::path::PathBuf> {
use hf_hub::api::sync::Api;
match filename {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let repo_main = "warp-ai/wuerstchen";
let repo_prior = "warp-ai/wuerstchen-prior";
let (repo, path) = match self {
Self::Tokenizer => (repo_main, "tokenizer/tokenizer.json"),
Self::PriorTokenizer => (repo_prior, "tokenizer/tokenizer.json"),
Self::Clip => (repo_main, "text_encoder/model.safetensors"),
Self::PriorClip => (repo_prior, "text_encoder/model.safetensors"),
Self::Decoder => (repo_main, "decoder/diffusion_pytorch_model.safetensors"),
Self::VqGan => (repo_main, "vqgan/diffusion_pytorch_model.safetensors"),
Self::Prior => (repo_prior, "prior/diffusion_pytorch_model.safetensors"),
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)
}
}
}
}
fn output_filename(
basename: &str,
sample_idx: i64,
num_samples: i64,
timestep_idx: Option<usize>,
) -> String {
let filename = if num_samples > 1 {
match basename.rsplit_once('.') {
None => format!("{basename}.{sample_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}.{sample_idx}.{extension}")
}
}
} else {
basename.to_string()
};
match timestep_idx {
None => filename,
Some(timestep_idx) => match filename.rsplit_once('.') {
None => format!("{filename}-{timestep_idx}.png"),
Some((filename_no_extension, extension)) => {
format!("{filename_no_extension}-{timestep_idx}.{extension}")
}
},
}
}
fn encode_prompt(
prompt: &str,
uncond_prompt: Option<&str>,
tokenizer: std::path::PathBuf,
clip_weights: std::path::PathBuf,
clip_config: stable_diffusion::clip::Config,
device: &Device,
) -> Result<Tensor> {
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match &clip_config.pad_with {
Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens_len = tokens.len();
while tokens.len() < clip_config.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the clip transformer.");
let text_model =
stable_diffusion::build_clip_transformer(&clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward_with_mask(&tokens, tokens_len - 1)?;
match uncond_prompt {
None => Ok(text_embeddings),
Some(uncond_prompt) => {
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let uncond_tokens_len = uncond_tokens.len();
while uncond_tokens.len() < clip_config.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
let uncond_embeddings =
text_model.forward_with_mask(&uncond_tokens, uncond_tokens_len - 1)?;
let text_embeddings = Tensor::cat(&[text_embeddings, uncond_embeddings], 0)?;
Ok(text_embeddings)
}
}
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
cpu,
height,
width,
tokenizer,
final_image,
num_samples,
clip_weights,
prior_weights,
vqgan_weights,
decoder_weights,
tracing,
..
} = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(cpu)?;
let height = height.unwrap_or(1024);
let width = width.unwrap_or(1024);
let prior_text_embeddings = {
let tokenizer = ModelFile::PriorTokenizer.get(args.prior_tokenizer)?;
let weights = ModelFile::PriorClip.get(args.prior_clip_weights)?;
encode_prompt(
&prompt,
Some(&uncond_prompt),
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen_prior(),
&device,
)?
};
println!("generated prior text embeddings {prior_text_embeddings:?}");
let text_embeddings = {
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
let weights = ModelFile::Clip.get(clip_weights)?;
encode_prompt(
&prompt,
None,
tokenizer.clone(),
weights,
stable_diffusion::clip::Config::wuerstchen(),
&device,
)?
};
println!("generated text embeddings {text_embeddings:?}");
println!("Building the prior.");
let b_size = 1;
let image_embeddings = {
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
let prior = {
let file = ModelFile::Prior.get(prior_weights)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
};
wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN,
/* c */ 1536,
/* c_cond */ 1280,
/* c_r */ 64,
/* depth */ 32,
/* nhead */ 24,
args.use_flash_attn,
vb,
)?
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
let noise_pred = (noise_pred_uncond
+ ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
((latents * 42.)? - 1.)?
};
println!("Building the vqgan.");
let vqgan = {
let file = ModelFile::VqGan.get(vqgan_weights)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
};
wuerstchen::paella_vq::PaellaVQ::new(vb)?
};
println!("Building the decoder.");
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
let decoder = {
let file = ModelFile::Decoder.get(decoder_weights)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
};
wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ DECODER_CIN,
/* c_out */ DECODER_CIN,
/* c_r */ 64,
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
args.use_flash_attn,
vb,
)?
};
for idx in 0..num_samples {
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, DECODER_CIN, latent_height, latent_width),
&device,
)?;
println!("diffusion process with prior {image_embeddings:?}");
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(12, Default::default())?;
let timesteps = scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
let noise_pred =
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
latents = scheduler.step(&noise_pred, t, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
}
println!(
"Generating the final image for sample {}/{}.",
idx + 1,
num_samples
);
let image = vqgan.decode(&(&latents * 0.3764)?)?;
let image = (image.clamp(0f32, 1f32)? * 255.)?
.to_dtype(DType::U8)?
.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
candle_examples::save_image(&image, image_filename)?
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
run(args)
}

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;
@ -146,9 +146,7 @@ pub fn main() -> Result<()> {
// Create the model and load the weights from the file.
let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? };
let config = args.config()?;
let darknet = darknet::parse_config(config)?;
let model = darknet.build_model(vb)?;

View File

@ -0,0 +1,47 @@
# candle-yolo-v8: Object Detection and Pose Estimation
This is a port of [Ultralytics
YOLOv8](https://github.com/ultralytics/ultralytics). The implementation is based
on the [tinygrad
version](https://github.com/tinygrad/tinygrad/blob/master/examples/yolov8.py)
and on the model architecture described in this
[issue](https://github.com/ultralytics/ultralytics/issues/189). The supported
tasks are object detection and pose estimation.
You can try this model online on the [Candle YOLOv8
Space](https://huggingface.co/spaces/lmz/candle-yolo). The model then fully runs
in your browser using WebAssembly - if you use a custom image it will never
leave your phone/computer!
## Running some example
### Object Detection
```bash
cargo run --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
```
This prints details about the detected objects and generates a `bike.pp.jpg` file.
![Leading group, Giro d'Italia 2021](./assets/bike.jpg)
Image source:
[wikimedia](https://commons.wikimedia.org/wiki/File:Leading_group,_Giro_d%27Italia_2021,_Stage_15.jpg).
![Leading group, Giro d'Italia 2021](./assets/bike.od.jpg)
### Pose Estimation
```bash
cargo run --example yolo-v8 --release -- \
candle-examples/examples/yolo-v8/assets/peoples.jpeg --task pose
```
![Leading group, Giro d'Italia 2021](./assets/bike.pose.jpg)
### Command-line flags
- `--which`: select the model variant to be used, `n`, `s` , `m`, `l`, or `x` by
increasing size and quality.
- `--task`: `detect` for object detection and `pose` for pose estimation.
- `--legend-size`: the size of the characters to print.
- `--model`: use a local model file rather than downloading it from the hub.

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 175 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

View File

@ -7,9 +7,9 @@ extern crate accelerate_src;
mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
use image::DynamicImage;
@ -253,6 +253,14 @@ enum YoloTask {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Model weights, in safetensors format.
#[arg(long)]
model: Option<String>,
@ -363,6 +371,7 @@ impl Task for YoloV8Pose {
}
pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
let device = candle_examples::device(args.cpu)?;
// Create the model and load the weights from the file.
let multiples = match args.which {
Which::N => Multiples::n(),
@ -372,9 +381,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Which::X => Multiples::x(),
};
let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = T::load(vb, multiples)?;
println!("model loaded");
for image_name in args.images.iter() {
@ -405,7 +412,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Tensor::from_vec(
data,
(img.height() as usize, img.width() as usize, 3),
&Device::Cpu,
&device,
)?
.permute((2, 0, 1))?
};
@ -430,7 +437,19 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
}
pub fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
match args.task {
YoloTask::Detect => run::<YoloV8>(args)?,
YoloTask::Pose => run::<YoloV8Pose>(args)?,

View File

@ -77,6 +77,7 @@ impl Module for Upsample {
struct ConvBlock {
conv: Conv2d,
bn: BatchNorm,
span: tracing::Span,
}
impl ConvBlock {
@ -97,12 +98,17 @@ impl ConvBlock {
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
Ok(Self { conv, bn })
Ok(Self {
conv,
bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
}
}
impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
@ -114,6 +120,7 @@ struct Bottleneck {
cv1: ConvBlock,
cv2: ConvBlock,
residual: bool,
span: tracing::Span,
}
impl Bottleneck {
@ -123,12 +130,18 @@ impl Bottleneck {
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 3, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_, c2, 3, 1, None)?;
let residual = c1 == c2 && shortcut;
Ok(Self { cv1, cv2, residual })
Ok(Self {
cv1,
cv2,
residual,
span: tracing::span!(tracing::Level::TRACE, "bottleneck"),
})
}
}
impl Module for Bottleneck {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = self.cv2.forward(&self.cv1.forward(xs)?)?;
if self.residual {
xs + ys
@ -143,6 +156,7 @@ struct C2f {
cv1: ConvBlock,
cv2: ConvBlock,
bottleneck: Vec<Bottleneck>,
span: tracing::Span,
}
impl C2f {
@ -159,12 +173,14 @@ impl C2f {
cv1,
cv2,
bottleneck,
span: tracing::span!(tracing::Level::TRACE, "c2f"),
})
}
}
impl Module for C2f {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = self.cv1.forward(xs)?;
let mut ys = ys.chunk(2, 1)?;
for m in self.bottleneck.iter() {
@ -180,6 +196,7 @@ struct Sppf {
cv1: ConvBlock,
cv2: ConvBlock,
k: usize,
span: tracing::Span,
}
impl Sppf {
@ -187,12 +204,18 @@ impl Sppf {
let c_ = c1 / 2;
let cv1 = ConvBlock::load(vb.pp("cv1"), c1, c_, 1, 1, None)?;
let cv2 = ConvBlock::load(vb.pp("cv2"), c_ * 4, c2, 1, 1, None)?;
Ok(Self { cv1, cv2, k })
Ok(Self {
cv1,
cv2,
k,
span: tracing::span!(tracing::Level::TRACE, "sppf"),
})
}
}
impl Module for Sppf {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_, _, _, _) = xs.dims4()?;
let xs = self.cv1.forward(xs)?;
let xs2 = xs
@ -215,17 +238,23 @@ impl Module for Sppf {
struct Dfl {
conv: Conv2d,
num_classes: usize,
span: tracing::Span,
}
impl Dfl {
fn load(vb: VarBuilder, num_classes: usize) -> Result<Self> {
let conv = conv2d_no_bias(num_classes, 1, 1, Default::default(), vb.pp("conv"))?;
Ok(Self { conv, num_classes })
Ok(Self {
conv,
num_classes,
span: tracing::span!(tracing::Level::TRACE, "dfl"),
})
}
}
impl Module for Dfl {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, _channels, anchors) = xs.dims3()?;
let xs = xs
.reshape((b_sz, 4, self.num_classes, anchors))?
@ -247,6 +276,7 @@ struct DarkNet {
b4_0: ConvBlock,
b4_1: C2f,
b5: Sppf,
span: tracing::Span,
}
impl DarkNet {
@ -330,10 +360,12 @@ impl DarkNet {
b4_0,
b4_1,
b5,
span: tracing::span!(tracing::Level::TRACE, "darknet"),
})
}
fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let _enter = self.span.enter();
let x1 = self.b1_1.forward(&self.b1_0.forward(xs)?)?;
let x2 = self
.b2_2
@ -354,6 +386,7 @@ struct YoloV8Neck {
n4: C2f,
n5: ConvBlock,
n6: C2f,
span: tracing::Span,
}
impl YoloV8Neck {
@ -413,10 +446,12 @@ impl YoloV8Neck {
n4,
n5,
n6,
span: tracing::span!(tracing::Level::TRACE, "neck"),
})
}
fn forward(&self, p3: &Tensor, p4: &Tensor, p5: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let _enter = self.span.enter();
let x = self
.n1
.forward(&Tensor::cat(&[&self.up.forward(p5)?, p4], 1)?)?;
@ -440,6 +475,7 @@ struct DetectionHead {
cv3: [(ConvBlock, ConvBlock, Conv2d); 3],
ch: usize,
no: usize,
span: tracing::Span,
}
#[derive(Debug)]
@ -447,6 +483,7 @@ struct PoseHead {
detect: DetectionHead,
cv4: [(ConvBlock, ConvBlock, Conv2d); 3],
kpt: (usize, usize),
span: tracing::Span,
}
fn make_anchors(
@ -519,6 +556,7 @@ impl DetectionHead {
cv3,
ch,
no,
span: tracing::span!(tracing::Level::TRACE, "detection-head"),
})
}
@ -547,6 +585,7 @@ impl DetectionHead {
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<DetectionHeadOut> {
let _enter = self.span.enter();
let forward_cv = |xs, i: usize| {
let xs_2 = self.cv2[i].0.forward(xs)?;
let xs_2 = self.cv2[i].1.forward(&xs_2)?;
@ -606,7 +645,12 @@ impl PoseHead {
Self::load_cv4(vb.pp("cv4.1"), c4, nk, filters.1)?,
Self::load_cv4(vb.pp("cv4.2"), c4, nk, filters.2)?,
];
Ok(Self { detect, cv4, kpt })
Ok(Self {
detect,
cv4,
kpt,
span: tracing::span!(tracing::Level::TRACE, "pose-head"),
})
}
fn load_cv4(
@ -622,6 +666,7 @@ impl PoseHead {
}
fn forward(&self, xs0: &Tensor, xs1: &Tensor, xs2: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let d = self.detect.forward(xs0, xs1, xs2)?;
let forward_cv = |xs: &Tensor, i: usize| {
let (b_sz, _, h, w) = xs.dims4()?;
@ -650,6 +695,7 @@ pub struct YoloV8 {
net: DarkNet,
fpn: YoloV8Neck,
head: DetectionHead,
span: tracing::Span,
}
impl YoloV8 {
@ -657,12 +703,18 @@ impl YoloV8 {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = DetectionHead::load(vb.pp("head"), num_classes, m.filters())?;
Ok(Self { net, fpn, head })
Ok(Self {
net,
fpn,
head,
span: tracing::span!(tracing::Level::TRACE, "yolo-v8"),
})
}
}
impl Module for YoloV8 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
Ok(self.head.forward(&xs1, &xs2, &xs3)?.pred)
@ -674,6 +726,7 @@ pub struct YoloV8Pose {
net: DarkNet,
fpn: YoloV8Neck,
head: PoseHead,
span: tracing::Span,
}
impl YoloV8Pose {
@ -686,12 +739,18 @@ impl YoloV8Pose {
let net = DarkNet::load(vb.pp("net"), m)?;
let fpn = YoloV8Neck::load(vb.pp("fpn"), m)?;
let head = PoseHead::load(vb.pp("head"), num_classes, kpt, m.filters())?;
Ok(Self { net, fpn, head })
Ok(Self {
net,
fpn,
head,
span: tracing::span!(tracing::Level::TRACE, "yolo-v8-pose"),
})
}
}
impl Module for YoloV8Pose {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (xs1, xs2, xs3) = self.net.forward(xs)?;
let (xs1, xs2, xs3) = self.fpn.forward(&xs1, &xs2, &xs3)?;
self.head.forward(&xs1, &xs2, &xs3)

View File

@ -1,6 +1,6 @@
pub mod coco_classes;
pub mod imagenet;
pub mod object_detection;
pub mod token_output_stream;
use candle::{Device, Result, Tensor};

Some files were not shown because too many files have changed in this diff Show More