mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo. * Add the yolo-v3 example. * Add minimum and maximum. * Use the newly introduced maximum. * Cuda support for min/max + add some testing. * Allow for more tests to work with accelerate. * Fix a typo.
This commit is contained in:
82
candle-examples/examples/yolo-v3/coco_classes.rs
Normal file
82
candle-examples/examples/yolo-v3/coco_classes.rs
Normal file
@ -0,0 +1,82 @@
|
||||
pub const NAMES: [&str; 80] = [
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorbike",
|
||||
"aeroplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"sofa",
|
||||
"pottedplant",
|
||||
"bed",
|
||||
"diningtable",
|
||||
"toilet",
|
||||
"tvmonitor",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
];
|
304
candle-examples/examples/yolo-v3/darknet.rs
Normal file
304
candle-examples/examples/yolo-v3/darknet.rs
Normal file
@ -0,0 +1,304 @@
|
||||
use candle::{Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Block {
|
||||
block_type: String,
|
||||
parameters: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn get(&self, key: &str) -> Result<&str> {
|
||||
match self.parameters.get(&key.to_string()) {
|
||||
None => candle::bail!("cannot find {} in {}", key, self.block_type),
|
||||
Some(value) => Ok(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Darknet {
|
||||
blocks: Vec<Block>,
|
||||
parameters: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
impl Darknet {
|
||||
fn get(&self, key: &str) -> Result<&str> {
|
||||
match self.parameters.get(&key.to_string()) {
|
||||
None => candle::bail!("cannot find {} in net parameters", key),
|
||||
Some(value) => Ok(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Accumulator {
|
||||
block_type: Option<String>,
|
||||
parameters: BTreeMap<String, String>,
|
||||
net: Darknet,
|
||||
}
|
||||
|
||||
impl Accumulator {
|
||||
fn new() -> Accumulator {
|
||||
Accumulator {
|
||||
block_type: None,
|
||||
parameters: BTreeMap::new(),
|
||||
net: Darknet {
|
||||
blocks: vec![],
|
||||
parameters: BTreeMap::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn finish_block(&mut self) {
|
||||
match &self.block_type {
|
||||
None => (),
|
||||
Some(block_type) => {
|
||||
if block_type == "net" {
|
||||
self.net.parameters = self.parameters.clone();
|
||||
} else {
|
||||
let block = Block {
|
||||
block_type: block_type.to_string(),
|
||||
parameters: self.parameters.clone(),
|
||||
};
|
||||
self.net.blocks.push(block);
|
||||
}
|
||||
self.parameters.clear();
|
||||
}
|
||||
}
|
||||
self.block_type = None;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_config<T: AsRef<Path>>(path: T) -> Result<Darknet> {
|
||||
let file = File::open(path.as_ref())?;
|
||||
let mut acc = Accumulator::new();
|
||||
for line in BufReader::new(file).lines() {
|
||||
let line = line?;
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
let line = line.trim();
|
||||
if line.starts_with('[') {
|
||||
if !line.ends_with(']') {
|
||||
candle::bail!("line does not end with ']' {line}")
|
||||
}
|
||||
let line = &line[1..line.len() - 1];
|
||||
acc.finish_block();
|
||||
acc.block_type = Some(line.to_string());
|
||||
} else {
|
||||
let key_value: Vec<&str> = line.splitn(2, '=').collect();
|
||||
if key_value.len() != 2 {
|
||||
candle::bail!("missing equal {line}")
|
||||
}
|
||||
let prev = acc.parameters.insert(
|
||||
key_value[0].trim().to_owned(),
|
||||
key_value[1].trim().to_owned(),
|
||||
);
|
||||
if prev.is_some() {
|
||||
candle::bail!("multiple value for key {}", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
acc.finish_block();
|
||||
Ok(acc.net)
|
||||
}
|
||||
|
||||
enum Bl {
|
||||
Layer(Box<dyn candle_nn::Module + Send>),
|
||||
Route(Vec<usize>),
|
||||
Shortcut(usize),
|
||||
Yolo(usize, Vec<(usize, usize)>),
|
||||
}
|
||||
|
||||
fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)> {
|
||||
let activation = b.get("activation")?;
|
||||
let filters = b.get("filters")?.parse::<usize>()?;
|
||||
let pad = b.get("pad")?.parse::<usize>()?;
|
||||
let size = b.get("size")?.parse::<usize>()?;
|
||||
let stride = b.get("stride")?.parse::<usize>()?;
|
||||
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
|
||||
let (bn, bias) = match b.parameters.get("batch_normalize") {
|
||||
Some(p) if p.parse::<usize>()? != 0 => {
|
||||
let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?;
|
||||
(Some(bn), false)
|
||||
}
|
||||
Some(_) | None => (None, true),
|
||||
};
|
||||
let conv_cfg = candle_nn::Conv2dConfig { stride, padding };
|
||||
let conv = if bias {
|
||||
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
||||
} else {
|
||||
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
|
||||
};
|
||||
let leaky = match activation {
|
||||
"leaky" => true,
|
||||
"linear" => false,
|
||||
otherwise => candle::bail!("unsupported activation {}", otherwise),
|
||||
};
|
||||
let func = candle_nn::func(move |xs| {
|
||||
let xs = conv.forward(xs)?;
|
||||
let xs = match &bn {
|
||||
Some(bn) => bn.forward(&xs)?,
|
||||
None => xs,
|
||||
};
|
||||
if leaky {
|
||||
xs.maximum(&(&xs * 0.1)?)
|
||||
} else {
|
||||
Ok(xs)
|
||||
}
|
||||
});
|
||||
Ok((filters, Bl::Layer(Box::new(func))))
|
||||
}
|
||||
|
||||
fn upsample(prev_channels: usize) -> Result<(usize, Bl)> {
|
||||
let layer = candle_nn::func(|xs| {
|
||||
let (_n, _c, h, w) = xs.dims4()?;
|
||||
xs.upsample_nearest2d(2 * h, 2 * w)
|
||||
});
|
||||
Ok((prev_channels, Bl::Layer(Box::new(layer))))
|
||||
}
|
||||
|
||||
fn int_list_of_string(s: &str) -> Result<Vec<i64>> {
|
||||
let res: std::result::Result<Vec<_>, _> =
|
||||
s.split(',').map(|xs| xs.trim().parse::<i64>()).collect();
|
||||
Ok(res?)
|
||||
}
|
||||
|
||||
fn usize_of_index(index: usize, i: i64) -> usize {
|
||||
if i >= 0 {
|
||||
i as usize
|
||||
} else {
|
||||
(index as i64 + i) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn route(index: usize, p: &[(usize, Bl)], block: &Block) -> Result<(usize, Bl)> {
|
||||
let layers = int_list_of_string(block.get("layers")?)?;
|
||||
let layers: Vec<usize> = layers
|
||||
.into_iter()
|
||||
.map(|l| usize_of_index(index, l))
|
||||
.collect();
|
||||
let channels = layers.iter().map(|&l| p[l].0).sum();
|
||||
Ok((channels, Bl::Route(layers)))
|
||||
}
|
||||
|
||||
fn shortcut(index: usize, p: usize, block: &Block) -> Result<(usize, Bl)> {
|
||||
let from = block.get("from")?.parse::<i64>()?;
|
||||
Ok((p, Bl::Shortcut(usize_of_index(index, from))))
|
||||
}
|
||||
|
||||
fn yolo(p: usize, block: &Block) -> Result<(usize, Bl)> {
|
||||
let classes = block.get("classes")?.parse::<usize>()?;
|
||||
let flat = int_list_of_string(block.get("anchors")?)?;
|
||||
if flat.len() % 2 != 0 {
|
||||
candle::bail!("even number of anchors");
|
||||
}
|
||||
let flat = flat.into_iter().map(|i| i as usize).collect::<Vec<_>>();
|
||||
let anchors: Vec<_> = (0..(flat.len() / 2))
|
||||
.map(|i| (flat[2 * i], flat[2 * i + 1]))
|
||||
.collect();
|
||||
let mask = int_list_of_string(block.get("mask")?)?;
|
||||
let anchors = mask.into_iter().map(|i| anchors[i as usize]).collect();
|
||||
Ok((p, Bl::Yolo(classes, anchors)))
|
||||
}
|
||||
|
||||
fn detect(
|
||||
xs: &Tensor,
|
||||
image_height: usize,
|
||||
classes: usize,
|
||||
anchors: &Vec<(usize, usize)>,
|
||||
) -> Result<Tensor> {
|
||||
let (bsize, _channels, height, _width) = xs.dims4()?;
|
||||
let stride = image_height / height;
|
||||
let grid_size = image_height / stride;
|
||||
let bbox_attrs = 5 + classes;
|
||||
let nanchors = anchors.len();
|
||||
let xs = xs
|
||||
.reshape((bsize, bbox_attrs * nanchors, grid_size * grid_size))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?
|
||||
.reshape((bsize, grid_size * grid_size * nanchors, bbox_attrs))?;
|
||||
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
|
||||
let a = grid.repeat((grid_size, 1))?;
|
||||
let b = a.t()?.contiguous()?;
|
||||
let x_offset = a.unsqueeze(1)?;
|
||||
let y_offset = b.unsqueeze(1)?;
|
||||
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
|
||||
.repeat((1, nanchors))?
|
||||
.reshape((grid_size * grid_size * nanchors, 2))?
|
||||
.unsqueeze(0)?;
|
||||
let anchors: Vec<f32> = anchors
|
||||
.iter()
|
||||
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
|
||||
.collect();
|
||||
let anchors = Tensor::new(anchors.as_slice(), &Device::Cpu)?
|
||||
.reshape((anchors.len() / 2, 2))?
|
||||
.repeat((grid_size * grid_size, 1))?
|
||||
.unsqueeze(0)?;
|
||||
let ys02 = xs.i((.., .., 0..2))?;
|
||||
let ys24 = xs.i((.., .., 2..4))?;
|
||||
let ys4 = xs.i((.., .., 4..))?;
|
||||
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
|
||||
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
|
||||
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
|
||||
Tensor::cat(&[ys02, ys24, ys4], 2)
|
||||
}
|
||||
|
||||
impl Darknet {
|
||||
pub fn height(&self) -> Result<usize> {
|
||||
let image_height = self.get("height")?.parse::<usize>()?;
|
||||
Ok(image_height)
|
||||
}
|
||||
|
||||
pub fn width(&self) -> Result<usize> {
|
||||
let image_width = self.get("width")?.parse::<usize>()?;
|
||||
Ok(image_width)
|
||||
}
|
||||
|
||||
pub fn build_model(&self, vb: VarBuilder) -> Result<Func> {
|
||||
let mut blocks: Vec<(usize, Bl)> = vec![];
|
||||
let mut prev_channels: usize = 3;
|
||||
for (index, block) in self.blocks.iter().enumerate() {
|
||||
let channels_and_bl = match block.block_type.as_str() {
|
||||
"convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?,
|
||||
"upsample" => upsample(prev_channels)?,
|
||||
"shortcut" => shortcut(index, prev_channels, block)?,
|
||||
"route" => route(index, &blocks, block)?,
|
||||
"yolo" => yolo(prev_channels, block)?,
|
||||
otherwise => candle::bail!("unsupported block type {}", otherwise),
|
||||
};
|
||||
prev_channels = channels_and_bl.0;
|
||||
blocks.push(channels_and_bl);
|
||||
}
|
||||
let image_height = self.height()?;
|
||||
let func = candle_nn::func(move |xs| {
|
||||
let mut prev_ys: Vec<Tensor> = vec![];
|
||||
let mut detections: Vec<Tensor> = vec![];
|
||||
for (_, b) in blocks.iter() {
|
||||
let ys = match b {
|
||||
Bl::Layer(l) => {
|
||||
let xs = prev_ys.last().unwrap_or(xs);
|
||||
l.forward(xs)?
|
||||
}
|
||||
Bl::Route(layers) => {
|
||||
let layers: Vec<_> = layers.iter().map(|&i| &prev_ys[i]).collect();
|
||||
Tensor::cat(&layers, 1)?
|
||||
}
|
||||
Bl::Shortcut(from) => (prev_ys.last().unwrap() + prev_ys.get(*from).unwrap())?,
|
||||
Bl::Yolo(classes, anchors) => {
|
||||
let xs = prev_ys.last().unwrap_or(xs);
|
||||
detections.push(detect(xs, image_height, *classes, anchors)?);
|
||||
Tensor::new(&[0u32], &Device::Cpu)?
|
||||
}
|
||||
};
|
||||
prev_ys.push(ys);
|
||||
}
|
||||
Tensor::cat(&detections, 1)
|
||||
});
|
||||
Ok(func)
|
||||
}
|
||||
}
|
145
candle-examples/examples/yolo-v3/main.rs
Normal file
145
candle-examples/examples/yolo-v3/main.rs
Normal file
@ -0,0 +1,145 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod coco_classes;
|
||||
mod darknet;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use clap::Parser;
|
||||
|
||||
const CONFIG_NAME: &str = "candle-examples/examples/yolo/yolo-v3.cfg";
|
||||
const CONFIDENCE_THRESHOLD: f64 = 0.5;
|
||||
const NMS_THRESHOLD: f64 = 0.4;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct Bbox {
|
||||
xmin: f64,
|
||||
ymin: f64,
|
||||
xmax: f64,
|
||||
ymax: f64,
|
||||
confidence: f64,
|
||||
}
|
||||
|
||||
// Intersection over union of two bounding boxes.
|
||||
fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
|
||||
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
|
||||
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
|
||||
let i_xmin = b1.xmin.max(b2.xmin);
|
||||
let i_xmax = b1.xmax.min(b2.xmax);
|
||||
let i_ymin = b1.ymin.max(b2.ymin);
|
||||
let i_ymax = b1.ymax.min(b2.ymax);
|
||||
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
|
||||
i_area / (b1_area + b2_area - i_area)
|
||||
}
|
||||
|
||||
// Assumes x1 <= x2 and y1 <= y2
|
||||
pub fn draw_rect(_: &mut Tensor, _x1: usize, _x2: usize, _y1: usize, _y2: usize) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn report(pred: &Tensor, img: &Tensor, w: usize, h: usize) -> Result<Tensor> {
|
||||
let (npreds, pred_size) = pred.dims2()?;
|
||||
let nclasses = pred_size - 5;
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index in 0..npreds {
|
||||
let pred = Vec::<f64>::try_from(pred.get(index)?)?;
|
||||
let confidence = pred[4];
|
||||
if confidence > CONFIDENCE_THRESHOLD {
|
||||
let mut class_index = 0;
|
||||
for i in 0..nclasses {
|
||||
if pred[5 + i] > pred[5 + class_index] {
|
||||
class_index = i
|
||||
}
|
||||
}
|
||||
if pred[class_index + 5] > 0. {
|
||||
let bbox = Bbox {
|
||||
xmin: pred[0] - pred[2] / 2.,
|
||||
ymin: pred[1] - pred[3] / 2.,
|
||||
xmax: pred[0] + pred[2] / 2.,
|
||||
ymax: pred[1] + pred[3] / 2.,
|
||||
confidence,
|
||||
};
|
||||
bboxes[class_index].push(bbox)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Perform non-maximum suppression.
|
||||
for bboxes_for_class in bboxes.iter_mut() {
|
||||
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
|
||||
let mut current_index = 0;
|
||||
for index in 0..bboxes_for_class.len() {
|
||||
let mut drop = false;
|
||||
for prev_index in 0..current_index {
|
||||
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
|
||||
if iou > NMS_THRESHOLD {
|
||||
drop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !drop {
|
||||
bboxes_for_class.swap(current_index, index);
|
||||
current_index += 1;
|
||||
}
|
||||
}
|
||||
bboxes_for_class.truncate(current_index);
|
||||
}
|
||||
// Annotate the original image and print boxes information.
|
||||
let (_, initial_h, initial_w) = img.dims3()?;
|
||||
let mut img = (img.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let w_ratio = initial_w as f64 / w as f64;
|
||||
let h_ratio = initial_h as f64 / h as f64;
|
||||
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
|
||||
for b in bboxes_for_class.iter() {
|
||||
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
|
||||
let xmin = ((b.xmin * w_ratio) as usize).clamp(0, initial_w - 1);
|
||||
let ymin = ((b.ymin * h_ratio) as usize).clamp(0, initial_h - 1);
|
||||
let xmax = ((b.xmax * w_ratio) as usize).clamp(0, initial_w - 1);
|
||||
let ymax = ((b.ymax * h_ratio) as usize).clamp(0, initial_h - 1);
|
||||
draw_rect(&mut img, xmin, xmax, ymin, ymax.min(ymin + 2));
|
||||
draw_rect(&mut img, xmin, xmax, ymin.max(ymax - 2), ymax);
|
||||
draw_rect(&mut img, xmin, xmax.min(xmin + 2), ymin, ymax);
|
||||
draw_rect(&mut img, xmin.max(xmax - 2), xmax, ymin, ymax);
|
||||
}
|
||||
}
|
||||
Ok((img * 255.)?.to_dtype(DType::U8)?)
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Model weights, in safetensors format.
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
|
||||
images: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
||||
let darknet = darknet::parse_config(CONFIG_NAME)?;
|
||||
let model = darknet.build_model(vb)?;
|
||||
|
||||
for image in args.images.iter() {
|
||||
// Load the image file and resize it.
|
||||
let net_width = darknet.width()?;
|
||||
let net_height = darknet.height()?;
|
||||
let image = candle_examples::load_image_and_resize(image, net_width, net_height)?;
|
||||
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
|
||||
let predictions = model.forward(&image)?.squeeze(0)?;
|
||||
let _image = report(&predictions, &image, net_width, net_height)?;
|
||||
println!("converted {image}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
790
candle-examples/examples/yolo-v3/yolo-v3.cfg
Normal file
790
candle-examples/examples/yolo-v3/yolo-v3.cfg
Normal file
@ -0,0 +1,790 @@
|
||||
[net]
|
||||
# Testing
|
||||
batch=1
|
||||
subdivisions=1
|
||||
# Training
|
||||
# batch=64
|
||||
# subdivisions=16
|
||||
width= 416
|
||||
|
||||
height = 416
|
||||
channels=3
|
||||
momentum=0.9
|
||||
decay=0.0005
|
||||
angle=0
|
||||
saturation = 1.5
|
||||
exposure = 1.5
|
||||
hue=.1
|
||||
|
||||
learning_rate=0.001
|
||||
burn_in=1000
|
||||
max_batches = 500200
|
||||
policy=steps
|
||||
steps=400000,450000
|
||||
scales=.1,.1
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=32
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
# Downsample
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=3
|
||||
stride=2
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=32
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
# Downsample
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
stride=2
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=64
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
# Downsample
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=2
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
# Downsample
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=2
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
# Downsample
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
stride=2
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=1024
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[shortcut]
|
||||
from=-3
|
||||
activation=linear
|
||||
|
||||
######################
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=1024
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=1024
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=512
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=1024
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
filters=255
|
||||
activation=linear
|
||||
|
||||
|
||||
[yolo]
|
||||
mask = 6,7,8
|
||||
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
|
||||
classes=80
|
||||
num=9
|
||||
jitter=.3
|
||||
ignore_thresh = .5
|
||||
truth_thresh = 1
|
||||
random=1
|
||||
|
||||
|
||||
[route]
|
||||
layers = -4
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[upsample]
|
||||
stride=2
|
||||
|
||||
[route]
|
||||
layers = -1, 61
|
||||
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=512
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=512
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=256
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=512
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
filters=255
|
||||
activation=linear
|
||||
|
||||
|
||||
[yolo]
|
||||
mask = 3,4,5
|
||||
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
|
||||
classes=80
|
||||
num=9
|
||||
jitter=.3
|
||||
ignore_thresh = .5
|
||||
truth_thresh = 1
|
||||
random=1
|
||||
|
||||
|
||||
|
||||
[route]
|
||||
layers = -4
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[upsample]
|
||||
stride=2
|
||||
|
||||
[route]
|
||||
layers = -1, 36
|
||||
|
||||
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=256
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=256
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
filters=128
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
batch_normalize=1
|
||||
size=3
|
||||
stride=1
|
||||
pad=1
|
||||
filters=256
|
||||
activation=leaky
|
||||
|
||||
[convolutional]
|
||||
size=1
|
||||
stride=1
|
||||
pad=1
|
||||
filters=255
|
||||
activation=linear
|
||||
|
||||
|
||||
[yolo]
|
||||
mask = 0,1,2
|
||||
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
|
||||
classes=80
|
||||
num=9
|
||||
jitter=.3
|
||||
ignore_thresh = .5
|
||||
truth_thresh = 1
|
||||
random=1
|
||||
|
Reference in New Issue
Block a user