Add a yolo-v3 example. (#528)

* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
This commit is contained in:
Laurent Mazare
2023-08-20 18:19:37 +01:00
committed by GitHub
parent e3d2786ffb
commit a1812f934f
24 changed files with 1497 additions and 8 deletions

View File

@ -0,0 +1,82 @@
pub const NAMES: [&str; 80] = [
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
];

View File

@ -0,0 +1,304 @@
use candle::{Device, IndexOp, Result, Tensor};
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Func, Module, VarBuilder};
use std::collections::BTreeMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
#[derive(Debug)]
struct Block {
block_type: String,
parameters: BTreeMap<String, String>,
}
impl Block {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
None => candle::bail!("cannot find {} in {}", key, self.block_type),
Some(value) => Ok(value),
}
}
}
#[derive(Debug)]
pub struct Darknet {
blocks: Vec<Block>,
parameters: BTreeMap<String, String>,
}
impl Darknet {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
None => candle::bail!("cannot find {} in net parameters", key),
Some(value) => Ok(value),
}
}
}
struct Accumulator {
block_type: Option<String>,
parameters: BTreeMap<String, String>,
net: Darknet,
}
impl Accumulator {
fn new() -> Accumulator {
Accumulator {
block_type: None,
parameters: BTreeMap::new(),
net: Darknet {
blocks: vec![],
parameters: BTreeMap::new(),
},
}
}
fn finish_block(&mut self) {
match &self.block_type {
None => (),
Some(block_type) => {
if block_type == "net" {
self.net.parameters = self.parameters.clone();
} else {
let block = Block {
block_type: block_type.to_string(),
parameters: self.parameters.clone(),
};
self.net.blocks.push(block);
}
self.parameters.clear();
}
}
self.block_type = None;
}
}
pub fn parse_config<T: AsRef<Path>>(path: T) -> Result<Darknet> {
let file = File::open(path.as_ref())?;
let mut acc = Accumulator::new();
for line in BufReader::new(file).lines() {
let line = line?;
if line.is_empty() || line.starts_with('#') {
continue;
}
let line = line.trim();
if line.starts_with('[') {
if !line.ends_with(']') {
candle::bail!("line does not end with ']' {line}")
}
let line = &line[1..line.len() - 1];
acc.finish_block();
acc.block_type = Some(line.to_string());
} else {
let key_value: Vec<&str> = line.splitn(2, '=').collect();
if key_value.len() != 2 {
candle::bail!("missing equal {line}")
}
let prev = acc.parameters.insert(
key_value[0].trim().to_owned(),
key_value[1].trim().to_owned(),
);
if prev.is_some() {
candle::bail!("multiple value for key {}", line)
}
}
}
acc.finish_block();
Ok(acc.net)
}
enum Bl {
Layer(Box<dyn candle_nn::Module + Send>),
Route(Vec<usize>),
Shortcut(usize),
Yolo(usize, Vec<(usize, usize)>),
}
fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)> {
let activation = b.get("activation")?;
let filters = b.get("filters")?.parse::<usize>()?;
let pad = b.get("pad")?.parse::<usize>()?;
let size = b.get("size")?.parse::<usize>()?;
let stride = b.get("stride")?.parse::<usize>()?;
let padding = if pad != 0 { (size - 1) / 2 } else { 0 };
let (bn, bias) = match b.parameters.get("batch_normalize") {
Some(p) if p.parse::<usize>()? != 0 => {
let bn = batch_norm(filters, 1e-5, vb.pp(&format!("batch_norm_{index}")))?;
(Some(bn), false)
}
Some(_) | None => (None, true),
};
let conv_cfg = candle_nn::Conv2dConfig { stride, padding };
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
} else {
conv2d_no_bias(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))?
};
let leaky = match activation {
"leaky" => true,
"linear" => false,
otherwise => candle::bail!("unsupported activation {}", otherwise),
};
let func = candle_nn::func(move |xs| {
let xs = conv.forward(xs)?;
let xs = match &bn {
Some(bn) => bn.forward(&xs)?,
None => xs,
};
if leaky {
xs.maximum(&(&xs * 0.1)?)
} else {
Ok(xs)
}
});
Ok((filters, Bl::Layer(Box::new(func))))
}
fn upsample(prev_channels: usize) -> Result<(usize, Bl)> {
let layer = candle_nn::func(|xs| {
let (_n, _c, h, w) = xs.dims4()?;
xs.upsample_nearest2d(2 * h, 2 * w)
});
Ok((prev_channels, Bl::Layer(Box::new(layer))))
}
fn int_list_of_string(s: &str) -> Result<Vec<i64>> {
let res: std::result::Result<Vec<_>, _> =
s.split(',').map(|xs| xs.trim().parse::<i64>()).collect();
Ok(res?)
}
fn usize_of_index(index: usize, i: i64) -> usize {
if i >= 0 {
i as usize
} else {
(index as i64 + i) as usize
}
}
fn route(index: usize, p: &[(usize, Bl)], block: &Block) -> Result<(usize, Bl)> {
let layers = int_list_of_string(block.get("layers")?)?;
let layers: Vec<usize> = layers
.into_iter()
.map(|l| usize_of_index(index, l))
.collect();
let channels = layers.iter().map(|&l| p[l].0).sum();
Ok((channels, Bl::Route(layers)))
}
fn shortcut(index: usize, p: usize, block: &Block) -> Result<(usize, Bl)> {
let from = block.get("from")?.parse::<i64>()?;
Ok((p, Bl::Shortcut(usize_of_index(index, from))))
}
fn yolo(p: usize, block: &Block) -> Result<(usize, Bl)> {
let classes = block.get("classes")?.parse::<usize>()?;
let flat = int_list_of_string(block.get("anchors")?)?;
if flat.len() % 2 != 0 {
candle::bail!("even number of anchors");
}
let flat = flat.into_iter().map(|i| i as usize).collect::<Vec<_>>();
let anchors: Vec<_> = (0..(flat.len() / 2))
.map(|i| (flat[2 * i], flat[2 * i + 1]))
.collect();
let mask = int_list_of_string(block.get("mask")?)?;
let anchors = mask.into_iter().map(|i| anchors[i as usize]).collect();
Ok((p, Bl::Yolo(classes, anchors)))
}
fn detect(
xs: &Tensor,
image_height: usize,
classes: usize,
anchors: &Vec<(usize, usize)>,
) -> Result<Tensor> {
let (bsize, _channels, height, _width) = xs.dims4()?;
let stride = image_height / height;
let grid_size = image_height / stride;
let bbox_attrs = 5 + classes;
let nanchors = anchors.len();
let xs = xs
.reshape((bsize, bbox_attrs * nanchors, grid_size * grid_size))?
.transpose(1, 2)?
.contiguous()?
.reshape((bsize, grid_size * grid_size * nanchors, bbox_attrs))?;
let grid = Tensor::arange(0u32, grid_size as u32, &Device::Cpu)?;
let a = grid.repeat((grid_size, 1))?;
let b = a.t()?.contiguous()?;
let x_offset = a.unsqueeze(1)?;
let y_offset = b.unsqueeze(1)?;
let xy_offset = Tensor::cat(&[&x_offset, &y_offset], 1)?
.repeat((1, nanchors))?
.reshape((grid_size * grid_size * nanchors, 2))?
.unsqueeze(0)?;
let anchors: Vec<f32> = anchors
.iter()
.flat_map(|&(x, y)| vec![x as f32 / stride as f32, y as f32 / stride as f32].into_iter())
.collect();
let anchors = Tensor::new(anchors.as_slice(), &Device::Cpu)?
.reshape((anchors.len() / 2, 2))?
.repeat((grid_size * grid_size, 1))?
.unsqueeze(0)?;
let ys02 = xs.i((.., .., 0..2))?;
let ys24 = xs.i((.., .., 2..4))?;
let ys4 = xs.i((.., .., 4..))?;
let ys02 = (candle_nn::ops::sigmoid(&ys02)?.add(&xy_offset)? * stride as f64)?;
let ys24 = (ys24.exp()?.mul(&anchors)? * stride as f64)?;
let ys4 = candle_nn::ops::sigmoid(&ys4)?;
Tensor::cat(&[ys02, ys24, ys4], 2)
}
impl Darknet {
pub fn height(&self) -> Result<usize> {
let image_height = self.get("height")?.parse::<usize>()?;
Ok(image_height)
}
pub fn width(&self) -> Result<usize> {
let image_width = self.get("width")?.parse::<usize>()?;
Ok(image_width)
}
pub fn build_model(&self, vb: VarBuilder) -> Result<Func> {
let mut blocks: Vec<(usize, Bl)> = vec![];
let mut prev_channels: usize = 3;
for (index, block) in self.blocks.iter().enumerate() {
let channels_and_bl = match block.block_type.as_str() {
"convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?,
"upsample" => upsample(prev_channels)?,
"shortcut" => shortcut(index, prev_channels, block)?,
"route" => route(index, &blocks, block)?,
"yolo" => yolo(prev_channels, block)?,
otherwise => candle::bail!("unsupported block type {}", otherwise),
};
prev_channels = channels_and_bl.0;
blocks.push(channels_and_bl);
}
let image_height = self.height()?;
let func = candle_nn::func(move |xs| {
let mut prev_ys: Vec<Tensor> = vec![];
let mut detections: Vec<Tensor> = vec![];
for (_, b) in blocks.iter() {
let ys = match b {
Bl::Layer(l) => {
let xs = prev_ys.last().unwrap_or(xs);
l.forward(xs)?
}
Bl::Route(layers) => {
let layers: Vec<_> = layers.iter().map(|&i| &prev_ys[i]).collect();
Tensor::cat(&layers, 1)?
}
Bl::Shortcut(from) => (prev_ys.last().unwrap() + prev_ys.get(*from).unwrap())?,
Bl::Yolo(classes, anchors) => {
let xs = prev_ys.last().unwrap_or(xs);
detections.push(detect(xs, image_height, *classes, anchors)?);
Tensor::new(&[0u32], &Device::Cpu)?
}
};
prev_ys.push(ys);
}
Tensor::cat(&detections, 1)
});
Ok(func)
}
}

