mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00

* add mmdit of stable diffusion 3 lint add comments * correct a misplaced comment * fix cargo fmt * fix clippy error * use bail! instead of assert! * use get_on_dim in splitting qkv
198 lines
5.1 KiB
Rust
198 lines
5.1 KiB
Rust
use candle::{bail, DType, Module, Result, Tensor};
|
|
use candle_nn as nn;
|
|
|
|
pub struct PatchEmbedder {
|
|
proj: nn::Conv2d,
|
|
}
|
|
|
|
impl PatchEmbedder {
|
|
pub fn new(
|
|
patch_size: usize,
|
|
in_channels: usize,
|
|
embed_dim: usize,
|
|
vb: nn::VarBuilder,
|
|
) -> Result<Self> {
|
|
let proj = nn::conv2d(
|
|
in_channels,
|
|
embed_dim,
|
|
patch_size,
|
|
nn::Conv2dConfig {
|
|
stride: patch_size,
|
|
..Default::default()
|
|
},
|
|
vb.pp("proj"),
|
|
)?;
|
|
|
|
Ok(Self { proj })
|
|
}
|
|
}
|
|
|
|
impl Module for PatchEmbedder {
|
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
let x = self.proj.forward(x)?;
|
|
|
|
// flatten spatial dim and transpose to channels last
|
|
let (b, c, h, w) = x.dims4()?;
|
|
x.reshape((b, c, h * w))?.transpose(1, 2)
|
|
}
|
|
}
|
|
|
|
pub struct Unpatchifier {
|
|
patch_size: usize,
|
|
out_channels: usize,
|
|
}
|
|
|
|
impl Unpatchifier {
|
|
pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
|
|
Ok(Self {
|
|
patch_size,
|
|
out_channels,
|
|
})
|
|
}
|
|
|
|
pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
|
|
let h = (h + 1) / self.patch_size;
|
|
let w = (w + 1) / self.patch_size;
|
|
|
|
let x = x.reshape((
|
|
x.dim(0)?,
|
|
h,
|
|
w,
|
|
self.patch_size,
|
|
self.patch_size,
|
|
self.out_channels,
|
|
))?;
|
|
let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
|
|
x.reshape((
|
|
x.dim(0)?,
|
|
self.out_channels,
|
|
self.patch_size * h,
|
|
self.patch_size * w,
|
|
))
|
|
}
|
|
}
|
|
|
|
pub struct PositionEmbedder {
|
|
pos_embed: Tensor,
|
|
patch_size: usize,
|
|
pos_embed_max_size: usize,
|
|
}
|
|
|
|
impl PositionEmbedder {
|
|
pub fn new(
|
|
hidden_size: usize,
|
|
patch_size: usize,
|
|
pos_embed_max_size: usize,
|
|
vb: nn::VarBuilder,
|
|
) -> Result<Self> {
|
|
let pos_embed = vb.get(
|
|
(1, pos_embed_max_size * pos_embed_max_size, hidden_size),
|
|
"pos_embed",
|
|
)?;
|
|
Ok(Self {
|
|
pos_embed,
|
|
patch_size,
|
|
pos_embed_max_size,
|
|
})
|
|
}
|
|
pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
|
|
let h = (h + 1) / self.patch_size;
|
|
let w = (w + 1) / self.patch_size;
|
|
|
|
if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
|
|
bail!("Input size is too large for the position embedding")
|
|
}
|
|
|
|
let top = (self.pos_embed_max_size - h) / 2;
|
|
let left = (self.pos_embed_max_size - w) / 2;
|
|
|
|
let pos_embed =
|
|
self.pos_embed
|
|
.reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
|
|
let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
|
|
pos_embed.reshape((1, h * w, ()))
|
|
}
|
|
}
|
|
|
|
pub struct TimestepEmbedder {
|
|
mlp: nn::Sequential,
|
|
frequency_embedding_size: usize,
|
|
}
|
|
|
|
impl TimestepEmbedder {
|
|
pub fn new(
|
|
hidden_size: usize,
|
|
frequency_embedding_size: usize,
|
|
vb: nn::VarBuilder,
|
|
) -> Result<Self> {
|
|
let mlp = nn::seq()
|
|
.add(nn::linear(
|
|
frequency_embedding_size,
|
|
hidden_size,
|
|
vb.pp("mlp.0"),
|
|
)?)
|
|
.add(nn::Activation::Silu)
|
|
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
|
|
|
|
Ok(Self {
|
|
mlp,
|
|
frequency_embedding_size,
|
|
})
|
|
}
|
|
|
|
fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
|
|
if dim % 2 != 0 {
|
|
bail!("Embedding dimension must be even")
|
|
}
|
|
|
|
if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
|
|
bail!("Input tensor must be floating point")
|
|
}
|
|
|
|
let half = dim / 2;
|
|
let freqs = Tensor::arange(0f32, half as f32, t.device())?
|
|
.to_dtype(candle::DType::F32)?
|
|
.mul(&Tensor::full(
|
|
(-f64::ln(max_period) / half as f64) as f32,
|
|
half,
|
|
t.device(),
|
|
)?)?
|
|
.exp()?;
|
|
|
|
let args = t
|
|
.unsqueeze(1)?
|
|
.to_dtype(candle::DType::F32)?
|
|
.matmul(&freqs.unsqueeze(0)?)?;
|
|
let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
|
|
embedding.to_dtype(candle::DType::F16)
|
|
}
|
|
}
|
|
|
|
impl Module for TimestepEmbedder {
|
|
fn forward(&self, t: &Tensor) -> Result<Tensor> {
|
|
let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
|
|
self.mlp.forward(&t_freq)
|
|
}
|
|
}
|
|
|
|
pub struct VectorEmbedder {
|
|
mlp: nn::Sequential,
|
|
}
|
|
|
|
impl VectorEmbedder {
|
|
pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
|
|
let mlp = nn::seq()
|
|
.add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
|
|
.add(nn::Activation::Silu)
|
|
.add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
|
|
|
|
Ok(Self { mlp })
|
|
}
|
|
}
|
|
|
|
impl Module for VectorEmbedder {
|
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
self.mlp.forward(x)
|
|
}
|
|
}
|