From 66750f98271dc852464d55689ca2d29f04f3fa34 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 15 Jul 2023 08:25:15 +0100 Subject: [PATCH] Add some 'cuda-if-available' helper function. (#172) --- candle-core/src/device.rs | 8 ++++++++ candle-core/src/utils.rs | 7 +++++++ candle-examples/examples/bert/main.rs | 15 +-------------- candle-examples/examples/falcon/main.rs | 15 +-------------- candle-examples/examples/llama/main.rs | 15 +-------------- candle-examples/examples/musicgen/main.rs | 17 ++--------------- candle-examples/examples/whisper/main.rs | 16 +--------------- candle-examples/src/lib.rs | 12 ++++++++++++ 8 files changed, 33 insertions(+), 72 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index ca408529..53e2de43 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -109,6 +109,14 @@ impl Device { } } + pub fn cuda_if_available(ordinal: usize) -> Result { + if crate::utils::cuda_is_available() { + Self::new_cuda(ordinal) + } else { + Ok(Self::Cpu) + } + } + pub(crate) fn rand_uniform( &self, shape: &Shape, diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index b5621e56..895c97e1 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -17,3 +17,10 @@ pub fn has_mkl() -> bool { #[cfg(not(feature = "mkl"))] return false; } + +pub fn cuda_is_available() -> bool { + #[cfg(feature = "cuda")] + return true; + #[cfg(not(feature = "cuda"))] + return false; +} diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index aae8bc50..dca6721b 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -495,20 +495,7 @@ struct Args { impl Args { fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { - #[cfg(feature = "cuda")] - let default_device = Device::new_cuda(0)?; - - #[cfg(not(feature = "cuda"))] - let default_device = { - println!("Running on CPU, to run on GPU, run this example with `--features cuda`"); - Device::Cpu - }; - - let device = if self.cpu { - Device::Cpu - } else { - default_device - }; + let device = candle_examples::device(self.cpu)?; let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); let default_revision = "refs/pr/21".to_string(); let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 7e20c7d2..7d5eaa52 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -120,20 +120,7 @@ struct Args { fn main() -> Result<()> { let args = Args::parse(); - #[cfg(feature = "cuda")] - let default_device = Device::new_cuda(0)?; - - #[cfg(not(feature = "cuda"))] - let default_device = { - println!("Running on CPU, to run on GPU, run this example with `--features cuda`"); - Device::Cpu - }; - let device = if args.cpu { - Device::Cpu - } else { - default_device - }; - + let device = candle_examples::device(args.cpu)?; let start = std::time::Instant::now(); let api = Api::new()?; let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision); diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 203b4606..aa02299d 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -134,20 +134,7 @@ fn main() -> Result<()> { let args = Args::parse(); - #[cfg(feature = "cuda")] - let default_device = Device::new_cuda(0)?; - - #[cfg(not(feature = "cuda"))] - let default_device = { - println!("Running on CPU, to run on GPU, run this example with `--features cuda`"); - Device::Cpu - }; - - let device = if args.cpu { - Device::Cpu - } else { - default_device - }; + let device = candle_examples::device(args.cpu)?; let config = Config::config_7b(); let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; diff --git a/candle-examples/examples/musicgen/main.rs b/candle-examples/examples/musicgen/main.rs index 90b464c3..3e136e90 100644 --- a/candle-examples/examples/musicgen/main.rs +++ b/candle-examples/examples/musicgen/main.rs @@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use nn::VarBuilder; use anyhow::{Error as E, Result}; -use candle::{DType, Device}; +use candle::DType; use clap::Parser; const DTYPE: DType = DType::F32; @@ -41,20 +41,7 @@ fn main() -> Result<()> { use tokenizers::Tokenizer; let args = Args::parse(); - #[cfg(feature = "cuda")] - let default_device = Device::new_cuda(0)?; - - #[cfg(not(feature = "cuda"))] - let default_device = { - println!("Running on CPU, to run on GPU, run this example with `--features cuda`"); - Device::Cpu - }; - let device = if args.cpu { - Device::Cpu - } else { - default_device - }; - + let device = candle_examples::device(args.cpu)?; let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?; let _tokenizer = tokenizer.with_padding(None).with_truncation(None); diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 09ef4593..d01fb605 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -257,21 +257,7 @@ struct Args { fn main() -> Result<()> { let args = Args::parse(); - - #[cfg(feature = "cuda")] - let default_device = Device::new_cuda(0)?; - - #[cfg(not(feature = "cuda"))] - let default_device = { - println!("Running on CPU, to run on GPU, run this example with `--features cuda`"); - Device::Cpu - }; - - let device = if args.cpu { - Device::Cpu - } else { - default_device - }; + let device = candle_examples::device(args.cpu)?; let default_model = "openai/whisper-tiny.en".to_string(); let path = std::path::PathBuf::from(default_model.clone()); let default_revision = "refs/pr/15".to_string(); diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 8b137891..285aee04 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -1 +1,13 @@ +use candle::{Device, Result}; +pub fn device(cpu: bool) -> Result { + if cpu { + Ok(Device::Cpu) + } else { + let device = Device::cuda_if_available(0)?; + if !device.is_cuda() { + println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + } + Ok(device) + } +}