Add an initial Segformer implementation (#1617)

* add segformer

* Make the id2label field optional.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jiayu Liu
2024-03-03 23:01:46 +08:00
committed by GitHub
parent 60dc72b96b
commit 924ccae30c
5 changed files with 1641 additions and 0 deletions

View File

@ -0,0 +1,28 @@
# candle-segformer
- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
## How to run the example
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
```bash
# run the image classification task
cargo run --example segformer classify <path-to-image>
# run the segmentation task
cargo run --example segformer segment <path-to-image>
```
Example output for classification:
```text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[pr]: https://github.com/huggingface/candle/pull/1617
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

View File

@ -0,0 +1,752 @@
[
{
"index": 1,
"color": "#787878",
"label": "wall"
},
{
"index": 2,
"color": "#B47878",
"label": "building;edifice"
},
{
"index": 3,
"color": "#06E6E6",
"label": "sky"
},
{
"index": 4,
"color": "#503232",
"label": "floor;flooring"
},
{
"index": 5,
"color": "#04C803",
"label": "tree"
},
{
"index": 6,
"color": "#787850",
"label": "ceiling"
},
{
"index": 7,
"color": "#8C8C8C",
"label": "road;route"
},
{
"index": 8,
"color": "#CC05FF",
"label": "bed"
},
{
"index": 9,
"color": "#E6E6E6",
"label": "windowpane;window"
},
{
"index": 10,
"color": "#04FA07",
"label": "grass"
},
{
"index": 11,
"color": "#E005FF",
"label": "cabinet"
},
{
"index": 12,
"color": "#EBFF07",
"label": "sidewalk;pavement"
},
{
"index": 13,
"color": "#96053D",
"label": "person;individual;someone;somebody;mortal;soul"
},
{
"index": 14,
"color": "#787846",
"label": "earth;ground"
},
{
"index": 15,
"color": "#08FF33",
"label": "door;double;door"
},
{
"index": 16,
"color": "#FF0652",
"label": "table"
},
{
"index": 17,
"color": "#8FFF8C",
"label": "mountain;mount"
},
{
"index": 18,
"color": "#CCFF04",
"label": "plant;flora;plant;life"
},
{
"index": 19,
"color": "#FF3307",
"label": "curtain;drape;drapery;mantle;pall"
},
{
"index": 20,
"color": "#CC4603",
"label": "chair"
},
{
"index": 21,
"color": "#0066C8",
"label": "car;auto;automobile;machine;motorcar"
},
{
"index": 22,
"color": "#3DE6FA",
"label": "water"
},
{
"index": 23,
"color": "#FF0633",
"label": "painting;picture"
},
{
"index": 24,
"color": "#0B66FF",
"label": "sofa;couch;lounge"
},
{
"index": 25,
"color": "#FF0747",
"label": "shelf"
},
{
"index": 26,
"color": "#FF09E0",
"label": "house"
},
{
"index": 27,
"color": "#0907E6",
"label": "sea"
},
{
"index": 28,
"color": "#DCDCDC",
"label": "mirror"
},
{
"index": 29,
"color": "#FF095C",
"label": "rug;carpet;carpeting"
},
{
"index": 30,
"color": "#7009FF",
"label": "field"
},
{
"index": 31,
"color": "#08FFD6",
"label": "armchair"
},
{
"index": 32,
"color": "#07FFE0",
"label": "seat"
},
{
"index": 33,
"color": "#FFB806",
"label": "fence;fencing"
},
{
"index": 34,
"color": "#0AFF47",
"label": "desk"
},
{
"index": 35,
"color": "#FF290A",
"label": "rock;stone"
},
{
"index": 36,
"color": "#07FFFF",
"label": "wardrobe;closet;press"
},
{
"index": 37,
"color": "#E0FF08",
"label": "lamp"
},
{
"index": 38,
"color": "#6608FF",
"label": "bathtub;bathing;tub;bath;tub"
},
{
"index": 39,
"color": "#FF3D06",
"label": "railing;rail"
},
{
"index": 40,
"color": "#FFC207",
"label": "cushion"
},
{
"index": 41,
"color": "#FF7A08",
"label": "base;pedestal;stand"
},
{
"index": 42,
"color": "#00FF14",
"label": "box"
},
{
"index": 43,
"color": "#FF0829",
"label": "column;pillar"
},
{
"index": 44,
"color": "#FF0599",
"label": "signboard;sign"
},
{
"index": 45,
"color": "#0633FF",
"label": "chest;of;drawers;chest;bureau;dresser"
},
{
"index": 46,
"color": "#EB0CFF",
"label": "counter"
},
{
"index": 47,
"color": "#A09614",
"label": "sand"
},
{
"index": 48,
"color": "#00A3FF",
"label": "sink"
},
{
"index": 49,
"color": "#8C8C8C",
"label": "skyscraper"
},
{
"index": 50,
"color": "#FA0A0F",
"label": "fireplace;hearth;open;fireplace"
},
{
"index": 51,
"color": "#14FF00",
"label": "refrigerator;icebox"
},
{
"index": 52,
"color": "#1FFF00",
"label": "grandstand;covered;stand"
},
{
"index": 53,
"color": "#FF1F00",
"label": "path"
},
{
"index": 54,
"color": "#FFE000",
"label": "stairs;steps"
},
{
"index": 55,
"color": "#99FF00",
"label": "runway"
},
{
"index": 56,
"color": "#0000FF",
"label": "case;display;case;showcase;vitrine"
},
{
"index": 57,
"color": "#FF4700",
"label": "pool;table;billiard;table;snooker;table"
},
{
"index": 58,
"color": "#00EBFF",
"label": "pillow"
},
{
"index": 59,
"color": "#00ADFF",
"label": "screen;door;screen"
},
{
"index": 60,
"color": "#1F00FF",
"label": "stairway;staircase"
},
{
"index": 61,
"color": "#0BC8C8",
"label": "river"
},
{
"index": 62,
"color": "#FF5200",
"label": "bridge;span"
},
{
"index": 63,
"color": "#00FFF5",
"label": "bookcase"
},
{
"index": 64,
"color": "#003DFF",
"label": "blind;screen"
},
{
"index": 65,
"color": "#00FF70",
"label": "coffee;table;cocktail;table"
},
{
"index": 66,
"color": "#00FF85",
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index": 67,
"color": "#FF0000",
"label": "flower"
},
{
"index": 68,
"color": "#FFA300",
"label": "book"
},
{
"index": 69,
"color": "#FF6600",
"label": "hill"
},
{
"index": 70,
"color": "#C2FF00",
"label": "bench"
},
{
"index": 71,
"color": "#008FFF",
"label": "countertop"
},
{
"index": 72,
"color": "#33FF00",
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index": 73,
"color": "#0052FF",
"label": "palm;palm;tree"
},
{
"index": 74,
"color": "#00FF29",
"label": "kitchen;island"
},
{
"index": 75,
"color": "#00FFAD",
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index": 76,
"color": "#0A00FF",
"label": "swivel;chair"
},
{
"index": 77,
"color": "#ADFF00",
"label": "boat"
},
{
"index": 78,
"color": "#00FF99",
"label": "bar"
},
{
"index": 79,
"color": "#FF5C00",
"label": "arcade;machine"
},
{
"index": 80,
"color": "#FF00FF",
"label": "hovel;hut;hutch;shack;shanty"
},
{
"index": 81,
"color": "#FF00F5",
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index": 82,
"color": "#FF0066",
"label": "towel"
},
{
"index": 83,
"color": "#FFAD00",
"label": "light;light;source"
},
{
"index": 84,
"color": "#FF0014",
"label": "truck;motortruck"
},
{
"index": 85,
"color": "#FFB8B8",
"label": "tower"
},
{
"index": 86,
"color": "#001FFF",
"label": "chandelier;pendant;pendent"
},
{
"index": 87,
"color": "#00FF3D",
"label": "awning;sunshade;sunblind"
},
{
"index": 88,
"color": "#0047FF",
"label": "streetlight;street;lamp"
},
{
"index": 89,
"color": "#FF00CC",
"label": "booth;cubicle;stall;kiosk"
},
{
"index": 90,
"color": "#00FFC2",
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index": 91,
"color": "#00FF52",
"label": "airplane;aeroplane;plane"
},
{
"index": 92,
"color": "#000AFF",
"label": "dirt;track"
},
{
"index": 93,
"color": "#0070FF",
"label": "apparel;wearing;apparel;dress;clothes"
},
{
"index": 94,
"color": "#3300FF",
"label": "pole"
},
{
"index": 95,
"color": "#00C2FF",
"label": "land;ground;soil"
},
{
"index": 96,
"color": "#007AFF",
"label": "bannister;banister;balustrade;balusters;handrail"
},
{
"index": 97,
"color": "#00FFA3",
"label": "escalator;moving;staircase;moving;stairway"
},
{
"index": 98,
"color": "#FF9900",
"label": "ottoman;pouf;pouffe;puff;hassock"
},
{
"index": 99,
"color": "#00FF0A",
"label": "bottle"
},
{
"index": 100,
"color": "#FF7000",
"label": "buffet;counter;sideboard"
},
{
"index": 101,
"color": "#8FFF00",
"label": "poster;posting;placard;notice;bill;card"
},
{
"index": 102,
"color": "#5200FF",
"label": "stage"
},
{
"index": 103,
"color": "#A3FF00",
"label": "van"
},
{
"index": 104,
"color": "#FFEB00",
"label": "ship"
},
{
"index": 105,
"color": "#08B8AA",
"label": "fountain"
},
{
"index": 106,
"color": "#8500FF",
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index": 107,
"color": "#00FF5C",
"label": "canopy"
},
{
"index": 108,
"color": "#B800FF",
"label": "washer;automatic;washer;washing;machine"
},
{
"index": 109,
"color": "#FF001F",
"label": "plaything;toy"
},
{
"index": 110,
"color": "#00B8FF",
"label": "swimming;pool;swimming;bath;natatorium"
},
{
"index": 111,
"color": "#00D6FF",
"label": "stool"
},
{
"index": 112,
"color": "#FF0070",
"label": "barrel;cask"
},
{
"index": 113,
"color": "#5CFF00",
"label": "basket;handbasket"
},
{
"index": 114,
"color": "#00E0FF",
"label": "waterfall;falls"
},
{
"index": 115,
"color": "#70E0FF",
"label": "tent;collapsible;shelter"
},
{
"index": 116,
"color": "#46B8A0",
"label": "bag"
},
{
"index": 117,
"color": "#A300FF",
"label": "minibike;motorbike"
},
{
"index": 118,
"color": "#9900FF",
"label": "cradle"
},
{
"index": 119,
"color": "#47FF00",
"label": "oven"
},
{
"index": 120,
"color": "#FF00A3",
"label": "ball"
},
{
"index": 121,
"color": "#FFCC00",
"label": "food;solid;food"
},
{
"index": 122,
"color": "#FF008F",
"label": "step;stair"
},
{
"index": 123,
"color": "#00FFEB",
"label": "tank;storage;tank"
},
{
"index": 124,
"color": "#85FF00",
"label": "trade;name;brand;name;brand;marque"
},
{
"index": 125,
"color": "#FF00EB",
"label": "microwave;microwave;oven"
},
{
"index": 126,
"color": "#F500FF",
"label": "pot;flowerpot"
},
{
"index": 127,
"color": "#FF007A",
"label": "animal;animate;being;beast;brute;creature;fauna"
},
{
"index": 128,
"color": "#FFF500",
"label": "bicycle;bike;wheel;cycle"
},
{
"index": 129,
"color": "#0ABED4",
"label": "lake"
},
{
"index": 130,
"color": "#D6FF00",
"label": "dishwasher;dish;washer;dishwashing;machine"
},
{
"index": 131,
"color": "#00CCFF",
"label": "screen;silver;screen;projection;screen"
},
{
"index": 132,
"color": "#1400FF",
"label": "blanket;cover"
},
{
"index": 133,
"color": "#FFFF00",
"label": "sculpture"
},
{
"index": 134,
"color": "#0099FF",
"label": "hood;exhaust;hood"
},
{
"index": 135,
"color": "#0029FF",
"label": "sconce"
},
{
"index": 136,
"color": "#00FFCC",
"label": "vase"
},
{
"index": 137,
"color": "#2900FF",
"label": "traffic;light;traffic;signal;stoplight"
},
{
"index": 138,
"color": "#29FF00",
"label": "tray"
},
{
"index": 139,
"color": "#AD00FF",
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index": 140,
"color": "#00F5FF",
"label": "fan"
},
{
"index": 141,
"color": "#4700FF",
"label": "pier;wharf;wharfage;dock"
},
{
"index": 142,
"color": "#7A00FF",
"label": "crt;screen"
},
{
"index": 143,
"color": "#00FFB8",
"label": "plate"
},
{
"index": 144,
"color": "#005CFF",
"label": "monitor;monitoring;device"
},
{
"index": 145,
"color": "#B8FF00",
"label": "bulletin;board;notice;board"
},
{
"index": 146,
"color": "#0085FF",
"label": "shower"
},
{
"index": 147,
"color": "#FFD600",
"label": "radiator"
},
{
"index": 148,
"color": "#19C2C2",
"label": "glass;drinking;glass"
},
{
"index": 149,
"color": "#66FF00",
"label": "clock"
},
{
"index": 150,
"color": "#5C00FF",
"label": "flag"
}
]

View File

@ -0,0 +1,155 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(
long,
help = "path to the label file in json format",
default_value = "candle-examples/examples/segformer/assets/labels.json"
)]
label_path: PathBuf,
#[arg(long, help = "path to for the output mask image")]
output_path: PathBuf,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}
#[derive(Debug, serde::Deserialize)]
struct LabelItem {
index: u32,
color: String,
}
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let label_file = std::fs::read_to_string(&args.label_path)?;
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
let label_colors: HashMap<u32, Rgb<u8>> = label_items
.iter()
.map(|x| {
(x.index - 1, {
let color = x.color.trim_start_matches('#');
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
Rgb([r, g, b])
})
})
.collect();
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = label_items.len();
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
// generate a mask image
let mask = &segmentations.squeeze(0)?.argmax(0)?;
let (h, w) = mask.dims2()?;
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
let mask = mask
.iter()
.flat_map(|x| label_colors[x].data())
.collect::<Vec<u8>>();
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
// resize
let mask = image::DynamicImage::from(mask);
let mask = mask.resize_to_fill(
w as u32 * 4,
h as u32 * 4,
image::imageops::FilterType::CatmullRom,
);
mask.save(args.output_path.clone())?;
println!("mask image saved to {:?}", args.output_path);
Ok(())
}
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
let classification = classification.squeeze(0)?;
println!(
"classification logits {:?}",
classification.to_vec1::<f32>()?
);
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
let label_id = format!("{}", label_id);
println!("label: {}", config.id2label[&label_id]);
Ok(())
}
pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}