Use the module trait in stable-diffusion. (#817)

This commit is contained in:
Laurent Mazare
2023-09-11 20:40:07 +01:00
committed by GitHub
parent 59e63d690c
commit c5a058b169
7 changed files with 30 additions and 31 deletions

View File

@ -78,11 +78,7 @@ impl st::View for &Tensor {
}
impl Tensor {
pub fn save_safetensors<P: AsRef<std::path::Path>>(
&self,
name: &str,
filename: P,
) -> Result<()> {
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
let data = [(name, self.clone())];
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
}
@ -267,7 +263,7 @@ impl MmapedFile {
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
let p = p.as_ref();
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
let inner = memmap2::MmapOptions::new()

View File

@ -7,7 +7,7 @@ extern crate intel_mkl_src;
use candle_transformers::models::stable_diffusion;
use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;

View File

@ -17,7 +17,7 @@ impl GeGlu {
}
}
impl GeGlu {
impl Module for GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
@ -53,7 +53,7 @@ impl FeedForward {
}
}
impl FeedForward {
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
@ -501,8 +501,10 @@ impl AttentionBlock {
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for AttentionBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = xs.dtype();
let residual = xs;

View File

@ -14,7 +14,7 @@ pub enum Activation {
Gelu,
}
impl Activation {
impl Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
@ -129,7 +129,7 @@ impl ClipTextEmbeddings {
}
}
impl ClipTextEmbeddings {
impl Module for ClipTextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let token_embedding = self.token_embedding.forward(xs)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
@ -328,8 +328,8 @@ impl ClipTextTransformer {
}
}
impl ClipTextTransformer {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for ClipTextTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (bsz, seq_len) = xs.dims2()?;
let xs = self.embeddings.forward(xs)?;
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;

View File

@ -17,8 +17,8 @@ impl TimestepEmbedding {
}
}
impl TimestepEmbedding {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for TimestepEmbedding {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
self.linear_2.forward(&xs)
}
@ -41,8 +41,8 @@ impl Timesteps {
}
}
impl Timesteps {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for Timesteps {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let half_dim = (self.num_channels / 2) as u32;
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
* -f64::ln(10000.))?;

View File

@ -5,7 +5,7 @@ use super::attention::{
};
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use super::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
@ -43,7 +43,7 @@ impl Downsample2D {
}
}
impl Downsample2D {
impl Module for Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
@ -172,8 +172,8 @@ impl DownEncoderBlock2D {
}
}
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for DownEncoderBlock2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
@ -256,8 +256,8 @@ impl UpDecoderBlock2D {
}
}
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for UpDecoderBlock2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {

View File

@ -132,14 +132,15 @@ impl Encoder {
impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.conv_in.forward(xs)?;
let mut xs = xs.apply(&self.conv_in)?;
for down_block in self.down_blocks.iter() {
xs = down_block.forward(&xs)?
xs = xs.apply(down_block)?
}
let xs = self.mid_block.forward(&xs, None)?;
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
let xs = self
.mid_block
.forward(&xs, None)?
.apply(&self.conv_norm_out)?;
nn::ops::silu(&xs)?.apply(&self.conv_out)
}
}
@ -302,7 +303,7 @@ impl DiagonalGaussianDistribution {
}
pub fn sample(&self) -> Result<Tensor> {
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
let sample = self.mean.randn_like(0., 1.);
&self.mean + &self.std * sample
}
}