Add the const-set op. (#2910)

* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
This commit is contained in:
Laurent Mazare
2025-04-19 10:07:02 +02:00
committed by GitHub
parent b2904a830b
commit a4c56a958e
20 changed files with 414 additions and 209 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0-alpha.4"
version = "0.9.0-alpha.5"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.5" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.5" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.5" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.5" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.5" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.5" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.5" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.5" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }