mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Some fixes for yolo-v3. (#529)
* Some fixes for yolo-v3. * Use the running stats for inference in the batch-norm layer. * Get some proper predictions for yolo. * Avoid the quadratic insertion.
This commit is contained in:
@ -497,10 +497,7 @@ impl Tensor {
|
|||||||
let repeats = shape.into();
|
let repeats = shape.into();
|
||||||
let repeats = repeats.dims();
|
let repeats = repeats.dims();
|
||||||
let mut inp = if self.rank() < repeats.len() {
|
let mut inp = if self.rank() < repeats.len() {
|
||||||
let mut shape = self.dims().to_vec();
|
let shape = [vec![1; repeats.len() - self.rank()], self.dims().to_vec()].concat();
|
||||||
while shape.len() < repeats.len() {
|
|
||||||
shape.push(1)
|
|
||||||
}
|
|
||||||
self.reshape(shape)?
|
self.reshape(shape)?
|
||||||
} else {
|
} else {
|
||||||
self.clone()
|
self.clone()
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{Device, IndexOp, Result, Tensor};
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
|
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
@ -145,11 +145,12 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
Some(bn) => bn.forward(&xs)?,
|
Some(bn) => bn.forward(&xs)?,
|
||||||
None => xs,
|
None => xs,
|
||||||
};
|
};
|
||||||
if leaky {
|
let xs = if leaky {
|
||||||
xs.maximum(&(&xs * 0.1)?)
|
xs.maximum(&(&xs * 0.1)?)?
|
||||||
} else {
|
} else {
|
||||||
|
xs
|
||||||
|
};
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
|
||||||
});
|
});
|
||||||
Ok((filters, Bl::Layer(Box::new(func))))
|
Ok((filters, Bl::Layer(Box::new(func))))
|
||||||
}
|
}
|
||||||
@ -225,12 +226,13 @@ fn detect(
|
|||||||
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
|
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
|
||||||
let a = grid.repeat((grid_size, 1))?;
|
let a = grid.repeat((grid_size, 1))?;
|
||||||
let b = a.t()?.contiguous()?;
|
let b = a.t()?.contiguous()?;
|
||||||
let x_offset = a.unsqueeze(1)?;
|
let x_offset = a.flatten_all()?.unsqueeze(1)?;
|
||||||
let y_offset = b.unsqueeze(1)?;
|
let y_offset = b.flatten_all()?.unsqueeze(1)?;
|
||||||
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
|
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
|
||||||
.repeat((1, nanchors))?
|
.repeat((1, nanchors))?
|
||||||
.reshape((grid_size * grid_size * nanchors, 2))?
|
.reshape((grid_size * grid_size * nanchors, 2))?
|
||||||
.unsqueeze(0)?;
|
.unsqueeze(0)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
let anchors: Vec<f32> = anchors
|
let anchors: Vec<f32> = anchors
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
|
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
|
||||||
@ -245,7 +247,8 @@ fn detect(
|
|||||||
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
|
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
|
||||||
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
|
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
|
||||||
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
|
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
|
||||||
Tensor::cat(&[ys02, ys24, ys4], 2)
|
let ys = Tensor::cat(&[ys02, ys24, ys4], 2)?;
|
||||||
|
Ok(ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Darknet {
|
impl Darknet {
|
||||||
|
7
candle-examples/examples/yolo-v3/extract-weights.py
Normal file
7
candle-examples/examples/yolo-v3/extract-weights.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
def remove_prefix(text, prefix):
|
||||||
|
return text[text.startswith(prefix) and len(prefix):]
|
||||||
|
nps = {}
|
||||||
|
for k, v in model.state_dict().items():
|
||||||
|
k = remove_prefix(k, 'module_list.')
|
||||||
|
nps[k] = v.detach().numpy()
|
||||||
|
np.savez('yolo-v3.ot', **nps)
|
@ -11,22 +11,22 @@ use anyhow::Result;
|
|||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use image::{DynamicImage, ImageBuffer};
|
||||||
|
|
||||||
const CONFIG_NAME: &str = "candle-examples/examples/yolo/yolo-v3.cfg";
|
const CONFIDENCE_THRESHOLD: f32 = 0.5;
|
||||||
const CONFIDENCE_THRESHOLD: f64 = 0.5;
|
const NMS_THRESHOLD: f32 = 0.4;
|
||||||
const NMS_THRESHOLD: f64 = 0.4;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
struct Bbox {
|
struct Bbox {
|
||||||
xmin: f64,
|
xmin: f32,
|
||||||
ymin: f64,
|
ymin: f32,
|
||||||
xmax: f64,
|
xmax: f32,
|
||||||
ymax: f64,
|
ymax: f32,
|
||||||
confidence: f64,
|
confidence: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Intersection over union of two bounding boxes.
|
// Intersection over union of two bounding boxes.
|
||||||
fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
|
fn iou(b1: &Bbox, b2: &Bbox) -> f32 {
|
||||||
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||||
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||||
let i_xmin = b1.xmin.max(b2.xmin);
|
let i_xmin = b1.xmin.max(b2.xmin);
|
||||||
@ -38,18 +38,35 @@ fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assumes x1 <= x2 and y1 <= y2
|
// Assumes x1 <= x2 and y1 <= y2
|
||||||
pub fn draw_rect(_: &mut Tensor, _x1: usize, _x2: usize, _y1: usize, _y2: usize) {
|
pub fn draw_rect(
|
||||||
todo!()
|
img: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
|
||||||
|
x1: u32,
|
||||||
|
x2: u32,
|
||||||
|
y1: u32,
|
||||||
|
y2: u32,
|
||||||
|
) {
|
||||||
|
for x in x1..=x2 {
|
||||||
|
let pixel = img.get_pixel_mut(x, y1);
|
||||||
|
*pixel = image::Rgb([255, 0, 0]);
|
||||||
|
let pixel = img.get_pixel_mut(x, y2);
|
||||||
|
*pixel = image::Rgb([255, 0, 0]);
|
||||||
|
}
|
||||||
|
for y in y1..=y2 {
|
||||||
|
let pixel = img.get_pixel_mut(x1, y);
|
||||||
|
*pixel = image::Rgb([255, 0, 0]);
|
||||||
|
let pixel = img.get_pixel_mut(x2, y);
|
||||||
|
*pixel = image::Rgb([255, 0, 0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn report(pred: &Tensor, img: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
pub fn report(pred: &Tensor, img: DynamicImage, w: usize, h: usize) -> Result<DynamicImage> {
|
||||||
let (npreds, pred_size) = pred.dims2()?;
|
let (npreds, pred_size) = pred.dims2()?;
|
||||||
let nclasses = pred_size - 5;
|
let nclasses = pred_size - 5;
|
||||||
// The bounding boxes grouped by (maximum) class index.
|
// The bounding boxes grouped by (maximum) class index.
|
||||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||||
// Extract the bounding boxes for which confidence is above the threshold.
|
// Extract the bounding boxes for which confidence is above the threshold.
|
||||||
for index in 0..npreds {
|
for index in 0..npreds {
|
||||||
let pred = Vec::<f64>::try_from(pred.get(index)?)?;
|
let pred = Vec::<f32>::try_from(pred.get(index)?)?;
|
||||||
let confidence = pred[4];
|
let confidence = pred[4];
|
||||||
if confidence > CONFIDENCE_THRESHOLD {
|
if confidence > CONFIDENCE_THRESHOLD {
|
||||||
let mut class_index = 0;
|
let mut class_index = 0;
|
||||||
@ -91,24 +108,21 @@ pub fn report(pred: &Tensor, img: &Tensor, w: usize, h: usize) -> Result<Tensor>
|
|||||||
bboxes_for_class.truncate(current_index);
|
bboxes_for_class.truncate(current_index);
|
||||||
}
|
}
|
||||||
// Annotate the original image and print boxes information.
|
// Annotate the original image and print boxes information.
|
||||||
let (_, initial_h, initial_w) = img.dims3()?;
|
let (initial_h, initial_w) = (img.height(), img.width());
|
||||||
let mut img = (img.to_dtype(DType::F32)? * (1. / 255.))?;
|
let w_ratio = initial_w as f32 / w as f32;
|
||||||
let w_ratio = initial_w as f64 / w as f64;
|
let h_ratio = initial_h as f32 / h as f32;
|
||||||
let h_ratio = initial_h as f64 / h as f64;
|
let mut img = img.to_rgb8();
|
||||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||||
for b in bboxes_for_class.iter() {
|
for b in bboxes_for_class.iter() {
|
||||||
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
|
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
|
||||||
let xmin = ((b.xmin * w_ratio) as usize).clamp(0, initial_w - 1);
|
let xmin = ((b.xmin * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||||
let ymin = ((b.ymin * h_ratio) as usize).clamp(0, initial_h - 1);
|
let ymin = ((b.ymin * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||||
let xmax = ((b.xmax * w_ratio) as usize).clamp(0, initial_w - 1);
|
let xmax = ((b.xmax * w_ratio) as u32).clamp(0, initial_w - 1);
|
||||||
let ymax = ((b.ymax * h_ratio) as usize).clamp(0, initial_h - 1);
|
let ymax = ((b.ymax * h_ratio) as u32).clamp(0, initial_h - 1);
|
||||||
draw_rect(&mut img, xmin, xmax, ymin, ymax.min(ymin + 2));
|
draw_rect(&mut img, xmin, xmax, ymin, ymax);
|
||||||
draw_rect(&mut img, xmin, xmax, ymin.max(ymax - 2), ymax);
|
|
||||||
draw_rect(&mut img, xmin, xmax.min(xmin + 2), ymin, ymax);
|
|
||||||
draw_rect(&mut img, xmin.max(xmax - 2), xmax, ymin, ymax);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok((img * 255.)?.to_dtype(DType::U8)?)
|
Ok(DynamicImage::ImageRgb8(img))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -118,6 +132,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: String,
|
model: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config: String,
|
||||||
|
|
||||||
images: Vec<String>,
|
images: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,18 +145,36 @@ pub fn main() -> Result<()> {
|
|||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||||
let darknet = darknet::parse_config(CONFIG_NAME)?;
|
let darknet = darknet::parse_config(&args.config)?;
|
||||||
let model = darknet.build_model(vb)?;
|
let model = darknet.build_model(vb)?;
|
||||||
|
|
||||||
for image in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
|
println!("processing {image_name}");
|
||||||
|
let mut image_name = std::path::PathBuf::from(image_name);
|
||||||
// Load the image file and resize it.
|
// Load the image file and resize it.
|
||||||
let net_width = darknet.width()?;
|
let net_width = darknet.width()?;
|
||||||
let net_height = darknet.height()?;
|
let net_height = darknet.height()?;
|
||||||
let image = candle_examples::load_image_and_resize(image, net_width, net_height)?;
|
|
||||||
|
let original_image = image::io::Reader::open(&image_name)?
|
||||||
|
.decode()
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
let image = {
|
||||||
|
let data = original_image
|
||||||
|
.resize_exact(
|
||||||
|
net_width as u32,
|
||||||
|
net_height as u32,
|
||||||
|
image::imageops::FilterType::Triangle,
|
||||||
|
)
|
||||||
|
.to_rgb8()
|
||||||
|
.into_raw();
|
||||||
|
Tensor::from_vec(data, (net_width, net_height, 3), &Device::Cpu)?.permute((2, 0, 1))?
|
||||||
|
};
|
||||||
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||||
let predictions = model.forward(&image)?.squeeze(0)?;
|
let predictions = model.forward(&image)?.squeeze(0)?;
|
||||||
let _image = report(&predictions, &image, net_width, net_height)?;
|
let image = report(&predictions, original_image, net_width, net_height)?;
|
||||||
println!("converted {image}");
|
image_name.set_extension("pp.jpg");
|
||||||
|
println!("writing {image_name:?}");
|
||||||
|
image.save(image_name)?
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -40,6 +40,8 @@ impl From<f64> for BatchNormConfig {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BatchNorm {
|
pub struct BatchNorm {
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
weight_and_bias: Option<(Tensor, Tensor)>,
|
weight_and_bias: Option<(Tensor, Tensor)>,
|
||||||
remove_mean: bool,
|
remove_mean: bool,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
@ -47,7 +49,14 @@ pub struct BatchNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BatchNorm {
|
impl BatchNorm {
|
||||||
pub fn new(num_features: usize, weight: Tensor, bias: Tensor, eps: f64) -> Result<Self> {
|
pub fn new(
|
||||||
|
num_features: usize,
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
if eps < 0. {
|
if eps < 0. {
|
||||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||||
}
|
}
|
||||||
@ -64,6 +73,8 @@ impl BatchNorm {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
weight_and_bias: Some((weight, bias)),
|
weight_and_bias: Some((weight, bias)),
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
eps,
|
eps,
|
||||||
@ -71,11 +82,18 @@ impl BatchNorm {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_no_bias(num_features: usize, eps: f64) -> Result<Self> {
|
pub fn new_no_bias(
|
||||||
|
num_features: usize,
|
||||||
|
running_mean: Tensor,
|
||||||
|
running_var: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
) -> Result<Self> {
|
||||||
if eps < 0. {
|
if eps < 0. {
|
||||||
candle::bail!("batch-norm eps cannot be negative {eps}")
|
candle::bail!("batch-norm eps cannot be negative {eps}")
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
weight_and_bias: None,
|
weight_and_bias: None,
|
||||||
remove_mean: true,
|
remove_mean: true,
|
||||||
eps,
|
eps,
|
||||||
@ -84,8 +102,8 @@ impl BatchNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Module for BatchNorm {
|
impl BatchNorm {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
@ -129,6 +147,29 @@ impl crate::Module for BatchNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl crate::Module for BatchNorm {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let target_shape: Vec<usize> = x
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
|
||||||
|
.collect();
|
||||||
|
let target_shape = target_shape.as_slice();
|
||||||
|
let x = x
|
||||||
|
.broadcast_sub(&self.running_mean.reshape(target_shape)?)?
|
||||||
|
.broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?;
|
||||||
|
match &self.weight_and_bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some((weight, bias)) => {
|
||||||
|
let weight = weight.reshape(target_shape)?;
|
||||||
|
let bias = bias.reshape(target_shape)?;
|
||||||
|
x.broadcast_mul(&weight)?.broadcast_add(&bias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||||
num_features: usize,
|
num_features: usize,
|
||||||
config: C,
|
config: C,
|
||||||
@ -138,6 +179,8 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
|||||||
if config.eps < 0. {
|
if config.eps < 0. {
|
||||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||||
}
|
}
|
||||||
|
let running_mean = vb.get_or_init(num_features, "running_mean", crate::Init::Const(0.))?;
|
||||||
|
let running_var = vb.get_or_init(num_features, "running_var", crate::Init::Const(1.))?;
|
||||||
let weight_and_bias = if config.affine {
|
let weight_and_bias = if config.affine {
|
||||||
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
|
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
|
||||||
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
|
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
|
||||||
@ -146,6 +189,8 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(BatchNorm {
|
Ok(BatchNorm {
|
||||||
|
running_mean,
|
||||||
|
running_var,
|
||||||
weight_and_bias,
|
weight_and_bias,
|
||||||
remove_mean: config.remove_mean,
|
remove_mean: config.remove_mean,
|
||||||
eps: config.eps,
|
eps: config.eps,
|
||||||
|
@ -7,8 +7,8 @@ extern crate accelerate_src;
|
|||||||
mod test_utils;
|
mod test_utils;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::{BatchNorm, Module};
|
use candle_nn::BatchNorm;
|
||||||
|
|
||||||
/* The test below has been generated using the following PyTorch code:
|
/* The test below has been generated using the following PyTorch code:
|
||||||
import torch
|
import torch
|
||||||
@ -21,7 +21,9 @@ print(output.flatten())
|
|||||||
*/
|
*/
|
||||||
#[test]
|
#[test]
|
||||||
fn batch_norm() -> Result<()> {
|
fn batch_norm() -> Result<()> {
|
||||||
let bn = BatchNorm::new_no_bias(5, 1e-8)?;
|
let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?;
|
||||||
|
let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?;
|
||||||
|
let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?;
|
||||||
let input: [f32; 120] = [
|
let input: [f32; 120] = [
|
||||||
-0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975,
|
-0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975,
|
||||||
-0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860,
|
-0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860,
|
||||||
@ -37,7 +39,7 @@ fn batch_norm() -> Result<()> {
|
|||||||
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
|
||||||
];
|
];
|
||||||
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
|
||||||
let output = bn.forward(&input)?;
|
let output = bn.forward_learning(&input)?;
|
||||||
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
assert_eq!(output.dims(), &[2, 5, 3, 4]);
|
||||||
let output = output.flatten_all()?;
|
let output = output.flatten_all()?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -59,11 +61,13 @@ fn batch_norm() -> Result<()> {
|
|||||||
);
|
);
|
||||||
let bn2 = BatchNorm::new(
|
let bn2 = BatchNorm::new(
|
||||||
5,
|
5,
|
||||||
|
running_mean.clone(),
|
||||||
|
running_var.clone(),
|
||||||
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||||
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
|
||||||
1e-8,
|
1e-8,
|
||||||
)?;
|
)?;
|
||||||
let output2 = bn2.forward(&input)?;
|
let output2 = bn2.forward_learning(&input)?;
|
||||||
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
|
||||||
let output2 = output2.flatten_all()?;
|
let output2 = output2.flatten_all()?;
|
||||||
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
|
||||||
|
Reference in New Issue
Block a user