mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Enable stable-diffusion 3 on metal. (#2560)
This commit is contained in:
@ -30,9 +30,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The CUDA device ID to use.
|
||||
#[arg(long, default_value = "0")]
|
||||
cuda_device_id: usize,
|
||||
/// The GPU device ID to use.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
gpu_device_id: usize,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
cuda_device_id,
|
||||
gpu_device_id,
|
||||
tracing,
|
||||
use_flash_attn,
|
||||
height,
|
||||
@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: Support and test on Metal.
|
||||
let device = if cpu {
|
||||
candle::Device::Cpu
|
||||
} else if candle::utils::cuda_is_available() {
|
||||
candle::Device::new_cuda(gpu_device_id)?
|
||||
} else if candle::utils::metal_is_available() {
|
||||
candle::Device::new_metal(gpu_device_id)?
|
||||
} else {
|
||||
candle::Device::cuda_if_available(cuda_device_id)?
|
||||
candle::Device::Cpu
|
||||
};
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
Reference in New Issue
Block a user