mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the module trait in stable-diffusion. (#817)
This commit is contained in:
@ -78,11 +78,7 @@ impl st::View for &Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn save_safetensors<P: AsRef<std::path::Path>>(
|
pub fn save_safetensors<P: AsRef<Path>>(&self, name: &str, filename: P) -> Result<()> {
|
||||||
&self,
|
|
||||||
name: &str,
|
|
||||||
filename: P,
|
|
||||||
) -> Result<()> {
|
|
||||||
let data = [(name, self.clone())];
|
let data = [(name, self.clone())];
|
||||||
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
Ok(st::serialize_to_file(data, &None, filename.as_ref())?)
|
||||||
}
|
}
|
||||||
@ -267,7 +263,7 @@ impl MmapedFile {
|
|||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||||
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
|
||||||
let p = p.as_ref();
|
let p = p.as_ref();
|
||||||
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
|
||||||
let inner = memmap2::MmapOptions::new()
|
let inner = memmap2::MmapOptions::new()
|
||||||
|
@ -7,7 +7,7 @@ extern crate intel_mkl_src;
|
|||||||
use candle_transformers::models::stable_diffusion;
|
use candle_transformers::models::stable_diffusion;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ impl GeGlu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GeGlu {
|
impl Module for GeGlu {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||||
@ -53,7 +53,7 @@ impl FeedForward {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeedForward {
|
impl Module for FeedForward {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let xs = self.project_in.forward(xs)?;
|
let xs = self.project_in.forward(xs)?;
|
||||||
@ -501,8 +501,10 @@ impl AttentionBlock {
|
|||||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
impl Module for AttentionBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let in_dtype = xs.dtype();
|
let in_dtype = xs.dtype();
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
|
@ -14,7 +14,7 @@ pub enum Activation {
|
|||||||
Gelu,
|
Gelu,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Activation {
|
impl Module for Activation {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
|
||||||
@ -129,7 +129,7 @@ impl ClipTextEmbeddings {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClipTextEmbeddings {
|
impl Module for ClipTextEmbeddings {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let token_embedding = self.token_embedding.forward(xs)?;
|
let token_embedding = self.token_embedding.forward(xs)?;
|
||||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||||
@ -328,8 +328,8 @@ impl ClipTextTransformer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClipTextTransformer {
|
impl Module for ClipTextTransformer {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let (bsz, seq_len) = xs.dims2()?;
|
let (bsz, seq_len) = xs.dims2()?;
|
||||||
let xs = self.embeddings.forward(xs)?;
|
let xs = self.embeddings.forward(xs)?;
|
||||||
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
|
let causal_attention_mask = Self::build_causal_attention_mask(bsz, seq_len, xs.device())?;
|
||||||
|
@ -17,8 +17,8 @@ impl TimestepEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TimestepEmbedding {
|
impl Module for TimestepEmbedding {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
|
let xs = nn::ops::silu(&self.linear_1.forward(xs)?)?;
|
||||||
self.linear_2.forward(&xs)
|
self.linear_2.forward(&xs)
|
||||||
}
|
}
|
||||||
@ -41,8 +41,8 @@ impl Timesteps {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Timesteps {
|
impl Module for Timesteps {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let half_dim = (self.num_channels / 2) as u32;
|
let half_dim = (self.num_channels / 2) as u32;
|
||||||
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
|
let exponent = (Tensor::arange(0, half_dim, xs.device())?.to_dtype(candle::DType::F32)?
|
||||||
* -f64::ln(10000.))?;
|
* -f64::ln(10000.))?;
|
||||||
|
@ -5,7 +5,7 @@ use super::attention::{
|
|||||||
};
|
};
|
||||||
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||||
use super::utils::{conv2d, Conv2d};
|
use super::utils::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Module, Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -43,7 +43,7 @@ impl Downsample2D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Downsample2D {
|
impl Module for Downsample2D {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
match &self.conv {
|
match &self.conv {
|
||||||
@ -172,8 +172,8 @@ impl DownEncoderBlock2D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DownEncoderBlock2D {
|
impl Module for DownEncoderBlock2D {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for resnet in self.resnets.iter() {
|
for resnet in self.resnets.iter() {
|
||||||
@ -256,8 +256,8 @@ impl UpDecoderBlock2D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UpDecoderBlock2D {
|
impl Module for UpDecoderBlock2D {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for resnet in self.resnets.iter() {
|
for resnet in self.resnets.iter() {
|
||||||
|
@ -132,14 +132,15 @@ impl Encoder {
|
|||||||
|
|
||||||
impl Encoder {
|
impl Encoder {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let mut xs = self.conv_in.forward(xs)?;
|
let mut xs = xs.apply(&self.conv_in)?;
|
||||||
for down_block in self.down_blocks.iter() {
|
for down_block in self.down_blocks.iter() {
|
||||||
xs = down_block.forward(&xs)?
|
xs = xs.apply(down_block)?
|
||||||
}
|
}
|
||||||
let xs = self.mid_block.forward(&xs, None)?;
|
let xs = self
|
||||||
let xs = self.conv_norm_out.forward(&xs)?;
|
.mid_block
|
||||||
let xs = nn::ops::silu(&xs)?;
|
.forward(&xs, None)?
|
||||||
self.conv_out.forward(&xs)
|
.apply(&self.conv_norm_out)?;
|
||||||
|
nn::ops::silu(&xs)?.apply(&self.conv_out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -302,7 +303,7 @@ impl DiagonalGaussianDistribution {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn sample(&self) -> Result<Tensor> {
|
pub fn sample(&self) -> Result<Tensor> {
|
||||||
let sample = Tensor::randn(0., 1f32, self.mean.shape(), self.mean.device());
|
let sample = self.mean.randn_like(0., 1.);
|
||||||
&self.mean + &self.std * sample
|
&self.mean + &self.std * sample
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user