Enable stable-diffusion 3 on metal. (#2560)

This commit is contained in:
Laurent Mazare
2024-10-14 08:59:12 +02:00
committed by GitHub
parent f553ab5eb4
commit 3d1dc06cdb
4 changed files with 11 additions and 12 deletions

View File

@ -122,6 +122,3 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]
[[example]]
name = "stable-diffusion-3"

View File

@ -30,9 +30,9 @@ struct Args {
#[arg(long)]
cpu: bool,
/// The CUDA device ID to use.
#[arg(long, default_value = "0")]
cuda_device_id: usize,
/// The GPU device ID to use.
#[arg(long, default_value_t = 0)]
gpu_device_id: usize,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> {
prompt,
uncond_prompt,
cpu,
cuda_device_id,
gpu_device_id,
tracing,
use_flash_attn,
height,
@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> {
None
};
// TODO: Support and test on Metal.
let device = if cpu {
candle::Device::Cpu
} else if candle::utils::cuda_is_available() {
candle::Device::new_cuda(gpu_device_id)?
} else if candle::utils::metal_is_available() {
candle::Device::new_metal(gpu_device_id)?
} else {
candle::Device::cuda_if_available(cuda_device_id)?
candle::Device::Cpu
};
let api = hf_hub::api::sync::Api::new()?;

View File

@ -31,7 +31,7 @@ pub fn euler_sample(
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
&Tensor::full(timestep, (2,), x.device())?.contiguous()?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
)?;

View File

@ -1,9 +1,8 @@
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub decoder_vocab_size: Option<usize>,