mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context. * Switch two unwrap to context.
This commit is contained in:
@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
|
|||||||
pub msg: &'static str,
|
pub msg: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Error {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{self}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Main library error type.
|
/// Main library error type.
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
// === DType Errors ===
|
// === DType Errors ===
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
@ -199,8 +205,14 @@ pub enum Error {
|
|||||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||||
|
|
||||||
/// Arbitrary errors wrapping.
|
/// Arbitrary errors wrapping.
|
||||||
#[error(transparent)]
|
#[error("{0}")]
|
||||||
Wrapped(Box<dyn std::error::Error + Send + Sync>),
|
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
|
||||||
|
|
||||||
|
#[error("{context}\n{inner}")]
|
||||||
|
Context {
|
||||||
|
inner: Box<Self>,
|
||||||
|
context: Box<dyn std::fmt::Display + Send + Sync>,
|
||||||
|
},
|
||||||
|
|
||||||
/// Adding path information to an error.
|
/// Adding path information to an error.
|
||||||
#[error("path: {path:?} {inner}")]
|
#[error("path: {path:?} {inner}")]
|
||||||
@ -218,16 +230,19 @@ pub enum Error {
|
|||||||
/// User generated error message, typically created via `bail!`.
|
/// User generated error message, typically created via `bail!`.
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
Msg(String),
|
Msg(String),
|
||||||
|
|
||||||
|
#[error("unwrap none")]
|
||||||
|
UnwrapNone,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
impl Error {
|
impl Error {
|
||||||
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
|
||||||
Self::Wrapped(Box::new(err)).bt()
|
Self::Wrapped(Box::new(err)).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn msg(err: impl std::error::Error) -> Self {
|
pub fn msg(err: impl std::fmt::Display) -> Self {
|
||||||
Self::Msg(err.to_string()).bt()
|
Self::Msg(err.to_string()).bt()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -253,6 +268,13 @@ impl Error {
|
|||||||
path: p.as_ref().to_path_buf(),
|
path: p.as_ref().to_path_buf(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
|
||||||
|
Self::Context {
|
||||||
|
inner: Box::new(self),
|
||||||
|
context: Box::new(c),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
|
|||||||
(_, Err(e)) => Err(e),
|
(_, Err(e)) => Err(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Taken from anyhow.
|
||||||
|
pub trait Context<T> {
|
||||||
|
/// Wrap the error value with additional context.
|
||||||
|
fn context<C>(self, context: C) -> Result<T>
|
||||||
|
where
|
||||||
|
C: std::fmt::Display + Send + Sync + 'static;
|
||||||
|
|
||||||
|
/// Wrap the error value with additional context that is evaluated lazily
|
||||||
|
/// only once an error does occur.
|
||||||
|
fn with_context<C, F>(self, f: F) -> Result<T>
|
||||||
|
where
|
||||||
|
C: std::fmt::Display + Send + Sync + 'static,
|
||||||
|
F: FnOnce() -> C;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Context<T> for Option<T> {
|
||||||
|
fn context<C>(self, context: C) -> Result<T>
|
||||||
|
where
|
||||||
|
C: std::fmt::Display + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Some(v) => Ok(v),
|
||||||
|
None => Err(Error::UnwrapNone.context(context).bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_context<C, F>(self, f: F) -> Result<T>
|
||||||
|
where
|
||||||
|
C: std::fmt::Display + Send + Sync + 'static,
|
||||||
|
F: FnOnce() -> C,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Some(v) => Ok(v),
|
||||||
|
None => Err(Error::UnwrapNone.context(f()).bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
|||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Context, Error, Result};
|
||||||
pub use indexer::{IndexOp, TensorIndexer};
|
pub use indexer::{IndexOp, TensorIndexer};
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Just enough pickle support to be able to read PyTorch checkpoints.
|
//! Just enough pickle support to be able to read PyTorch checkpoints.
|
||||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||||
// composable/tensor agnostic at some point.
|
// composable/tensor agnostic at some point.
|
||||||
use crate::{DType, Error as E, Layout, Result, Tensor};
|
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
@ -537,7 +537,7 @@ impl Stack {
|
|||||||
crate::bail!("setitems: not an even number of objects")
|
crate::bail!("setitems: not an even number of objects")
|
||||||
}
|
}
|
||||||
while let Some(value) = objs.pop() {
|
while let Some(value) = objs.pop() {
|
||||||
let key = objs.pop().unwrap();
|
let key = objs.pop().context("empty objs")?;
|
||||||
d.push((key, value))
|
d.push((key, value))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -557,7 +557,7 @@ impl Stack {
|
|||||||
crate::bail!("setitems: not an even number of objects")
|
crate::bail!("setitems: not an even number of objects")
|
||||||
}
|
}
|
||||||
while let Some(value) = objs.pop() {
|
while let Some(value) = objs.pop() {
|
||||||
let key = objs.pop().unwrap();
|
let key = objs.pop().context("empty objs")?;
|
||||||
pydict.push((key, value))
|
pydict.push((key, value))
|
||||||
}
|
}
|
||||||
self.push(Object::Dict(pydict))
|
self.push(Object::Dict(pydict))
|
||||||
@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
if !file_name.ends_with("data.pkl") {
|
if !file_name.ends_with("data.pkl") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
|
||||||
let reader = zip.by_name(file_name)?;
|
let reader = zip.by_name(file_name)?;
|
||||||
let mut reader = std::io::BufReader::new(reader);
|
let mut reader = std::io::BufReader::new(reader);
|
||||||
let mut stack = Stack::empty();
|
let mut stack = Stack::empty();
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
|
|
||||||
use super::{GgmlDType, QTensor};
|
use super::{GgmlDType, QTensor};
|
||||||
use crate::{Device, Result};
|
use crate::{Context, Device, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -338,7 +338,7 @@ impl Value {
|
|||||||
if value_type.len() != 1 {
|
if value_type.len() != 1 {
|
||||||
crate::bail!("multiple value-types in the same array {value_type:?}")
|
crate::bail!("multiple value-types in the same array {value_type:?}")
|
||||||
}
|
}
|
||||||
value_type.into_iter().next().unwrap()
|
value_type.into_iter().next().context("empty value_type")?
|
||||||
};
|
};
|
||||||
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
w.write_u32::<LittleEndian>(value_type.to_u32())?;
|
||||||
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
w.write_u64::<LittleEndian>(v.len() as u64)?;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
//! Code for GGML and GGUF files
|
//! Code for GGML and GGUF files
|
||||||
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
}
|
}
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
let last_k = dst_shape.pop().unwrap();
|
let last_k = dst_shape.pop().context("empty dst_shape")?;
|
||||||
if last_k != k {
|
if last_k != k {
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
/// Concatenates two or more tensors along a particular dimension.
|
/// Concatenates two or more tensors along a particular dimension.
|
||||||
@ -134,7 +134,7 @@ impl Tensor {
|
|||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
|
||||||
offsets.push(next_offset);
|
offsets.push(next_offset);
|
||||||
}
|
}
|
||||||
let shape = Shape::from(cat_dims);
|
let shape = Shape::from(cat_dims);
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! Functionality for modeling sampling strategies and logits processing in text generation
|
//! Functionality for modeling sampling strategies and logits processing in text generation
|
||||||
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
||||||
//! and combinations thereof.
|
//! and combinations thereof.
|
||||||
use candle::{DType, Error, Result, Tensor};
|
use candle::{Context, DType, Error, Result, Tensor};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
|
|
||||||
#[derive(Clone, PartialEq, Debug)]
|
#[derive(Clone, PartialEq, Debug)]
|
||||||
@ -45,7 +45,7 @@ impl LogitsProcessor {
|
|||||||
.enumerate()
|
.enumerate()
|
||||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||||
.map(|(i, _)| i as u32)
|
.map(|(i, _)| i as u32)
|
||||||
.unwrap();
|
.context("empty logits")?;
|
||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
|
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
|
||||||
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
|
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
|
use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
use super::{Activation, EncoderConfig};
|
use super::{Activation, EncoderConfig};
|
||||||
@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
|
|||||||
.apply(&self.pre_layer_norm)?;
|
.apply(&self.pre_layer_norm)?;
|
||||||
|
|
||||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||||
let encoder_outputs = result.last().unwrap();
|
let encoder_outputs = result.last().context("no last")?;
|
||||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
//! https://github.com/openai/CLIP
|
//! https://github.com/openai/CLIP
|
||||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||||
|
|
||||||
use candle::{IndexOp, Result, Shape, Tensor, D};
|
use candle::{Context, IndexOp, Result, Shape, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use candle_nn::Module;
|
use candle_nn::Module;
|
||||||
use nn::Conv2dConfig;
|
use nn::Conv2dConfig;
|
||||||
@ -149,7 +149,7 @@ impl ClipVisionTransformer {
|
|||||||
.apply(&self.embeddings)?
|
.apply(&self.embeddings)?
|
||||||
.apply(&self.pre_layer_norm)?;
|
.apply(&self.pre_layer_norm)?;
|
||||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||||
let encoder_outputs = result.last().unwrap();
|
let encoder_outputs = result.last().context("no last")?;
|
||||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//! See:
|
//! See:
|
||||||
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
|
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
|
||||||
//!
|
//!
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Context, Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use nn::{Module, VarBuilder};
|
use nn::{Module, VarBuilder};
|
||||||
|
|
||||||
@ -289,7 +289,7 @@ impl EfficientNet {
|
|||||||
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
||||||
let f_p = p.pp("features");
|
let f_p = p.pp("features");
|
||||||
let first_in_c = configs[0].input_channels;
|
let first_in_c = configs[0].input_channels;
|
||||||
let last_out_c = configs.last().unwrap().out_channels;
|
let last_out_c = configs.last().context("no last")?.out_channels;
|
||||||
let final_out_c = 4 * last_out_c;
|
let final_out_c = 4 * last_out_c;
|
||||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
||||||
let nconfigs = configs.len();
|
let nconfigs = configs.len();
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
|
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{Context, DType, Result, Tensor, D};
|
||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
|
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
|
||||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||||
@ -178,7 +178,7 @@ fn squeeze_and_excitation(
|
|||||||
// based on the _fuse_bn_tensor method in timm
|
// based on the _fuse_bn_tensor method in timm
|
||||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
|
||||||
let mu = bn.running_mean();
|
let mu = bn.running_mean();
|
||||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||||
let gps = (gamma / sigma)?;
|
let gps = (gamma / sigma)?;
|
||||||
|
@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}
|
|||||||
use crate::models::llama::{Cache, Llama};
|
use crate::models::llama::{Cache, Llama};
|
||||||
use crate::models::with_tracing::linear;
|
use crate::models::with_tracing::linear;
|
||||||
|
|
||||||
use candle::{bail, Device, IndexOp, Result, Tensor};
|
use candle::{bail, Context, Device, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
|
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
|
||||||
use fancy_regex::Regex;
|
use fancy_regex::Regex;
|
||||||
use utils::get_anyres_image_grid_shape;
|
use utils::get_anyres_image_grid_shape;
|
||||||
@ -145,7 +145,7 @@ impl ClipVisionTower {
|
|||||||
let config = if config.is_none() {
|
let config = if config.is_none() {
|
||||||
ClipVisionConfig::clip_vit_large_patch14_336()
|
ClipVisionConfig::clip_vit_large_patch14_336()
|
||||||
} else {
|
} else {
|
||||||
config.clone().unwrap()
|
config.clone().context("no config")?
|
||||||
};
|
};
|
||||||
let select_layer = match select_layer {
|
let select_layer = match select_layer {
|
||||||
-1 | -2 => select_layer,
|
-1 | -2 => select_layer,
|
||||||
@ -262,14 +262,14 @@ impl LLaVA {
|
|||||||
let image_features = if mm_patch_merge_type == "flat" {
|
let image_features = if mm_patch_merge_type == "flat" {
|
||||||
image_features
|
image_features
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| x.flatten(0, 1).unwrap())
|
.map(|x| x.flatten(0, 1))
|
||||||
.collect::<Vec<Tensor>>()
|
.collect::<Result<Vec<Tensor>>>()?
|
||||||
} else if mm_patch_merge_type.starts_with("spatial") {
|
} else if mm_patch_merge_type.starts_with("spatial") {
|
||||||
let mut new_image_features = Vec::new();
|
let mut new_image_features = Vec::new();
|
||||||
for (image_idx, image_feature) in image_features.iter().enumerate() {
|
for (image_idx, image_feature) in image_features.iter().enumerate() {
|
||||||
let new_image_feature = if image_feature.dims()[0] > 1 {
|
let new_image_feature = if image_feature.dims()[0] > 1 {
|
||||||
let base_image_feature = image_feature.get(0).unwrap();
|
let base_image_feature = image_feature.get(0)?;
|
||||||
let patch_image_feature = image_feature.i(1..).unwrap();
|
let patch_image_feature = image_feature.i(1..)?;
|
||||||
let height = self.clip_vision_tower.num_patches_per_side();
|
let height = self.clip_vision_tower.num_patches_per_side();
|
||||||
let width = height;
|
let width = height;
|
||||||
assert_eq!(height * width, base_image_feature.dims()[0]);
|
assert_eq!(height * width, base_image_feature.dims()[0]);
|
||||||
@ -313,16 +313,12 @@ impl LLaVA {
|
|||||||
};
|
};
|
||||||
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
|
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
|
||||||
} else {
|
} else {
|
||||||
let new_image_feature = image_feature.get(0).unwrap();
|
let new_image_feature = image_feature.get(0)?;
|
||||||
if mm_patch_merge_type.contains("unpad") {
|
if mm_patch_merge_type.contains("unpad") {
|
||||||
Tensor::cat(
|
Tensor::cat(
|
||||||
&[
|
&[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
|
||||||
new_image_feature,
|
|
||||||
self.image_newline.clone().unsqueeze(0).unwrap(),
|
|
||||||
],
|
|
||||||
0,
|
0,
|
||||||
)
|
)?
|
||||||
.unwrap()
|
|
||||||
} else {
|
} else {
|
||||||
new_image_feature
|
new_image_feature
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
//!
|
//!
|
||||||
|
|
||||||
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
|
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||||
use candle::{Module, ModuleT, Result, Tensor, D};
|
use candle::{Context, Module, ModuleT, Result, Tensor, D};
|
||||||
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
|
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -633,7 +633,7 @@ impl ImageClassificationModel {
|
|||||||
impl Module for ImageClassificationModel {
|
impl Module for ImageClassificationModel {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let all_hidden_states = self.segformer.forward(x)?;
|
let all_hidden_states = self.segformer.forward(x)?;
|
||||||
let hidden_states = all_hidden_states.last().unwrap();
|
let hidden_states = all_hidden_states.last().context("no last")?;
|
||||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||||
let mean = hidden_states.mean(1)?;
|
let mean = hidden_states.mean(1)?;
|
||||||
self.classifier.forward(&mean)
|
self.classifier.forward(&mean)
|
||||||
|
Reference in New Issue
Block a user