mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add a stable diffusion example (#328)
* Start adding a stable-diffusion example. * Proper computation of the causal mask. * Add the chunk operation. * Work in progress: port the attention module. * Add some dummy modules for conv2d and group-norm, get the attention module to compile. * Re-enable the 2d convolution. * Add the embeddings module. * Add the resnet module. * Add the unet blocks. * Add the unet. * And add the variational auto-encoder. * Use the pad function from utils.
This commit is contained in:
65
candle-examples/examples/stable-diffusion/embeddings.rs
Normal file
65
candle-examples/examples/stable-diffusion/embeddings.rs
Normal file
@ -0,0 +1,65 @@
|
||||
#![allow(dead_code)]
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TimestepEmbedding {
|
||||
linear_1: nn::Linear,
|
||||
linear_2: nn::Linear,
|
||||
}
|
||||
|
||||
impl TimestepEmbedding {
|
||||
// act_fn: "silu"
|
||||
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
|
||||
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
|
||||
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
|
||||
Ok(Self { linear_1, linear_2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl TimestepEmbedding {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
|
||||
self.linear_2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Timesteps {
|
||||
num_channels: usize,
|
||||
flip_sin_to_cos: bool,
|
||||
downscale_freq_shift: f64,
|
||||
}
|
||||
|
||||
impl Timesteps {
|
||||
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
|
||||
Self {
|
||||
num_channels,
|
||||
flip_sin_to_cos,
|
||||
downscale_freq_shift,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Timesteps {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let half_dim = (self.num_channels / 2) as u32;
|
||||
let exponent =
|
||||
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
|
||||
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
|
||||
let emb = exponent.exp()?;
|
||||
// emb = timesteps[:, None].float() * emb[None, :]
|
||||
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
|
||||
let (cos, sin) = (emb.cos()?, emb.sin()?);
|
||||
let emb = if self.flip_sin_to_cos {
|
||||
Tensor::cat(&[&cos, &sin], D::Minus1)?
|
||||
} else {
|
||||
Tensor::cat(&[&sin, &cos], D::Minus1)?
|
||||
};
|
||||
if self.num_channels % 2 == 1 {
|
||||
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
|
||||
} else {
|
||||
Ok(emb)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user