Add some 'cuda-if-available' helper function. (#172)

This commit is contained in:
Laurent Mazare
2023-07-15 08:25:15 +01:00
committed by GitHub
parent 2ddda706bd
commit 66750f9827
8 changed files with 33 additions and 72 deletions

View File

@ -16,7 +16,7 @@ use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
use nn::VarBuilder;
use anyhow::{Error as E, Result};
use candle::{DType, Device};
use candle::DType;
use clap::Parser;
const DTYPE: DType = DType::F32;
@ -41,20 +41,7 @@ fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
#[cfg(feature = "cuda")]
let default_device = Device::new_cuda(0)?;
#[cfg(not(feature = "cuda"))]
let default_device = {
println!("Running on CPU, to run on GPU, run this example with `--features cuda`");
Device::Cpu
};
let device = if args.cpu {
Device::Cpu
} else {
default_device
};
let device = candle_examples::device(args.cpu)?;
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);