mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Remove the unused pragma in vit + handle the final layernorm. (#1688)
This commit is contained in:
@ -1,4 +1,3 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
|
use crate::models::with_tracing::{conv2d, linear, linear_no_bias, Conv2d, Linear};
|
||||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||||
@ -82,7 +81,7 @@ impl PatchEmbeddings {
|
|||||||
|
|
||||||
impl Module for PatchEmbeddings {
|
impl Module for PatchEmbeddings {
|
||||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||||
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
|
||||||
self.projection
|
self.projection
|
||||||
.forward(pixel_values)?
|
.forward(pixel_values)?
|
||||||
.flatten_from(2)?
|
.flatten_from(2)?
|
||||||
@ -123,9 +122,9 @@ impl Embeddings {
|
|||||||
|
|
||||||
fn interpolate_pos_encoding(
|
fn interpolate_pos_encoding(
|
||||||
&self,
|
&self,
|
||||||
embeddings: &Tensor,
|
_embeddings: &Tensor,
|
||||||
height: usize,
|
_height: usize,
|
||||||
width: usize,
|
_width: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
@ -136,7 +135,7 @@ impl Embeddings {
|
|||||||
bool_masked_pos: Option<&Tensor>,
|
bool_masked_pos: Option<&Tensor>,
|
||||||
interpolate_pos_encoding: bool,
|
interpolate_pos_encoding: bool,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (b_size, num_channels, height, width) = pixel_values.dims4()?;
|
let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
|
||||||
let embeddings = self.patch_embeddings.forward(pixel_values)?;
|
let embeddings = self.patch_embeddings.forward(pixel_values)?;
|
||||||
let embeddings = match (bool_masked_pos, &self.mask_token) {
|
let embeddings = match (bool_masked_pos, &self.mask_token) {
|
||||||
(None, _) => embeddings,
|
(None, _) => embeddings,
|
||||||
@ -392,6 +391,9 @@ impl Model {
|
|||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
let embedding_output = self.embeddings.forward(xs, None, false)?;
|
||||||
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
let encoder_outputs = self.encoder.forward(&embedding_output)?;
|
||||||
encoder_outputs.i((.., 0, ..))?.apply(&self.classifier)
|
encoder_outputs
|
||||||
|
.i((.., 0, ..))?
|
||||||
|
.apply(&self.layernorm)?
|
||||||
|
.apply(&self.classifier)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user