mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
feat: implement VGG13, VGG16 and VGG19 (#1211)
* feat: implement VGG13, VGG16 and VGG19 * Cosmetic fixes. * More cosmetic tweaks + avoid re-loading the weights on each final layer. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
13
candle-examples/examples/vgg/README.md
Normal file
13
candle-examples/examples/vgg/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
## VGG Model Implementation
|
||||||
|
|
||||||
|
This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library.
|
||||||
|
|
||||||
|
The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image.
|
||||||
|
|
||||||
|
You can run the example with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
|
||||||
|
```
|
||||||
|
|
||||||
|
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
77
candle-examples/examples/vgg/main.rs
Normal file
77
candle-examples/examples/vgg/main.rs
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use candle_transformers::models::vgg::{Models, Vgg};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
Vgg13,
|
||||||
|
Vgg16,
|
||||||
|
Vgg19,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Variant of the model to use.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::Vgg13)]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||||
|
|
||||||
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = match args.which {
|
||||||
|
Which::Vgg13 => "timm/vgg13.tv_in1k",
|
||||||
|
Which::Vgg16 => "timm/vgg16.tv_in1k",
|
||||||
|
Which::Vgg19 => "timm/vgg19.tv_in1k",
|
||||||
|
};
|
||||||
|
let api = api.model(repo.into());
|
||||||
|
let filename = "model.safetensors";
|
||||||
|
let model_file = api.get(filename)?;
|
||||||
|
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
|
let model = match args.which {
|
||||||
|
Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?,
|
||||||
|
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
|
||||||
|
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
|
||||||
|
};
|
||||||
|
let logits = model.forward(&image)?;
|
||||||
|
|
||||||
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
|
.i(0)?
|
||||||
|
.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
// Sort the predictions and take the top 5
|
||||||
|
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||||
|
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||||
|
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Print the top predictions
|
||||||
|
for &(i, p) in &top {
|
||||||
|
println!(
|
||||||
|
"{:50}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[i],
|
||||||
|
p * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -28,6 +28,7 @@ pub mod segment_anything;
|
|||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod t5;
|
pub mod t5;
|
||||||
|
pub mod vgg;
|
||||||
pub mod vit;
|
pub mod vit;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
pub mod with_tracing;
|
pub mod with_tracing;
|
||||||
|
254
candle-transformers/src/models/vgg.rs
Normal file
254
candle-transformers/src/models/vgg.rs
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
//! VGG-16 model implementation.
|
||||||
|
//!
|
||||||
|
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
|
||||||
|
//! <https://arxiv.org/abs/1409.1556>
|
||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
use candle_nn::{Func, VarBuilder};
|
||||||
|
|
||||||
|
// Enum representing the different VGG models
|
||||||
|
pub enum Models {
|
||||||
|
Vgg13,
|
||||||
|
Vgg16,
|
||||||
|
Vgg19,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct representing a VGG model
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Vgg<'a> {
|
||||||
|
blocks: Vec<Func<'a>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Struct representing the configuration for the pre-logit layer
|
||||||
|
struct PreLogitConfig {
|
||||||
|
in_dim: (usize, usize, usize, usize),
|
||||||
|
target_in: usize,
|
||||||
|
target_out: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation of the VGG model
|
||||||
|
impl<'a> Vgg<'a> {
|
||||||
|
// Function to create a new VGG model
|
||||||
|
pub fn new(vb: VarBuilder<'a>, model: Models) -> Result<Self> {
|
||||||
|
let blocks = match model {
|
||||||
|
Models::Vgg13 => vgg13_blocks(vb)?,
|
||||||
|
Models::Vgg16 => vgg16_blocks(vb)?,
|
||||||
|
Models::Vgg19 => vgg19_blocks(vb)?,
|
||||||
|
};
|
||||||
|
Ok(Self { blocks })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation of the forward pass for the VGG model
|
||||||
|
impl Module for Vgg<'_> {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.unsqueeze(0)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
xs = xs.apply(block)?;
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to create a conv2d block
|
||||||
|
// The block is composed of two conv2d layers followed by a max pool layer
|
||||||
|
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<'static>> {
|
||||||
|
let layers = convs
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(_, &(in_c, out_c, name))| {
|
||||||
|
candle_nn::conv2d(
|
||||||
|
in_c,
|
||||||
|
out_c,
|
||||||
|
3,
|
||||||
|
candle_nn::Conv2dConfig {
|
||||||
|
stride: 1,
|
||||||
|
padding: 1,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
vb.pp(name),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in layers.iter() {
|
||||||
|
xs = xs.apply(layer)?.relu()?
|
||||||
|
}
|
||||||
|
xs = xs.max_pool2d_with_stride(2, 2)?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to create a fully connected layer
|
||||||
|
// The layer is composed of two linear layers followed by a dropout layer
|
||||||
|
fn fully_connected(
|
||||||
|
num_classes: usize,
|
||||||
|
pre_logit_1: PreLogitConfig,
|
||||||
|
pre_logit_2: PreLogitConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Func> {
|
||||||
|
let lin = get_weights_and_biases(
|
||||||
|
&vb.pp("pre_logits.fc1"),
|
||||||
|
pre_logit_1.in_dim,
|
||||||
|
pre_logit_1.target_in,
|
||||||
|
pre_logit_1.target_out,
|
||||||
|
)?;
|
||||||
|
let lin2 = get_weights_and_biases(
|
||||||
|
&vb.pp("pre_logits.fc2"),
|
||||||
|
pre_logit_2.in_dim,
|
||||||
|
pre_logit_2.target_in,
|
||||||
|
pre_logit_2.target_out,
|
||||||
|
)?;
|
||||||
|
Ok(Func::new(move |xs| {
|
||||||
|
let xs = xs.reshape((1, pre_logit_1.target_out))?;
|
||||||
|
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?;
|
||||||
|
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?;
|
||||||
|
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
|
||||||
|
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?;
|
||||||
|
Ok(xs)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to get the weights and biases for a layer
|
||||||
|
// This is required because the weights and biases are stored in different format than our linear layer expects
|
||||||
|
fn get_weights_and_biases(
|
||||||
|
vs: &VarBuilder,
|
||||||
|
in_dim: (usize, usize, usize, usize),
|
||||||
|
target_in: usize,
|
||||||
|
target_out: usize,
|
||||||
|
) -> Result<candle_nn::Linear> {
|
||||||
|
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vs.get_with_hints(in_dim, "weight", init_ws)?;
|
||||||
|
let ws = ws.reshape((target_in, target_out))?;
|
||||||
|
let bound = 1. / (target_out as f64).sqrt();
|
||||||
|
let init_bs = candle_nn::Init::Uniform {
|
||||||
|
lo: -bound,
|
||||||
|
up: bound,
|
||||||
|
};
|
||||||
|
let bs = vs.get_with_hints(target_in, "bias", init_bs)?;
|
||||||
|
Ok(candle_nn::Linear::new(ws, Some(bs)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||||
|
let num_classes = 1000;
|
||||||
|
let blocks = vec![
|
||||||
|
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||||
|
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
|
||||||
|
conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?,
|
||||||
|
conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?,
|
||||||
|
conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?,
|
||||||
|
fully_connected(
|
||||||
|
num_classes,
|
||||||
|
PreLogitConfig {
|
||||||
|
in_dim: (4096, 512, 7, 7),
|
||||||
|
target_in: 4096,
|
||||||
|
target_out: 512 * 7 * 7,
|
||||||
|
},
|
||||||
|
PreLogitConfig {
|
||||||
|
in_dim: (4096, 4096, 1, 1),
|
||||||
|
target_in: 4096,
|
||||||
|
target_out: 4096,
|
||||||
|
},
|
||||||
|
vb.clone(),
|
||||||
|
)?,
|
||||||
|
];
|
||||||
|
Ok(blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||||
|
let num_classes = 1000;
|
||||||
|
let blocks = vec![
|
||||||
|
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||||
|
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
|
||||||
|
conv2d_block(
|
||||||
|
&[
|
||||||
|
(128, 256, "features.10"),
|
||||||
|
(256, 256, "features.12"),
|
||||||
|
(256, 256, "features.14"),
|
||||||
|
],
|
||||||
|
&vb,
|
||||||
|
)?,
|
||||||
|
conv2d_block(
|
||||||
|
&[
|
||||||
|
(256, 512, "features.17"),
|
||||||
|
(512, 512, "features.19"),
|
||||||
|
(512, 512, "features.21"),
|
||||||
|
],
|
||||||
|
&vb,
|
||||||
|
)?,
|
||||||
|
conv2d_block(
|
||||||
|
&[
|
||||||
|
(512, 512, "features.24"),
|
||||||
|
(512, 512, "features.26"),
|
||||||
|
(512, 512, "features.28"),
|
||||||
|
],
|
||||||
|
&vb,
|
||||||
|
)?,
|
||||||
|
fully_connected(
|
||||||
|
num_classes,
|
||||||
|
PreLogitConfig {
|
||||||
|
in_dim: (4096, 512, 7, 7),
|
||||||
|
target_in: 4096,
|
||||||
|
target_out: 512 * 7 * 7,
|
||||||
|
},
|
||||||
|
PreLogitConfig {
|
||||||
|
in_dim: (4096, 4096, 1, 1),
|
||||||
|
target_in: 4096,
|
||||||
|
target_out: 4096,
|
||||||
|
},
|
||||||
|
vb.clone(),
|
||||||
|
)?,
|
||||||
|
];
|
||||||
|
Ok(blocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
||||||
|
let num_classes = 1000;
|
||||||
|
let blocks = vec![
|
||||||
|
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||||
|
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
|
||||||
|
conv2d_block(
|
||||||
|
&[
|
||||||
|
(128, 256, "features.10"),
|
||||||
|
(256, 256, "features.12"),
|
||||||
|
(256, 256, "features.14"),
|
||||||
|
(256, 256, "features.16"),
|
||||||
|
],
|
||||||
|
&vb,
|
||||||
|
)?,
|
||||||
|
conv2d_block(
|
||||||
|
&[
|
||||||
|
(256, 512, "features.19"),
|
||||||
|
(512, 512, "features.21"),
|
||||||
|
(512, 512, "features.23"),
|
||||||
|
(512, 512, "features.25"),
|
||||||
|
],
|
||||||
|
&vb,
|
||||||
|
)?,
|
||||||
|
conv2d_block(
|
||||||
|
&[
|
||||||
|
(512, 512, "features.28"),
|
||||||
|
(512, 512, "features.30"),
|
||||||
|
(512, 512, "features.32"),
|
||||||
|
(512, 512, "features.34"),
|
||||||
|
],
|
||||||
|
&vb,
|
||||||
|
)?,
|
||||||
|
fully_connected(
|
||||||
|
num_classes,
|
||||||
|
PreLogitConfig {
|
||||||
|
in_dim: (4096, 512, 7, 7),
|
||||||
|
target_in: 4096,
|
||||||
|
target_out: 512 * 7 * 7,
|
||||||
|
},
|
||||||
|
PreLogitConfig {
|
||||||
|
in_dim: (4096, 4096, 1, 1),
|
||||||
|
target_in: 4096,
|
||||||
|
target_out: 4096,
|
||||||
|
},
|
||||||
|
vb.clone(),
|
||||||
|
)?,
|
||||||
|
];
|
||||||
|
Ok(blocks)
|
||||||
|
}
|
Reference in New Issue
Block a user