Files
candle/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
Laurent Mazare d34039e352 Add a stable diffusion example (#328)
* Start adding a stable-diffusion example.

* Proper computation of the causal mask.

* Add the chunk operation.

* Work in progress: port the attention module.

* Add some dummy modules for conv2d and group-norm, get the attention module to compile.

* Re-enable the 2d convolution.

* Add the embeddings module.

* Add the resnet module.

* Add the unet blocks.

* Add the unet.

* And add the variational auto-encoder.

* Use the pad function from utils.
2023-08-06 17:49:43 +01:00

810 lines
24 KiB
Rust

#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use candle::{Result, Tensor};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
conv: Option<nn::Conv2d>,
padding: usize,
}
impl Downsample2D {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
use_conv: bool,
out_channels: usize,
padding: usize,
) -> Result<Self> {
let conv = if use_conv {
let config = nn::Conv2dConfig { stride: 2, padding };
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
None
};
Ok(Downsample2D { conv, padding })
}
}
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match &self.conv {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
Some(conv) => {
if self.padding == 0 {
let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
conv.forward(&xs)
} else {
conv.forward(xs)
}
}
}
}
}
// This does not support the conv-transpose mode.
#[derive(Debug)]
struct Upsample2D {
conv: nn::Conv2d,
}
impl Upsample2D {
fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
let config = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Ok(Self { conv })
}
}
impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let xs = match size {
None => {
// The following does not work and it's tricky to pass no fixed
// dimensions so hack our way around this.
// xs.upsample_nearest2d(&[], Some(2.), Some(2.)
let (_bsize, _channels, _h, _w) = xs.dims4()?;
crate::utils::upsample_nearest2d(xs)? // [2 * h, 2 * w], Some(2.), Some(2.))
}
Some((_h, _w)) => crate::utils::upsample_nearest2d(xs)?, // [h, w], None, None),
};
self.conv.forward(&xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownEncoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownEncoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
pub config: DownEncoderBlock2DConfig,
}
impl DownEncoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DownEncoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
out_channels: Some(out_channels),
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let downsampler = if config.add_downsample {
let downsample = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsample)
} else {
None
};
Ok(Self {
resnets,
downsampler,
config,
})
}
}
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.downsampler {
Some(downsampler) => downsampler.forward(&xs),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpDecoderBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpDecoderBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
pub config: UpDecoderBlock2DConfig,
}
impl UpDecoderBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: UpDecoderBlock2DConfig,
) -> Result<Self> {
let resnets: Vec<_> = {
let vs = vs.pp("resnets");
let conv_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
groups: config.resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels: None,
..Default::default()
};
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
let upsampler = if config.add_upsample {
let upsample =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsample)
} else {
None
};
Ok(Self {
resnets,
upsampler,
config,
})
}
}
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, None),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: Option<usize>,
// attention_type "default"
pub output_scale_factor: f64,
}
impl Default for UNetMidBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: Some(1),
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2D {
resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
pub config: UNetMidBlock2DConfig,
}
impl UNetMidBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let attn_cfg = AttentionBlockConfig {
num_head_channels: config.attn_num_head_channels,
num_groups: resnet_groups,
rescale_output_factor: config.output_scale_factor,
eps: config.resnet_eps,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
Ok(Self {
resnet,
attn_resnets,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DCrossAttnConfig {
pub num_layers: usize,
pub resnet_eps: f64,
pub resnet_groups: Option<usize>,
pub attn_num_head_channels: usize,
// attention_type "default"
pub output_scale_factor: f64,
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: Some(32),
attn_num_head_channels: 1,
output_scale_factor: 1.,
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
pub config: UNetMidBlock2DCrossAttnConfig,
}
impl UNetMidBlock2DCrossAttn {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let vs_attns = vs.pp("attentions");
let resnet_groups = config
.resnet_groups
.unwrap_or_else(|| usize::min(in_channels / 4, 32));
let resnet_cfg = ResnetBlock2DConfig {
eps: config.resnet_eps,
groups: resnet_groups,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
depth: 1,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
vs_attns.pp(&index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
attn_resnets.push((attn, resnet))
}
Ok(Self {
resnet,
attn_resnets,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
}
Ok(xs)
}
}
#[derive(Debug, Clone, Copy)]
pub struct DownBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_downsample: bool,
pub downsample_padding: usize,
}
impl Default for DownBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_downsample: true,
downsample_padding: 1,
}
}
}
#[derive(Debug)]
pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
pub config: DownBlock2DConfig,
}
impl DownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: DownBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
temb_channels,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
let downsampler = Downsample2D::new(
vs.pp("downsamplers").pp("0"),
out_channels,
true,
out_channels,
config.downsample_padding,
)?;
Some(downsampler)
} else {
None
};
Ok(Self {
resnets,
downsampler,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let mut xs = xs.clone();
let mut output_states = vec![];
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, temb)?;
output_states.push(xs.clone());
}
let xs = match &self.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnDownBlock2DConfig {
pub downblock: DownBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for CrossAttnDownBlock2DConfig {
fn default() -> Self {
Self {
downblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>,
pub config: CrossAttnDownBlock2DConfig,
}
impl CrossAttnDownBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
vs.clone(),
in_channels,
out_channels,
temb_channels,
config.downblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
downblock,
attentions,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> {
let mut output_states = vec![];
let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
xs = resnet.forward(&xs, temb)?;
xs = attn.forward(&xs, encoder_hidden_states)?;
output_states.push(xs.clone());
}
let xs = match &self.downblock.downsampler {
Some(downsampler) => {
let xs = downsampler.forward(&xs)?;
output_states.push(xs.clone());
xs
}
None => xs,
};
Ok((xs, output_states))
}
}
#[derive(Debug, Clone, Copy)]
pub struct UpBlock2DConfig {
pub num_layers: usize,
pub resnet_eps: f64,
// resnet_time_scale_shift: "default"
// resnet_act_fn: "swish"
pub resnet_groups: usize,
pub output_scale_factor: f64,
pub add_upsample: bool,
}
impl Default for UpBlock2DConfig {
fn default() -> Self {
Self {
num_layers: 1,
resnet_eps: 1e-6,
resnet_groups: 32,
output_scale_factor: 1.,
add_upsample: true,
}
}
}
#[derive(Debug)]
pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
pub config: UpBlock2DConfig,
}
impl UpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: UpBlock2DConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
let resnet_cfg = ResnetBlock2DConfig {
out_channels: Some(out_channels),
temb_channels,
eps: config.resnet_eps,
output_scale_factor: config.output_scale_factor,
..Default::default()
};
let resnets = (0..config.num_layers)
.map(|i| {
let res_skip_channels = if i == config.num_layers - 1 {
in_channels
} else {
out_channels
};
let resnet_in_channels = if i == 0 {
prev_output_channels
} else {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
let upsampler =
Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
Some(upsampler)
} else {
None
};
Ok(Self {
resnets,
upsampler,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = resnet.forward(&xs, temb)?;
}
match &self.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnUpBlock2DConfig {
pub upblock: UpBlock2DConfig,
pub attn_num_head_channels: usize,
pub cross_attention_dim: usize,
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for CrossAttnUpBlock2DConfig {
fn default() -> Self {
Self {
upblock: Default::default(),
attn_num_head_channels: 1,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>,
pub config: CrossAttnUpBlock2DConfig,
}
impl CrossAttnUpBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
vs.clone(),
in_channels,
prev_output_channels,
out_channels,
temb_channels,
config.upblock,
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
depth: 1,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
use_linear_projection: config.use_linear_projection,
};
let vs_attn = vs.pp("attentions");
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
cfg,
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self {
upblock,
attentions,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
res_xs: &[Tensor],
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
xs = resnet.forward(&xs, temb)?;
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
}
match &self.upblock.upsampler {
Some(upsampler) => upsampler.forward(&xs, upsample_size),
None => Ok(xs),
}
}
}