mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Enable stable-diffusion 3 on metal. (#2560)
This commit is contained in:
@ -122,6 +122,3 @@ required-features = ["onnx"]
|
||||
[[example]]
|
||||
name = "colpali"
|
||||
required-features = ["pdf2image"]
|
||||
|
||||
[[example]]
|
||||
name = "stable-diffusion-3"
|
@ -30,9 +30,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The CUDA device ID to use.
|
||||
#[arg(long, default_value = "0")]
|
||||
cuda_device_id: usize,
|
||||
/// The GPU device ID to use.
|
||||
#[arg(long, default_value_t = 0)]
|
||||
gpu_device_id: usize,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> {
|
||||
prompt,
|
||||
uncond_prompt,
|
||||
cpu,
|
||||
cuda_device_id,
|
||||
gpu_device_id,
|
||||
tracing,
|
||||
use_flash_attn,
|
||||
height,
|
||||
@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: Support and test on Metal.
|
||||
let device = if cpu {
|
||||
candle::Device::Cpu
|
||||
} else if candle::utils::cuda_is_available() {
|
||||
candle::Device::new_cuda(gpu_device_id)?
|
||||
} else if candle::utils::metal_is_available() {
|
||||
candle::Device::new_metal(gpu_device_id)?
|
||||
} else {
|
||||
candle::Device::cuda_if_available(cuda_device_id)?
|
||||
candle::Device::Cpu
|
||||
};
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
|
@ -31,7 +31,7 @@ pub fn euler_sample(
|
||||
let timestep = (*s_curr) * 1000.0;
|
||||
let noise_pred = mmdit.forward(
|
||||
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
|
||||
&Tensor::full(timestep, (2,), x.device())?.contiguous()?,
|
||||
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
|
||||
y,
|
||||
context,
|
||||
)?;
|
||||
|
@ -1,9 +1,8 @@
|
||||
use super::with_tracing::{linear, Embedding, Linear};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
|
Reference in New Issue
Block a user