mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Use cudarc 0.16. (#2900)
* Use cudarc 0.16. * Allow for disabling event tracking. * Tweaks. * Bump the ug version. * And bump the candle version too.
This commit is contained in:
26
Cargo.toml
26
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.4"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.3" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.3" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.3" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.3" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.3" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.3" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.3" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.15.2", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
ug = "0.3.1"
|
ug = "0.4.0"
|
||||||
ug-cuda = "0.3.1"
|
ug-cuda = "0.4.0"
|
||||||
ug-metal = "0.3.1"
|
ug-metal = "0.4.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -144,6 +144,24 @@ impl CudaDevice {
|
|||||||
self.stream.clone()
|
self.stream.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// When turned on, all cuda tensors **created after calling this function** will
|
||||||
|
/// not track uses via cuda events.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// It is up to the user to ensure proper synchronization between multiple streams:
|
||||||
|
/// - Ensure that no tensor is freed before a use on another stream is finished.
|
||||||
|
/// - Ensure that a tensor is not used on another stream before allocation on the
|
||||||
|
/// allocating stream finishes.
|
||||||
|
/// - Ensure that a tensor is not written two concurrently by multiple streams.
|
||||||
|
pub unsafe fn disable_event_tracking(&self) {
|
||||||
|
self.context.disable_event_tracking()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_event_tracking(&self) -> bool {
|
||||||
|
self.context.is_event_tracking()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
pub fn compile(
|
pub fn compile(
|
||||||
&self,
|
&self,
|
||||||
|
@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
|||||||
let tokenizer = common_args.tokenizer()?;
|
let tokenizer = common_args.tokenizer()?;
|
||||||
|
|
||||||
let device = candle_examples::device(common_args.cpu)?;
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
if let candle::Device::Cuda(d) = &device {
|
||||||
|
unsafe {
|
||||||
|
d.disable_event_tracking();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||||
let is_safetensors = config_path
|
let is_safetensors = config_path
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.4"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.4"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.4"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.4"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.3" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
Reference in New Issue
Block a user