Files
candle/candle-examples/examples/stable-diffusion/embeddings.rs
Laurent Mazare d34039e352 Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.

* Proper computation of the causal mask.

* Add the chunk operation.

* Work in progress: port the attention module.

* Add some dummy modules for conv2d and group-norm, get the attention module to compile.

* Re-enable the 2d convolution.

* Add the embeddings module.

* Add the resnet module.

* Add the unet blocks.

* Add the unet.

* And add the variational auto-encoder.

* Use the pad function from utils.
2023-08-06 17:49:43 +01:00

66 lines
1.9 KiB
Rust

#![allow(dead_code)]
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
pub struct TimestepEmbedding {
linear_1: nn::Linear,
linear_2: nn::Linear,
}
impl TimestepEmbedding {
// act_fn: "silu"
pub fn new(vs: nn::VarBuilder, channel: usize, time_embed_dim: usize) -> Result<Self> {
let linear_1 = nn::linear(channel, time_embed_dim, vs.pp("linear_1"))?;
let linear_2 = nn::linear(time_embed_dim, time_embed_dim, vs.pp("linear_2"))?;
Ok(Self { linear_1, linear_2 })
}
}
impl TimestepEmbedding {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs)
}
}
#[derive(Debug)]
pub struct Timesteps {
num_channels: usize,
flip_sin_to_cos: bool,
downscale_freq_shift: f64,
}
impl Timesteps {
pub fn new(num_channels: usize, flip_sin_to_cos: bool, downscale_freq_shift: f64) -> Self {
Self {
num_channels,
flip_sin_to_cos,
downscale_freq_shift,
}
}
}
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent =
(Tensor::arange(0, half_dim, xs.device())?.to_dtype(xs.dtype())? * -f64::ln(10000.))?;
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);
let emb = if self.flip_sin_to_cos {
Tensor::cat(&[&cos, &sin], D::Minus1)?
} else {
Tensor::cat(&[&sin, &cos], D::Minus1)?
};
if self.num_channels % 2 == 1 {
crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
} else {
Ok(emb)
}
}
}