mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Allow for different behavior between training and eval (#1213)
* Forward with training. * Do not use dropout on vgg evaluation.
This commit is contained in:
@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
|||||||
self(xs)
|
self(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||||
|
// separate the training and evaluation behaviors.
|
||||||
|
pub trait ModuleT {
|
||||||
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: Module> ModuleT for M {
|
||||||
|
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||||
|
self.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2271,6 +2271,11 @@ impl Tensor {
|
|||||||
m.forward(self)
|
m.forward(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the `forward` method of `m` on `self`.
|
||||||
|
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||||
|
m.forward_t(self, train)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||||
self.storage.read().unwrap()
|
self.storage.read().unwrap()
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
|||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||||
|
|
||||||
const IMAGE_DIM: usize = 784;
|
const IMAGE_DIM: usize = 784;
|
||||||
const LABELS: usize = 10;
|
const LABELS: usize = 10;
|
||||||
@ -95,7 +95,7 @@ impl ConvNet {
|
|||||||
.flatten_from(1)?
|
.flatten_from(1)?
|
||||||
.apply(&self.fc1)?
|
.apply(&self.fc1)?
|
||||||
.relu()?;
|
.relu()?;
|
||||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
use candle::{DType, IndexOp, D};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{ModuleT, VarBuilder};
|
||||||
use candle_transformers::models::vgg::{Models, Vgg};
|
use candle_transformers::models::vgg::{Models, Vgg};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
|
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
|
||||||
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
|
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
|
||||||
};
|
};
|
||||||
let logits = model.forward(&image)?;
|
let logits = model.forward_t(&image, /*train=*/ false)?;
|
||||||
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||||
.i(0)?
|
.i(0)?
|
||||||
|
@ -36,3 +36,38 @@ impl<'a> Func<'a> {
|
|||||||
Self { f: Arc::new(f) }
|
Self { f: Arc::new(f) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A layer defined by a simple closure.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct FuncT<'a> {
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> std::fmt::Debug for FuncT<'a> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(f, "func")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn func_t<'a, F>(f: F) -> FuncT<'a>
|
||||||
|
where
|
||||||
|
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||||
|
{
|
||||||
|
FuncT { f: Arc::new(f) }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> super::ModuleT for FuncT<'a> {
|
||||||
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||||
|
(*self.f)(xs, train)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> FuncT<'a> {
|
||||||
|
pub fn new<F>(f: F) -> Self
|
||||||
|
where
|
||||||
|
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||||
|
{
|
||||||
|
Self { f: Arc::new(f) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -22,7 +22,7 @@ pub use conv::{
|
|||||||
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||||
};
|
};
|
||||||
pub use embedding::{embedding, Embedding};
|
pub use embedding::{embedding, Embedding};
|
||||||
pub use func::{func, Func};
|
pub use func::{func, func_t, Func, FuncT};
|
||||||
pub use group_norm::{group_norm, GroupNorm};
|
pub use group_norm::{group_norm, GroupNorm};
|
||||||
pub use init::Init;
|
pub use init::Init;
|
||||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||||
@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential};
|
|||||||
pub use var_builder::VarBuilder;
|
pub use var_builder::VarBuilder;
|
||||||
pub use var_map::VarMap;
|
pub use var_map::VarMap;
|
||||||
|
|
||||||
pub use candle::Module;
|
pub use candle::{Module, ModuleT};
|
||||||
|
@ -84,6 +84,12 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl candle::ModuleT for Dropout {
|
||||||
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||||
|
self.forward(xs, train)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct SoftmaxLastDim;
|
struct SoftmaxLastDim;
|
||||||
|
|
||||||
impl candle::CustomOp1 for SoftmaxLastDim {
|
impl candle::CustomOp1 for SoftmaxLastDim {
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
//!
|
//!
|
||||||
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
|
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
|
||||||
//! <https://arxiv.org/abs/1409.1556>
|
//! <https://arxiv.org/abs/1409.1556>
|
||||||
use candle::{Module, Result, Tensor};
|
use candle::{ModuleT, Result, Tensor};
|
||||||
use candle_nn::{Func, VarBuilder};
|
use candle_nn::{FuncT, VarBuilder};
|
||||||
|
|
||||||
// Enum representing the different VGG models
|
// Enum representing the different VGG models
|
||||||
pub enum Models {
|
pub enum Models {
|
||||||
@ -15,7 +15,7 @@ pub enum Models {
|
|||||||
// Struct representing a VGG model
|
// Struct representing a VGG model
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Vgg<'a> {
|
pub struct Vgg<'a> {
|
||||||
blocks: Vec<Func<'a>>,
|
blocks: Vec<FuncT<'a>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Struct representing the configuration for the pre-logit layer
|
// Struct representing the configuration for the pre-logit layer
|
||||||
@ -39,11 +39,11 @@ impl<'a> Vgg<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Implementation of the forward pass for the VGG model
|
// Implementation of the forward pass for the VGG model
|
||||||
impl Module for Vgg<'_> {
|
impl ModuleT for Vgg<'_> {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||||
let mut xs = xs.unsqueeze(0)?;
|
let mut xs = xs.unsqueeze(0)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
xs = xs.apply(block)?;
|
xs = xs.apply_t(block, train)?;
|
||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
@ -51,7 +51,7 @@ impl Module for Vgg<'_> {
|
|||||||
|
|
||||||
// Function to create a conv2d block
|
// Function to create a conv2d block
|
||||||
// The block is composed of two conv2d layers followed by a max pool layer
|
// The block is composed of two conv2d layers followed by a max pool layer
|
||||||
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<'static>> {
|
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
|
||||||
let layers = convs
|
let layers = convs
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@ -70,7 +70,7 @@ fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<Func<
|
|||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
Ok(FuncT::new(move |xs, _train| {
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for layer in layers.iter() {
|
for layer in layers.iter() {
|
||||||
xs = xs.apply(layer)?.relu()?
|
xs = xs.apply(layer)?.relu()?
|
||||||
@ -87,7 +87,7 @@ fn fully_connected(
|
|||||||
pre_logit_1: PreLogitConfig,
|
pre_logit_1: PreLogitConfig,
|
||||||
pre_logit_2: PreLogitConfig,
|
pre_logit_2: PreLogitConfig,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Func> {
|
) -> Result<FuncT> {
|
||||||
let lin = get_weights_and_biases(
|
let lin = get_weights_and_biases(
|
||||||
&vb.pp("pre_logits.fc1"),
|
&vb.pp("pre_logits.fc1"),
|
||||||
pre_logit_1.in_dim,
|
pre_logit_1.in_dim,
|
||||||
@ -100,12 +100,15 @@ fn fully_connected(
|
|||||||
pre_logit_2.target_in,
|
pre_logit_2.target_in,
|
||||||
pre_logit_2.target_out,
|
pre_logit_2.target_out,
|
||||||
)?;
|
)?;
|
||||||
Ok(Func::new(move |xs| {
|
let dropout1 = candle_nn::Dropout::new(0.5);
|
||||||
|
let dropout2 = candle_nn::Dropout::new(0.5);
|
||||||
|
let dropout3 = candle_nn::Dropout::new(0.5);
|
||||||
|
Ok(FuncT::new(move |xs, train| {
|
||||||
let xs = xs.reshape((1, pre_logit_1.target_out))?;
|
let xs = xs.reshape((1, pre_logit_1.target_out))?;
|
||||||
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?;
|
let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
|
||||||
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?;
|
let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
|
||||||
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
|
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
|
||||||
let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?;
|
let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -130,7 +133,7 @@ fn get_weights_and_biases(
|
|||||||
Ok(candle_nn::Linear::new(ws, Some(bs)))
|
Ok(candle_nn::Linear::new(ws, Some(bs)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||||
let num_classes = 1000;
|
let num_classes = 1000;
|
||||||
let blocks = vec![
|
let blocks = vec![
|
||||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||||
@ -156,7 +159,7 @@ fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
|||||||
Ok(blocks)
|
Ok(blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||||
let num_classes = 1000;
|
let num_classes = 1000;
|
||||||
let blocks = vec![
|
let blocks = vec![
|
||||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||||
@ -203,7 +206,7 @@ fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
|||||||
Ok(blocks)
|
Ok(blocks)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<Func>> {
|
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||||
let num_classes = 1000;
|
let num_classes = 1000;
|
||||||
let blocks = vec![
|
let blocks = vec![
|
||||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||||
|
Reference in New Issue
Block a user