From c5a058b16954154a68899d26d681adfba003babf Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 11 Sep 2023 20:40:07 +0100 Subject: [PATCH] Use the module trait in stable-diffusion. (#817) --- candle-core/src/safetensors.rs | 8 ++------ candle-examples/examples/stable-diffusion/main.rs | 2 +- .../src/models/stable_diffusion/attention.rs | 8 +++++--- .../src/models/stable_diffusion/clip.rs | 8 ++++---- .../src/models/stable_diffusion/embeddings.rs | 8 ++++---- .../src/models/stable_diffusion/unet_2d_blocks.rs | 12 ++++++------ .../src/models/stable_diffusion/vae.rs | 15 ++++++++------- 7 files changed, 30 insertions(+), 31 deletions(-) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index f37bb8ef..d588ea67 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -78,11 +78,7 @@ impl st::View for &Tensor { } impl Tensor { - pub fn save_safetensors>( - &self, - name: &str, - filename: P, - ) -> Result<()> { + pub fn save_safetensors>(&self, name: &str, filename: P) -> Result<()> { let data = [(name, self.clone())]; Ok(st::serialize_to_file(data, &None, filename.as_ref())?) } @@ -267,7 +263,7 @@ impl MmapedFile { /// # Safety /// /// The unsafe is inherited from [`memmap2::MmapOptions`]. - pub unsafe fn new>(p: P) -> Result { + pub unsafe fn new>(p: P) -> Result { let p = p.as_ref(); let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let inner = memmap2::MmapOptions::new() diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 6bce2917..c8b771a0 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -7,7 +7,7 @@ extern crate intel_mkl_src; use candle_transformers::models::stable_diffusion; use anyhow::{Error as E, Result}; -use candle::{DType, Device, IndexOp, Tensor, D}; +use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; use tokenizers::Tokenizer; diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 000cd2fe..2b925cee 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -17,7 +17,7 @@ impl GeGlu { } } -impl GeGlu { +impl Module for GeGlu { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; @@ -53,7 +53,7 @@ impl FeedForward { } } -impl FeedForward { +impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let xs = self.project_in.forward(xs)?; @@ -501,8 +501,10 @@ impl AttentionBlock { xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? .transpose(1, 2) } +} - pub fn forward(&self, xs: &Tensor) -> Result { +impl Module for AttentionBlock { + fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let in_dtype = xs.dtype(); let residual = xs; diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index d26c1c46..397a1cef 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -14,7 +14,7 @@ pub enum Activation { Gelu, } -impl Activation { +impl Module for Activation { fn forward(&self, xs: &Tensor) -> Result { match self { Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, @@ -129,7 +129,7 @@ impl ClipTextEmbeddings { } } -impl ClipTextEmbeddings { +impl Module for ClipTextEmbeddings { fn forward(&self, xs: &Tensor) -> Result { let token_embedding = self.token_embedding.forward(xs)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?; @@ -328,8 +328,8 @@ impl ClipTextTransformer { } } -impl ClipTextTransformer { - pub fn forward(&self, xs: &Tensor) -> Result { +impl Module for ClipTextTransformer { + fn forward(&self, xs: &Tensor) -> Result { let (bsz, seq_len) = xs.dims2()?; let xs = self.embeddings.forward(xs)?; let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?; diff --git a/candle-transformers/src/models/stable_diffusion/embeddings.rs b/candle-transformers/src/models/stable_diffusion/embeddings.rs index 97bc61f1..0de5f9a7 100644 --- a/candle-transformers/src/models/stable_diffusion/embeddings.rs +++ b/candle-transformers/src/models/stable_diffusion/embeddings.rs @@ -17,8 +17,8 @@ impl TimestepEmbedding { } } -impl TimestepEmbedding { - pub fn forward(&self, xs: &Tensor) -> Result { +impl Module for TimestepEmbedding { + fn forward(&self, xs: &Tensor) -> Result { let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; self.linear_2.forward(&xs) } @@ -41,8 +41,8 @@ impl Timesteps { } } -impl Timesteps { - pub fn forward(&self, xs: &Tensor) -> Result { +impl Module for Timesteps { + fn forward(&self, xs: &Tensor) -> Result { let half_dim = (self.num_channels / 2) as u32; let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)? * -f64::ln(10000.))?; diff --git a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs index c53bd542..29510cef 100644 --- a/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs +++ b/candle-transformers/src/models/stable_diffusion/unet_2d_blocks.rs @@ -5,7 +5,7 @@ use super::attention::{ }; use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; use super::utils::{conv2d, Conv2d}; -use candle::{Result, Tensor, D}; +use candle::{Module, Result, Tensor, D}; use candle_nn as nn; #[derive(Debug)] @@ -43,7 +43,7 @@ impl Downsample2D { } } -impl Downsample2D { +impl Module for Downsample2D { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); match &self.conv { @@ -172,8 +172,8 @@ impl DownEncoderBlock2D { } } -impl DownEncoderBlock2D { - pub fn forward(&self, xs: &Tensor) -> Result { +impl Module for DownEncoderBlock2D { + fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let mut xs = xs.clone(); for resnet in self.resnets.iter() { @@ -256,8 +256,8 @@ impl UpDecoderBlock2D { } } -impl UpDecoderBlock2D { - pub fn forward(&self, xs: &Tensor) -> Result { +impl Module for UpDecoderBlock2D { + fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); let mut xs = xs.clone(); for resnet in self.resnets.iter() { diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index 48155ada..21709afe 100644 --- a/candle-transformers/src/models/stable_diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -132,14 +132,15 @@ impl Encoder { impl Encoder { fn forward(&self, xs: &Tensor) -> Result { - let mut xs = self.conv_in.forward(xs)?; + let mut xs = xs.apply(&self.conv_in)?; for down_block in self.down_blocks.iter() { - xs = down_block.forward(&xs)? + xs = xs.apply(down_block)? } - let xs = self.mid_block.forward(&xs, None)?; - let xs = self.conv_norm_out.forward(&xs)?; - let xs = nn::ops::silu(&xs)?; - self.conv_out.forward(&xs) + let xs = self + .mid_block + .forward(&xs, None)? + .apply(&self.conv_norm_out)?; + nn::ops::silu(&xs)?.apply(&self.conv_out) } } @@ -302,7 +303,7 @@ impl DiagonalGaussianDistribution { } pub fn sample(&self) -> Result { - let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device()); + let sample = self.mean.randn_like(0., 1.); &self.mean + &self.std * sample } }