mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a Context trait similar to anyhow::Context. (#2676)
* Add a Context trait similar to anyhow::Context. * Switch two unwrap to context.
This commit is contained in:
@ -6,7 +6,7 @@
|
||||
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
|
||||
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_
|
||||
|
||||
use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
|
||||
use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
use super::{Activation, EncoderConfig};
|
||||
@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let encoder_outputs = result.last().context("no last")?;
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
|
@ -6,7 +6,7 @@
|
||||
//! https://github.com/openai/CLIP
|
||||
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip
|
||||
|
||||
use candle::{IndexOp, Result, Shape, Tensor, D};
|
||||
use candle::{Context, IndexOp, Result, Shape, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
use nn::Conv2dConfig;
|
||||
@ -149,7 +149,7 @@ impl ClipVisionTransformer {
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
|
||||
let encoder_outputs = result.last().unwrap();
|
||||
let encoder_outputs = result.last().context("no last")?;
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
|
||||
Ok(result)
|
||||
|
@ -3,7 +3,7 @@
|
||||
//! See:
|
||||
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
|
||||
//!
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle::{Context, Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
use nn::{Module, VarBuilder};
|
||||
|
||||
@ -289,7 +289,7 @@ impl EfficientNet {
|
||||
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
|
||||
let f_p = p.pp("features");
|
||||
let first_in_c = configs[0].input_channels;
|
||||
let last_out_c = configs.last().unwrap().out_channels;
|
||||
let last_out_c = configs.last().context("no last")?.out_channels;
|
||||
let final_out_c = 4 * last_out_c;
|
||||
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
|
||||
let nconfigs = configs.len();
|
||||
|
@ -5,7 +5,7 @@
|
||||
//!
|
||||
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle::{Context, DType, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
|
||||
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
|
||||
@ -178,7 +178,7 @@ fn squeeze_and_excitation(
|
||||
// based on the _fuse_bn_tensor method in timm
|
||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
||||
let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
|
||||
let mu = bn.running_mean();
|
||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
||||
let gps = (gamma / sigma)?;
|
||||
|
@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}
|
||||
use crate::models::llama::{Cache, Llama};
|
||||
use crate::models::with_tracing::linear;
|
||||
|
||||
use candle::{bail, Device, IndexOp, Result, Tensor};
|
||||
use candle::{bail, Context, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
|
||||
use fancy_regex::Regex;
|
||||
use utils::get_anyres_image_grid_shape;
|
||||
@ -145,7 +145,7 @@ impl ClipVisionTower {
|
||||
let config = if config.is_none() {
|
||||
ClipVisionConfig::clip_vit_large_patch14_336()
|
||||
} else {
|
||||
config.clone().unwrap()
|
||||
config.clone().context("no config")?
|
||||
};
|
||||
let select_layer = match select_layer {
|
||||
-1 | -2 => select_layer,
|
||||
@ -262,14 +262,14 @@ impl LLaVA {
|
||||
let image_features = if mm_patch_merge_type == "flat" {
|
||||
image_features
|
||||
.iter()
|
||||
.map(|x| x.flatten(0, 1).unwrap())
|
||||
.collect::<Vec<Tensor>>()
|
||||
.map(|x| x.flatten(0, 1))
|
||||
.collect::<Result<Vec<Tensor>>>()?
|
||||
} else if mm_patch_merge_type.starts_with("spatial") {
|
||||
let mut new_image_features = Vec::new();
|
||||
for (image_idx, image_feature) in image_features.iter().enumerate() {
|
||||
let new_image_feature = if image_feature.dims()[0] > 1 {
|
||||
let base_image_feature = image_feature.get(0).unwrap();
|
||||
let patch_image_feature = image_feature.i(1..).unwrap();
|
||||
let base_image_feature = image_feature.get(0)?;
|
||||
let patch_image_feature = image_feature.i(1..)?;
|
||||
let height = self.clip_vision_tower.num_patches_per_side();
|
||||
let width = height;
|
||||
assert_eq!(height * width, base_image_feature.dims()[0]);
|
||||
@ -313,16 +313,12 @@ impl LLaVA {
|
||||
};
|
||||
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
|
||||
} else {
|
||||
let new_image_feature = image_feature.get(0).unwrap();
|
||||
let new_image_feature = image_feature.get(0)?;
|
||||
if mm_patch_merge_type.contains("unpad") {
|
||||
Tensor::cat(
|
||||
&[
|
||||
new_image_feature,
|
||||
self.image_newline.clone().unsqueeze(0).unwrap(),
|
||||
],
|
||||
&[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
)?
|
||||
} else {
|
||||
new_image_feature
|
||||
}
|
||||
|
@ -15,7 +15,7 @@
|
||||
//!
|
||||
|
||||
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||
use candle::{Module, ModuleT, Result, Tensor, D};
|
||||
use candle::{Context, Module, ModuleT, Result, Tensor, D};
|
||||
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -633,7 +633,7 @@ impl ImageClassificationModel {
|
||||
impl Module for ImageClassificationModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let all_hidden_states = self.segformer.forward(x)?;
|
||||
let hidden_states = all_hidden_states.last().unwrap();
|
||||
let hidden_states = all_hidden_states.last().context("no last")?;
|
||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let mean = hidden_states.mean(1)?;
|
||||
self.classifier.forward(&mean)
|
||||
|
Reference in New Issue
Block a user