mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some 'cuda-if-available' helper function. (#172)
This commit is contained in:
@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
use nn::VarBuilder;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device};
|
||||
use candle::DType;
|
||||
use clap::Parser;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
@ -41,20 +41,7 @@ fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
#[cfg(feature = "cuda")]
|
||||
let default_device = Device::new_cuda(0)?;
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let default_device = {
|
||||
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
|
||||
Device::Cpu
|
||||
};
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
} else {
|
||||
default_device
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
|
||||
|
Reference in New Issue
Block a user