View File

@ -0,0 +1,145 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod coco_classes;
mod darknet;
use anyhow::Result;
use candle::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
use clap::Parser;
const CONFIG_NAME: &str = "candle-examples/examples/yolo/yolo-v3.cfg";
const CONFIDENCE_THRESHOLD: f64 = 0.5;
const NMS_THRESHOLD: f64 = 0.4;
#[derive(Debug, Clone, Copy)]
struct Bbox {
xmin: f64,
ymin: f64,
xmax: f64,
ymax: f64,
confidence: f64,
}
// Intersection over union of two bounding boxes.
fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
// Assumes x1 <= x2 and y1 <= y2
pub fn draw_rect(_: &mut Tensor, _x1: usize, _x2: usize, _y1: usize, _y2: usize) {
todo!()
}
pub fn report(pred: &Tensor, img: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let (npreds, pred_size) = pred.dims2()?;
let nclasses = pred_size - 5;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f64>::try_from(pred.get(index)?)?;
let confidence = pred[4];
if confidence > CONFIDENCE_THRESHOLD {
let mut class_index = 0;
for i in 0..nclasses {
if pred[5 + i] > pred[5 + class_index] {
class_index = i
}
}
if pred[class_index + 5] > 0. {
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
};
bboxes[class_index].push(bbox)
}
}
}
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > NMS_THRESHOLD {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
// Annotate the original image and print boxes information.
let (_, initial_h, initial_w) = img.dims3()?;
let mut img = (img.to_dtype(DType::F32)? * (1. / 255.))?;
let w_ratio = initial_w as f64 / w as f64;
let h_ratio = initial_h as f64 / h as f64;
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
let xmin = ((b.xmin * w_ratio) as usize).clamp(0, initial_w - 1);
let ymin = ((b.ymin * h_ratio) as usize).clamp(0, initial_h - 1);
let xmax = ((b.xmax * w_ratio) as usize).clamp(0, initial_w - 1);
let ymax = ((b.ymax * h_ratio) as usize).clamp(0, initial_h - 1);
draw_rect(&mut img, xmin, xmax, ymin, ymax.min(ymin + 2));
draw_rect(&mut img, xmin, xmax, ymin.max(ymax - 2), ymax);
draw_rect(&mut img, xmin, xmax.min(xmin + 2), ymin, ymax);
draw_rect(&mut img, xmin.max(xmax - 2), xmax, ymin, ymax);
}
}
Ok((img * 255.)?.to_dtype(DType::U8)?)
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Model weights, in safetensors format.
#[arg(long)]
model: String,
images: Vec<String>,
}
pub fn main() -> Result<()> {
let args = Args::parse();
// Create the model and load the weights from the file.
let weights = unsafe { candle::safetensors::MmapedFile::new(&args.model)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let darknet = darknet::parse_config(CONFIG_NAME)?;
let model = darknet.build_model(vb)?;
for image in args.images.iter() {
// Load the image file and resize it.
let net_width = darknet.width()?;
let net_height = darknet.height()?;
let image = candle_examples::load_image_and_resize(image, net_width, net_height)?;
let image = (image.unsqueeze(0)?.to_dtype(DType::F32)? * (1. / 255.))?;
let predictions = model.forward(&image)?.squeeze(0)?;
let _image = report(&predictions, &image, net_width, net_height)?;
println!("converted {image}");
}
Ok(())
}

View File

@ -0,0 +1,790 @@
[net]
# Testing
batch=1
subdivisions=1
# Training
# batch=64
# subdivisions=16
width= 416
height = 416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
# Downsample
[convolutional]
batch_normalize=1
filters=64
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
######################
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 6,7,8
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .5
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 61
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .5
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 36
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .5
truth_thresh = 1
random=1