Use the module trait in stable-diffusion. (#817)

This commit is contained in:
Laurent Mazare
2023-09-11 20:40:07 +01:00
committed by GitHub
parent 59e63d690c
commit c5a058b169
7 changed files with 30 additions and 31 deletions

View File

@ -78,11 +78,7 @@ impl st::View for &Tensor {
} }
impl Tensor { impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>( pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
&self,
name: &str,
filename: P,
) -> Result<()> {
let data = [(name, self.clone())]; let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?) Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
} }
@ -267,7 +263,7 @@ impl MmapedFile {
/// # Safety /// # Safety
/// ///
/// The unsafe is inherited from [`memmap2::MmapOptions`]. /// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref(); let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?; let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new() let inner = memmap2::MmapOptions::new()

View File

@ -7,7 +7,7 @@ extern crate intel_mkl_src;
use candle_transformers::models::stable_diffusion; use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser; use clap::Parser;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;

View File

@ -17,7 +17,7 @@ impl GeGlu {
} }
} }
impl GeGlu { impl Module for GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
@ -53,7 +53,7 @@ impl FeedForward {
} }
} }
impl FeedForward { impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?; let xs = self.project_in.forward(xs)?;
@ -501,8 +501,10 @@ impl AttentionBlock {
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2) .transpose(1, 2)
} }
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { impl Module for AttentionBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let in_dtype = xs.dtype(); let in_dtype = xs.dtype();
let residual = xs; let residual = xs;

View File

@ -14,7 +14,7 @@ pub enum Activation {
Gelu, Gelu,
} }
impl Activation { impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self { match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
@ -129,7 +129,7 @@ impl ClipTextEmbeddings {
} }
} }
impl ClipTextEmbeddings { impl Module for ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?; let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?;
@ -328,8 +328,8 @@ impl ClipTextTransformer {
} }
} }
impl ClipTextTransformer { impl Module for ClipTextTransformer {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?; let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?; let xs = self.embeddings.forward(xs)?;
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?; let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;

View File

@ -17,8 +17,8 @@ impl TimestepEmbedding {
} }
} }
impl TimestepEmbedding { impl Module for TimestepEmbedding {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?; let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs) self.linear_2.forward(&xs)
} }
@ -41,8 +41,8 @@ impl Timesteps {
} }
} }
impl Timesteps { impl Module for Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32; let half_dim = (self.num_channels / 2) as u32;
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)? let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
* -f64::ln(10000.))?; * -f64::ln(10000.))?;

View File

@ -5,7 +5,7 @@ use super::attention::{
}; };
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use super::utils::{conv2d, Conv2d}; use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D}; use candle::{Module, Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
#[derive(Debug)] #[derive(Debug)]
@ -43,7 +43,7 @@ impl Downsample2D {
} }
} }
impl Downsample2D { impl Module for Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
match &self.conv { match &self.conv {
@ -172,8 +172,8 @@ impl DownEncoderBlock2D {
} }
} }
impl DownEncoderBlock2D { impl Module for DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for resnet in self.resnets.iter() { for resnet in self.resnets.iter() {
@ -256,8 +256,8 @@ impl UpDecoderBlock2D {
} }
} }
impl UpDecoderBlock2D { impl Module for UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for resnet in self.resnets.iter() { for resnet in self.resnets.iter() {

View File

@ -132,14 +132,15 @@ impl Encoder {
impl Encoder { impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.conv_in.forward(xs)?; let mut xs = xs.apply(&self.conv_in)?;
for down_block in self.down_blocks.iter() { for down_block in self.down_blocks.iter() {
xs = down_block.forward(&xs)? xs = xs.apply(down_block)?
} }
let xs = self.mid_block.forward(&xs, None)?; let xs = self
let xs = self.conv_norm_out.forward(&xs)?; .mid_block
let xs = nn::ops::silu(&xs)?; .forward(&xs, None)?
self.conv_out.forward(&xs) .apply(&self.conv_norm_out)?;
nn::ops::silu(&xs)?.apply(&self.conv_out)
} }
} }
@ -302,7 +303,7 @@ impl DiagonalGaussianDistribution {
} }
pub fn sample(&self) -> Result<Tensor> { pub fn sample(&self) -> Result<Tensor> {
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device()); let sample = self.mean.randn_like(0., 1.);
&self.mean + &self.std * sample &self.mean + &self.std * sample
} }
} }