mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add some 'cuda-if-available' helper function. (#172)
This commit is contained in:
@ -109,6 +109,14 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
|
||||||
|
if crate::utils::cuda_is_available() {
|
||||||
|
Self::new_cuda(ordinal)
|
||||||
|
} else {
|
||||||
|
Ok(Self::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn rand_uniform(
|
pub(crate) fn rand_uniform(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
|
@ -17,3 +17,10 @@ pub fn has_mkl() -> bool {
|
|||||||
#[cfg(not(feature = "mkl"))]
|
#[cfg(not(feature = "mkl"))]
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn cuda_is_available() -> bool {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
return true;
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
@ -495,20 +495,7 @@ struct Args {
|
|||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> {
|
||||||
#[cfg(feature = "cuda")]
|
let device = candle_examples::device(self.cpu)?;
|
||||||
let default_device = Device::new_cuda(0)?;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
let default_device = {
|
|
||||||
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
|
|
||||||
Device::Cpu
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = if self.cpu {
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
default_device
|
|
||||||
};
|
|
||||||
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
let default_model = "sentence-transformers/all-MiniLM-L6-v2".to_string();
|
||||||
let default_revision = "refs/pr/21".to_string();
|
let default_revision = "refs/pr/21".to_string();
|
||||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||||
|
@ -120,20 +120,7 @@ struct Args {
|
|||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_device = Device::new_cuda(0)?;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
let default_device = {
|
|
||||||
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
|
|
||||||
Device::Cpu
|
|
||||||
};
|
|
||||||
let device = if args.cpu {
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
default_device
|
|
||||||
};
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
|
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
|
||||||
|
@ -134,20 +134,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_device = Device::new_cuda(0)?;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
let default_device = {
|
|
||||||
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
|
|
||||||
Device::Cpu
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = if args.cpu {
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
default_device
|
|
||||||
};
|
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||||
|
@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
|||||||
use nn::VarBuilder;
|
use nn::VarBuilder;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device};
|
use candle::DType;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
@ -41,20 +41,7 @@ fn main() -> Result<()> {
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
#[cfg(feature = "cuda")]
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_device = Device::new_cuda(0)?;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
let default_device = {
|
|
||||||
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
|
|
||||||
Device::Cpu
|
|
||||||
};
|
|
||||||
let device = if args.cpu {
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
default_device
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
||||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||||
|
|
||||||
|
@ -257,21 +257,7 @@ struct Args {
|
|||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
#[cfg(feature = "cuda")]
|
|
||||||
let default_device = Device::new_cuda(0)?;
|
|
||||||
|
|
||||||
#[cfg(not(feature = "cuda"))]
|
|
||||||
let default_device = {
|
|
||||||
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
|
|
||||||
Device::Cpu
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = if args.cpu {
|
|
||||||
Device::Cpu
|
|
||||||
} else {
|
|
||||||
default_device
|
|
||||||
};
|
|
||||||
let default_model = "openai/whisper-tiny.en".to_string();
|
let default_model = "openai/whisper-tiny.en".to_string();
|
||||||
let path = std::path::PathBuf::from(default_model.clone());
|
let path = std::path::PathBuf::from(default_model.clone());
|
||||||
let default_revision = "refs/pr/15".to_string();
|
let default_revision = "refs/pr/15".to_string();
|
||||||
|
@ -1 +1,13 @@
|
|||||||
|
use candle::{Device, Result};
|
||||||
|
|
||||||
|
pub fn device(cpu: bool) -> Result<Device> {
|
||||||
|
if cpu {
|
||||||
|
Ok(Device::Cpu)
|
||||||
|
} else {
|
||||||
|
let device = Device::cuda_if_available(0)?;
|
||||||
|
if !device.is_cuda() {
|
||||||
|
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
|
||||||
|
}
|
||||||
|
Ok(device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user