From 67589791d20272ac6df0685872df7d5b1c6c3e04 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 10 Feb 2024 11:08:50 +0100 Subject: [PATCH] Remove the unused pragma in vit + handle the final layernorm. (#1688) --- candle-transformers/src/models/vit.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/candle-transformers/src/models/vit.rs b/candle-transformers/src/models/vit.rs index 962528c1..7a028a55 100644 --- a/candle-transformers/src/models/vit.rs +++ b/candle-transformers/src/models/vit.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear}; use candle::{IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; @@ -82,7 +81,7 @@ impl PatchEmbeddings { impl Module for PatchEmbeddings { fn forward(&self, pixel_values: &Tensor) -> Result { - let (b_size, num_channels, height, width) = pixel_values.dims4()?; + let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?; self.projection .forward(pixel_values)? .flatten_from(2)? @@ -123,9 +122,9 @@ impl Embeddings { fn interpolate_pos_encoding( &self, - embeddings: &Tensor, - height: usize, - width: usize, + _embeddings: &Tensor, + _height: usize, + _width: usize, ) -> Result { todo!() } @@ -136,7 +135,7 @@ impl Embeddings { bool_masked_pos: Option<&Tensor>, interpolate_pos_encoding: bool, ) -> Result { - let (b_size, num_channels, height, width) = pixel_values.dims4()?; + let (b_size, _num_channels, height, width) = pixel_values.dims4()?; let embeddings = self.patch_embeddings.forward(pixel_values)?; let embeddings = match (bool_masked_pos, &self.mask_token) { (None, _) => embeddings, @@ -392,6 +391,9 @@ impl Model { pub fn forward(&self, xs: &Tensor) -> Result { let embedding_output = self.embeddings.forward(xs, None, false)?; let encoder_outputs = self.encoder.forward(&embedding_output)?; - encoder_outputs.i((.., 0, ..))?.apply(&self.classifier) + encoder_outputs + .i((.., 0, ..))? + .apply(&self.layernorm)? + .apply(&self.classifier) } }