mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Some CLIP fixes for stable diffusion. (#338)
* Some CLIP fixes for stable diffusion. * Add the avg-pool2d operation on cpu.
This commit is contained in:
@ -37,6 +37,8 @@ pub trait BackendStorage: Sized {
|
|||||||
_params: &crate::conv::ParamsConv1D,
|
_params: &crate::conv::ParamsConv1D,
|
||||||
) -> Result<Self>;
|
) -> Result<Self>;
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
|
||||||
|
|
||||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||||
fn scatter_add(
|
fn scatter_add(
|
||||||
&self,
|
&self,
|
||||||
|
@ -633,6 +633,45 @@ impl Map1 for Affine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AvgPool2D((usize, usize), (usize, usize));
|
||||||
|
|
||||||
|
impl Map1 for AvgPool2D {
|
||||||
|
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
|
||||||
|
let (k_h, k_w) = self.0;
|
||||||
|
let (s_h, s_w) = self.1;
|
||||||
|
let (b_sz, c, h, w) = layout.shape().dims4()?;
|
||||||
|
let stride = layout.stride();
|
||||||
|
let (stride_h, stride_w) = (stride[2], stride[3]);
|
||||||
|
let h_out = (h - k_h) / s_h + 1;
|
||||||
|
let w_out = (w - k_w) / s_w + 1;
|
||||||
|
let src_index = layout.start_offset();
|
||||||
|
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
|
||||||
|
let scale = 1f64 / (k_h * k_w) as f64;
|
||||||
|
let scale = T::from_f64(scale);
|
||||||
|
for b_idx in 0..b_sz {
|
||||||
|
let dst = &mut dst[b_idx * c * h_out * w_out..];
|
||||||
|
let src_index = src_index + b_idx * stride[0];
|
||||||
|
for c_idx in 0..c {
|
||||||
|
let dst = &mut dst[c_idx * h_out * w_out..];
|
||||||
|
let src_index = src_index + c_idx * stride[1];
|
||||||
|
for h_idx in 0..h_out {
|
||||||
|
for w_idx in 0..w_out {
|
||||||
|
let mut sum = T::zero();
|
||||||
|
for m in 0..k_h {
|
||||||
|
for n in 0..k_w {
|
||||||
|
sum += src[src_index + m * stride_h + n * stride_w]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst[h_idx * w_out + w_idx] = sum * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Gather<'a, I: IntDType> {
|
struct Gather<'a, I: IntDType> {
|
||||||
ids: &'a [I],
|
ids: &'a [I],
|
||||||
ids_l: &'a Layout,
|
ids_l: &'a Layout,
|
||||||
@ -1529,6 +1568,15 @@ impl BackendStorage for CpuStorage {
|
|||||||
Affine(mul, add).map(self, layout)
|
Affine(mul, add).map(self, layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(
|
||||||
|
&self,
|
||||||
|
layout: &Layout,
|
||||||
|
kernel_size: (usize, usize),
|
||||||
|
stride: (usize, usize),
|
||||||
|
) -> Result<Self> {
|
||||||
|
AvgPool2D(kernel_size, stride).map(self, layout)
|
||||||
|
}
|
||||||
|
|
||||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||||
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
|
||||||
match self {
|
match self {
|
||||||
|
@ -1381,6 +1381,10 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||||
|
@ -119,6 +119,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::backend::BackendDevice for CudaDevice {
|
impl crate::backend::BackendDevice for CudaDevice {
|
||||||
|
@ -268,11 +268,20 @@ impl Storage {
|
|||||||
|
|
||||||
pub(crate) fn avg_pool2d(
|
pub(crate) fn avg_pool2d(
|
||||||
&self,
|
&self,
|
||||||
_layout: &Layout,
|
layout: &Layout,
|
||||||
_kernel_size: (usize, usize),
|
kernel_size: (usize, usize),
|
||||||
_stride: (usize, usize),
|
stride: (usize, usize),
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
match self {
|
||||||
|
Storage::Cpu(storage) => {
|
||||||
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
Self::Cuda(storage) => {
|
||||||
|
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn upsample_nearest2d(
|
pub(crate) fn upsample_nearest2d(
|
||||||
|
@ -103,7 +103,7 @@ impl ClipTextEmbeddings {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let token_embedding = self.token_embedding.forward(xs)?;
|
let token_embedding = self.token_embedding.forward(xs)?;
|
||||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||||
token_embedding + position_embedding
|
token_embedding.broadcast_add(&position_embedding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,9 +161,9 @@ impl ClipAttention {
|
|||||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||||
|
|
||||||
let src_len = key_states.dim(1)?;
|
let src_len = key_states.dim(1)?;
|
||||||
let attn_weights =
|
let attn_weights = attn_weights
|
||||||
(attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||||
+ causal_attention_mask)?;
|
.broadcast_add(causal_attention_mask)?;
|
||||||
let attn_weights =
|
let attn_weights =
|
||||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||||
@ -287,7 +287,7 @@ impl ClipTextTransformer {
|
|||||||
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
|
||||||
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
|
||||||
let mask: Vec<_> = (0..seq_len)
|
let mask: Vec<_> = (0..seq_len)
|
||||||
.flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. }))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||||
mask.broadcast_as((bsz, seq_len, seq_len))
|
mask.broadcast_as((bsz, seq_len, seq_len))
|
||||||
|
@ -57,13 +57,9 @@ struct Args {
|
|||||||
#[arg(long, value_name = "FILE")]
|
#[arg(long, value_name = "FILE")]
|
||||||
vae_weights: Option<String>,
|
vae_weights: Option<String>,
|
||||||
|
|
||||||
#[arg(
|
#[arg(long, value_name = "FILE")]
|
||||||
long,
|
/// The file specifying the tokenizer to used for tokenization.
|
||||||
value_name = "FILE",
|
tokenizer: String,
|
||||||
default_value = "data/bpe_simple_vocab_16e6.txt"
|
|
||||||
)]
|
|
||||||
/// The file specifying the vocabulary to used for tokenization.
|
|
||||||
vocab_file: String,
|
|
||||||
|
|
||||||
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
/// The size of the sliced attention or 0 for automatic slicing (disabled by default)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -165,7 +161,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
n_steps,
|
n_steps,
|
||||||
vocab_file,
|
tokenizer,
|
||||||
final_image,
|
final_image,
|
||||||
sliced_attention_size,
|
sliced_attention_size,
|
||||||
num_samples,
|
num_samples,
|
||||||
@ -184,7 +180,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||||
let device = candle_examples::device(cpu)?;
|
let device = candle_examples::device(cpu)?;
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(vocab_file).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
println!("Running with prompt \"{prompt}\".");
|
println!("Running with prompt \"{prompt}\".");
|
||||||
let tokens = tokenizer
|
let tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
|
Reference in New Issue
Block a user