From 25960676caefcb10060fb36a8d66efa9fa731dec Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 9 Jul 2024 12:38:11 +0200 Subject: [PATCH] Add a basic metal example with capture (#2324) * Add some tracing. * Get the trace to work. --- candle-core/Cargo.toml | 4 ++++ candle-core/examples/metal_basics.rs | 28 +++++++++++++++++++++++++ candle-core/src/metal_backend/device.rs | 8 ++++++- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 candle-core/examples/metal_basics.rs diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 92a04917..cbf8f200 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -48,3 +48,7 @@ metal = ["dep:metal", "dep:candle-metal-kernels"] [[bench]] name = "bench_main" harness = false + +[[example]] +name = "metal_basics" +required-features = ["metal"] diff --git a/candle-core/examples/metal_basics.rs b/candle-core/examples/metal_basics.rs new file mode 100644 index 00000000..f9ff81ad --- /dev/null +++ b/candle-core/examples/metal_basics.rs @@ -0,0 +1,28 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle_core::{Device, Tensor}; + +fn main() -> Result<()> { + // This requires the code to be run with MTL_CAPTURE_ENABLED=1 + let device = Device::new_metal(0)?; + let metal_device = match &device { + Device::Metal(m) => m, + _ => anyhow::bail!("unexpected device"), + }; + metal_device.capture("/tmp/candle.gputrace")?; + // This first synchronize ensures that a new command buffer gets created after setting up the + // capture scope. + device.synchronize()?; + let x = Tensor::randn(0f32, 1.0, (128, 128), &device)?; + let x1 = x.add(&x)?; + println!("{x1:?}"); + // This second synchronize ensures that the command buffer gets commited before the end of the + // capture scope. + device.synchronize()?; + Ok(()) +} diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 785fe621..07210c68 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -273,7 +273,13 @@ impl MetalDevice { let descriptor = metal::CaptureDescriptor::new(); descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); descriptor.set_capture_device(self); - descriptor.set_output_url(path); + // The [set_output_url] call requires an absolute path so we convert it if needed. + if path.as_ref().is_absolute() { + descriptor.set_output_url(path); + } else { + let path = std::env::current_dir()?.join(path); + descriptor.set_output_url(path); + } capture .start_capture(&descriptor)