diff --git a/Cargo.toml b/Cargo.toml index 299ccb3b..316d9e75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ exclude = [ resolver = "2" [workspace.package] -version = "0.9.0-alpha.3" +version = "0.9.0-alpha.4" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" @@ -33,17 +33,17 @@ ab_glyph = "0.2.23" accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" -candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.3" } -candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.3" } -candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.3" } -candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.3" } -candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.3" } -candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.3" } -candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.3" } -candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.3" } +candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" } +candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" } +candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" } +candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" } +candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" } +candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" } +candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" } +candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features=false } -cudarc = { version = "0.15.2", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } +cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.4.1" @@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" -ug = "0.3.1" -ug-cuda = "0.3.1" -ug-metal = "0.3.1" +ug = "0.4.0" +ug-cuda = "0.4.0" +ug-metal = "0.4.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "1.1.1", default-features = false } metal = { version = "0.27.0", features = ["mps"]} diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index a2674d67..7dd18b7a 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -144,6 +144,24 @@ impl CudaDevice { self.stream.clone() } + /// When turned on, all cuda tensors **created after calling this function** will + /// not track uses via cuda events. + /// + /// # Safety + /// + /// It is up to the user to ensure proper synchronization between multiple streams: + /// - Ensure that no tensor is freed before a use on another stream is finished. + /// - Ensure that a tensor is not used on another stream before allocation on the + /// allocating stream finishes. + /// - Ensure that a tensor is not written two concurrently by multiple streams. + pub unsafe fn disable_event_tracking(&self) { + self.context.disable_event_tracking() + } + + pub fn is_event_tracking(&self) -> bool { + self.context.is_event_tracking() + } + #[cfg(not(target_arch = "wasm32"))] pub fn compile( &self, diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 1a82bf1f..6471a6ac 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -256,6 +256,12 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let tokenizer = common_args.tokenizer()?; let device = candle_examples::device(common_args.cpu)?; + #[cfg(feature = "cuda")] + if let candle::Device::Cuda(d) = &device { + unsafe { + d.disable_event_tracking(); + } + }; let is_gguf = config_path.extension().map_or(false, |v| v == "gguf"); let is_safetensors = config_path diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 59007118..40063ba9 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.9.0-alpha.3" +version = "0.9.0-alpha.4" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.3" } +candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index 66152928..f786aaa4 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.9.0-alpha.3" +version = "0.9.0-alpha.4" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 7fad905c..d84f6824 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.9.0-alpha.3" +version = "0.9.0-alpha.4" edition = "2021" description = "Metal kernels for Candle" diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml index b118ef68..6954257d 100644 --- a/candle-onnx/Cargo.toml +++ b/candle-onnx/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-onnx" -version = "0.9.0-alpha.3" +version = "0.9.0-alpha.4" edition = "2021" description = "ONNX support for Candle" @@ -10,8 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.3" } -candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.3" } +candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" } +candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" } prost = "0.12.1" [build-dependencies]