diff --git a/candle-examples/examples/vgg/README.md b/candle-examples/examples/vgg/README.md new file mode 100644 index 00000000..473038e8 --- /dev/null +++ b/candle-examples/examples/vgg/README.md @@ -0,0 +1,13 @@ +## VGG Model Implementation + +This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library. + +The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image. + +You can run the example with the following command: + +```bash +cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13 +``` + +In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19). diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs new file mode 100644 index 00000000..e01fa8e8 --- /dev/null +++ b/candle-examples/examples/vgg/main.rs @@ -0,0 +1,77 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::vgg::{Models, Vgg}; +use clap::{Parser, ValueEnum}; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Vgg13, + Vgg16, + Vgg19, +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Variant of the model to use. + #[arg(value_enum, long, default_value_t = Which::Vgg13)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let image = candle_examples::imagenet::load_image224(args.image)?; + + println!("loaded image {image:?}"); + + let api = hf_hub::api::sync::Api::new()?; + let repo = match args.which { + Which::Vgg13 => "timm/vgg13.tv_in1k", + Which::Vgg16 => "timm/vgg16.tv_in1k", + Which::Vgg19 => "timm/vgg19.tv_in1k", + }; + let api = api.model(repo.into()); + let filename = "model.safetensors"; + let model_file = api.get(filename)?; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = match args.which { + Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?, + Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?, + Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?, + }; + let logits = model.forward(&image)?; + + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + + // Sort the predictions and take the top 5 + let mut top: Vec<_> = prs.iter().enumerate().collect(); + top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + let top = top.into_iter().take(5).collect::>(); + + // Print the top predictions + for &(i, p) in &top { + println!( + "{:50}: {:.2}%", + candle_examples::imagenet::CLASSES[i], + p * 100.0 + ); + } + + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index c59bd880..aecfcd67 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,6 +28,7 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; pub mod t5; +pub mod vgg; pub mod vit; pub mod whisper; pub mod with_tracing; diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs new file mode 100644 index 00000000..7837dc3e --- /dev/null +++ b/candle-transformers/src/models/vgg.rs @@ -0,0 +1,254 @@ +//! VGG-16 model implementation. +//! +//! See Very Deep Convolutional Networks for Large-Scale Image Recognition +//! +use candle::{Module, Result, Tensor}; +use candle_nn::{Func, VarBuilder}; + +// Enum representing the different VGG models +pub enum Models { + Vgg13, + Vgg16, + Vgg19, +} + +// Struct representing a VGG model +#[derive(Debug)] +pub struct Vgg<'a> { + blocks: Vec>, +} + +// Struct representing the configuration for the pre-logit layer +struct PreLogitConfig { + in_dim: (usize, usize, usize, usize), + target_in: usize, + target_out: usize, +} + +// Implementation of the VGG model +impl<'a> Vgg<'a> { + // Function to create a new VGG model + pub fn new(vb: VarBuilder<'a>, model: Models) -> Result { + let blocks = match model { + Models::Vgg13 => vgg13_blocks(vb)?, + Models::Vgg16 => vgg16_blocks(vb)?, + Models::Vgg19 => vgg19_blocks(vb)?, + }; + Ok(Self { blocks }) + } +} + +// Implementation of the forward pass for the VGG model +impl Module for Vgg<'_> { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.unsqueeze(0)?; + for block in self.blocks.iter() { + xs = xs.apply(block)?; + } + Ok(xs) + } +} + +// Function to create a conv2d block +// The block is composed of two conv2d layers followed by a max pool layer +fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { + let layers = convs + .iter() + .enumerate() + .map(|(_, &(in_c, out_c, name))| { + candle_nn::conv2d( + in_c, + out_c, + 3, + candle_nn::Conv2dConfig { + stride: 1, + padding: 1, + ..Default::default() + }, + vb.pp(name), + ) + }) + .collect::>>()?; + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)?.relu()? + } + xs = xs.max_pool2d_with_stride(2, 2)?; + Ok(xs) + })) +} + +// Function to create a fully connected layer +// The layer is composed of two linear layers followed by a dropout layer +fn fully_connected( + num_classes: usize, + pre_logit_1: PreLogitConfig, + pre_logit_2: PreLogitConfig, + vb: VarBuilder, +) -> Result { + let lin = get_weights_and_biases( + &vb.pp("pre_logits.fc1"), + pre_logit_1.in_dim, + pre_logit_1.target_in, + pre_logit_1.target_out, + )?; + let lin2 = get_weights_and_biases( + &vb.pp("pre_logits.fc2"), + pre_logit_2.in_dim, + pre_logit_2.target_in, + pre_logit_2.target_out, + )?; + Ok(Func::new(move |xs| { + let xs = xs.reshape((1, pre_logit_1.target_out))?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?; + let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?; + Ok(xs) + })) +} + +// Function to get the weights and biases for a layer +// This is required because the weights and biases are stored in different format than our linear layer expects +fn get_weights_and_biases( + vs: &VarBuilder, + in_dim: (usize, usize, usize, usize), + target_in: usize, + target_out: usize, +) -> Result { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_with_hints(in_dim, "weight", init_ws)?; + let ws = ws.reshape((target_in, target_out))?; + let bound = 1. / (target_out as f64).sqrt(); + let init_bs = candle_nn::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_with_hints(target_in, "bias", init_bs)?; + Ok(candle_nn::Linear::new(ws, Some(bs))) +} + +fn vgg13_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?, + conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?, + conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} + +fn vgg16_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block( + &[ + (128, 256, "features.10"), + (256, 256, "features.12"), + (256, 256, "features.14"), + ], + &vb, + )?, + conv2d_block( + &[ + (256, 512, "features.17"), + (512, 512, "features.19"), + (512, 512, "features.21"), + ], + &vb, + )?, + conv2d_block( + &[ + (512, 512, "features.24"), + (512, 512, "features.26"), + (512, 512, "features.28"), + ], + &vb, + )?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} + +fn vgg19_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block( + &[ + (128, 256, "features.10"), + (256, 256, "features.12"), + (256, 256, "features.14"), + (256, 256, "features.16"), + ], + &vb, + )?, + conv2d_block( + &[ + (256, 512, "features.19"), + (512, 512, "features.21"), + (512, 512, "features.23"), + (512, 512, "features.25"), + ], + &vb, + )?, + conv2d_block( + &[ + (512, 512, "features.28"), + (512, 512, "features.30"), + (512, 512, "features.32"), + (512, 512, "features.34"), + ], + &vb, + )?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +}