diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 345db0e5..307b56dc 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -37,6 +37,8 @@ pub trait BackendStorage: Sized { _params: &crate::conv::ParamsConv1D, ) -> Result; + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result; fn scatter_add( &self, diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4aa2f880..401a2c0e 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -633,6 +633,45 @@ impl Map1 for Affine { } } +struct AvgPool2D((usize, usize), (usize, usize)); + +impl Map1 for AvgPool2D { + fn f(&self, src: &[T], layout: &Layout) -> Result> { + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html + let (k_h, k_w) = self.0; + let (s_h, s_w) = self.1; + let (b_sz, c, h, w) = layout.shape().dims4()?; + let stride = layout.stride(); + let (stride_h, stride_w) = (stride[2], stride[3]); + let h_out = (h - k_h) / s_h + 1; + let w_out = (w - k_w) / s_w + 1; + let src_index = layout.start_offset(); + let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + let scale = 1f64 / (k_h * k_w) as f64; + let scale = T::from_f64(scale); + for b_idx in 0..b_sz { + let dst = &mut dst[b_idx * c * h_out * w_out..]; + let src_index = src_index + b_idx * stride[0]; + for c_idx in 0..c { + let dst = &mut dst[c_idx * h_out * w_out..]; + let src_index = src_index + c_idx * stride[1]; + for h_idx in 0..h_out { + for w_idx in 0..w_out { + let mut sum = T::zero(); + for m in 0..k_h { + for n in 0..k_w { + sum += src[src_index + m * stride_h + n * stride_w] + } + } + dst[h_idx * w_out + w_idx] = sum * scale; + } + } + } + } + Ok(dst) + } +} + struct Gather<'a, I: IntDType> { ids: &'a [I], ids_l: &'a Layout, @@ -1529,6 +1568,15 @@ impl BackendStorage for CpuStorage { Affine(mul, add).map(self, layout) } + fn avg_pool2d( + &self, + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), + ) -> Result { + AvgPool2D(kernel_size, stride).map(self, layout) + } + fn elu(&self, layout: &Layout, alpha: f64) -> Result { // TODO: Have some generic map for functions that apply on num_traits::Float elements. match self { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 7b4b358d..e71ecfce 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1381,6 +1381,10 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result { let device = self.device().clone(); let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 17d4a22e..2d5f955c 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -119,6 +119,10 @@ impl crate::backend::BackendStorage for CudaStorage { fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } } impl crate::backend::BackendDevice for CudaDevice { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index cbca4fc4..47df689c 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -268,11 +268,20 @@ impl Storage { pub(crate) fn avg_pool2d( &self, - _layout: &Layout, - _kernel_size: (usize, usize), - _stride: (usize, usize), + layout: &Layout, + kernel_size: (usize, usize), + stride: (usize, usize), ) -> Result { - todo!() + match self { + Storage::Cpu(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cpu(storage)) + } + Self::Cuda(storage) => { + let storage = storage.avg_pool2d(layout, kernel_size, stride)?; + Ok(Self::Cuda(storage)) + } + } } pub(crate) fn upsample_nearest2d( diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index 227660b1..ac9843f7 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -103,7 +103,7 @@ impl ClipTextEmbeddings { fn forward(&self, xs: &Tensor) -> Result { let token_embedding = self.token_embedding.forward(xs)?; let position_embedding = self.position_embedding.forward(&self.position_ids)?; - token_embedding + position_embedding + token_embedding.broadcast_add(&position_embedding) } } @@ -161,9 +161,9 @@ impl ClipAttention { let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; let src_len = key_states.dim(1)?; - let attn_weights = - (attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))? - + causal_attention_mask)?; + let attn_weights = attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)?; let attn_weights = attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?; let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; @@ -287,7 +287,7 @@ impl ClipTextTransformer { // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678 fn build_causal_attention_mask(bsz: usize, seq_len: usize, device: &Device) -> Result { let mask: Vec<_> = (0..seq_len) - .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::MIN } else { 0. })) .collect(); let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; mask.broadcast_as((bsz, seq_len, seq_len)) diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 2203b03a..d8327c0e 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -57,13 +57,9 @@ struct Args { #[arg(long, value_name = "FILE")] vae_weights: Option, - #[arg( - long, - value_name = "FILE", - default_value = "data/bpe_simple_vocab_16e6.txt" - )] - /// The file specifying the vocabulary to used for tokenization. - vocab_file: String, + #[arg(long, value_name = "FILE")] + /// The file specifying the tokenizer to used for tokenization. + tokenizer: String, /// The size of the sliced attention or 0 for automatic slicing (disabled by default) #[arg(long)] @@ -165,7 +161,7 @@ fn run(args: Args) -> Result<()> { height, width, n_steps, - vocab_file, + tokenizer, final_image, sliced_attention_size, num_samples, @@ -184,7 +180,7 @@ fn run(args: Args) -> Result<()> { let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; - let tokenizer = Tokenizer::from_file(vocab_file).map_err(E::msg)?; + let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; println!("Running with prompt \"{prompt}\"."); let tokens = tokenizer .encode(prompt, true)