Add some 'cuda-if-available' helper function. (#172)

This commit is contained in:
Laurent Mazare
2023-07-15 08:25:15 +01:00
committed by GitHub
parent 2ddda706bd
commit 66750f9827
8 changed files with 33 additions and 72 deletions

View File

@ -109,6 +109,14 @@ impl Device {
} }
} }
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
if crate::utils::cuda_is_available() {
Self::new_cuda(ordinal)
} else {
Ok(Self::Cpu)
}
}
pub(crate) fn rand_uniform( pub(crate) fn rand_uniform(
&self, &self,
shape: &Shape, shape: &Shape,

View File

@ -17,3 +17,10 @@ pub fn has_mkl() -> bool {
#[cfg(not(feature = "mkl"))] #[cfg(not(feature = "mkl"))]
return false; return false;
} }
pub fn cuda_is_available() -> bool {
#[cfg(feature = "cuda")]
return true;
#[cfg(not(feature = "cuda"))]
return false;
}

View File

@ -495,20 +495,7 @@ struct Args {
impl Args { impl Args {
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
#[cfg(feature = "cuda")] let device = candle_examples::device(self.cpu)?;
let default_device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let default_device = {
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
Device::Cpu
};
let device = if self.cpu {
Device::Cpu
} else {
default_device
};
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string(); let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
let default_revision = "refs/pr/21".to_string(); let default_revision = "refs/pr/21".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {

View File

@ -120,20 +120,7 @@ struct Args {
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
#[cfg(feature = "cuda")] let device = candle_examples::device(args.cpu)?;
let default_device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let default_device = {
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
Device::Cpu
};
let device = if args.cpu {
Device::Cpu
} else {
default_device
};
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision); let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);

View File

@ -134,20 +134,7 @@ fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
#[cfg(feature = "cuda")] let device = candle_examples::device(args.cpu)?;
let default_device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let default_device = {
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
Device::Cpu
};
let device = if args.cpu {
Device::Cpu
} else {
default_device
};
let config = Config::config_7b(); let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };

View File

@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder; use nn::VarBuilder;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device}; use candle::DType;
use clap::Parser; use clap::Parser;
const DTYPE: DType = DType::F32; const DTYPE: DType = DType::F32;
@ -41,20 +41,7 @@ fn main() -> Result<()> {
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
#[cfg(feature = "cuda")] let device = candle_examples::device(args.cpu)?;
let default_device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let default_device = {
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
Device::Cpu
};
let device = if args.cpu {
Device::Cpu
} else {
default_device
};
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?; let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None); let _tokenizer = tokenizer.with_padding(None).with_truncation(None);

View File

@ -257,21 +257,7 @@ struct Args {
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
#[cfg(feature = "cuda")]
let default_device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let default_device = {
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
Device::Cpu
};
let device = if args.cpu {
Device::Cpu
} else {
default_device
};
let default_model = "openai/whisper-tiny.en".to_string(); let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone()); let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string(); let default_revision = "refs/pr/15".to_string();

View File

@ -1 +1,13 @@
use candle::{Device, Result};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else {
let device = Device::cuda_if_available(0)?;
if !device.is_cuda() {
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(device)
}
}