Add an initial Segformer implementation (#1617)

* add segformer

* Make the id2label field optional.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jiayu Liu
2024-03-03 23:01:46 +08:00
committed by GitHub
parent 60dc72b96b
commit 924ccae30c
5 changed files with 1641 additions and 0 deletions

View File

@ -0,0 +1,28 @@
# candle-segformer
- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
## How to run the example
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
```bash
# run the image classification task
cargo run --example segformer classify <path-to-image>
# run the segmentation task
cargo run --example segformer segment <path-to-image>
```
Example output for classification:
```text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[pr]: https://github.com/huggingface/candle/pull/1617
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

View File

@ -0,0 +1,752 @@
[
{
"index": 1,
"color": "#787878",
"label": "wall"
},
{
"index": 2,
"color": "#B47878",
"label": "building;edifice"
},
{
"index": 3,
"color": "#06E6E6",
"label": "sky"
},
{
"index": 4,
"color": "#503232",
"label": "floor;flooring"
},
{
"index": 5,
"color": "#04C803",
"label": "tree"
},
{
"index": 6,
"color": "#787850",
"label": "ceiling"
},
{
"index": 7,
"color": "#8C8C8C",
"label": "road;route"
},
{
"index": 8,
"color": "#CC05FF",
"label": "bed"
},
{
"index": 9,
"color": "#E6E6E6",
"label": "windowpane;window"
},
{
"index": 10,
"color": "#04FA07",
"label": "grass"
},
{
"index": 11,
"color": "#E005FF",
"label": "cabinet"
},
{
"index": 12,
"color": "#EBFF07",
"label": "sidewalk;pavement"
},
{
"index": 13,
"color": "#96053D",
"label": "person;individual;someone;somebody;mortal;soul"
},
{
"index": 14,
"color": "#787846",
"label": "earth;ground"
},
{
"index": 15,
"color": "#08FF33",
"label": "door;double;door"
},
{
"index": 16,
"color": "#FF0652",
"label": "table"
},
{
"index": 17,
"color": "#8FFF8C",
"label": "mountain;mount"
},
{
"index": 18,
"color": "#CCFF04",
"label": "plant;flora;plant;life"
},
{
"index": 19,
"color": "#FF3307",
"label": "curtain;drape;drapery;mantle;pall"
},
{
"index": 20,
"color": "#CC4603",
"label": "chair"
},
{
"index": 21,
"color": "#0066C8",
"label": "car;auto;automobile;machine;motorcar"
},
{
"index": 22,
"color": "#3DE6FA",
"label": "water"
},
{
"index": 23,
"color": "#FF0633",
"label": "painting;picture"
},
{
"index": 24,
"color": "#0B66FF",
"label": "sofa;couch;lounge"
},
{
"index": 25,
"color": "#FF0747",
"label": "shelf"
},
{
"index": 26,
"color": "#FF09E0",
"label": "house"
},
{
"index": 27,
"color": "#0907E6",
"label": "sea"
},
{
"index": 28,
"color": "#DCDCDC",
"label": "mirror"
},
{
"index": 29,
"color": "#FF095C",
"label": "rug;carpet;carpeting"
},
{
"index": 30,
"color": "#7009FF",
"label": "field"
},
{
"index": 31,
"color": "#08FFD6",
"label": "armchair"
},
{
"index": 32,
"color": "#07FFE0",
"label": "seat"
},
{
"index": 33,
"color": "#FFB806",
"label": "fence;fencing"
},
{
"index": 34,
"color": "#0AFF47",
"label": "desk"
},
{
"index": 35,
"color": "#FF290A",
"label": "rock;stone"
},
{
"index": 36,
"color": "#07FFFF",
"label": "wardrobe;closet;press"
},
{
"index": 37,
"color": "#E0FF08",
"label": "lamp"
},
{
"index": 38,
"color": "#6608FF",
"label": "bathtub;bathing;tub;bath;tub"
},
{
"index": 39,
"color": "#FF3D06",
"label": "railing;rail"
},
{
"index": 40,
"color": "#FFC207",
"label": "cushion"
},
{
"index": 41,
"color": "#FF7A08",
"label": "base;pedestal;stand"
},
{
"index": 42,
"color": "#00FF14",
"label": "box"
},
{
"index": 43,
"color": "#FF0829",
"label": "column;pillar"
},
{
"index": 44,
"color": "#FF0599",
"label": "signboard;sign"
},
{
"index": 45,
"color": "#0633FF",
"label": "chest;of;drawers;chest;bureau;dresser"
},
{
"index": 46,
"color": "#EB0CFF",
"label": "counter"
},
{
"index": 47,
"color": "#A09614",
"label": "sand"
},
{
"index": 48,
"color": "#00A3FF",
"label": "sink"
},
{
"index": 49,
"color": "#8C8C8C",
"label": "skyscraper"
},
{
"index": 50,
"color": "#FA0A0F",
"label": "fireplace;hearth;open;fireplace"
},
{
"index": 51,
"color": "#14FF00",
"label": "refrigerator;icebox"
},
{
"index": 52,
"color": "#1FFF00",
"label": "grandstand;covered;stand"
},
{
"index": 53,
"color": "#FF1F00",
"label": "path"
},
{
"index": 54,
"color": "#FFE000",
"label": "stairs;steps"
},
{
"index": 55,
"color": "#99FF00",
"label": "runway"
},
{
"index": 56,
"color": "#0000FF",
"label": "case;display;case;showcase;vitrine"
},
{
"index": 57,
"color": "#FF4700",
"label": "pool;table;billiard;table;snooker;table"
},
{
"index": 58,
"color": "#00EBFF",
"label": "pillow"
},
{
"index": 59,
"color": "#00ADFF",
"label": "screen;door;screen"
},
{
"index": 60,
"color": "#1F00FF",
"label": "stairway;staircase"
},
{
"index": 61,
"color": "#0BC8C8",
"label": "river"
},
{
"index": 62,
"color": "#FF5200",
"label": "bridge;span"
},
{
"index": 63,
"color": "#00FFF5",
"label": "bookcase"
},
{
"index": 64,
"color": "#003DFF",
"label": "blind;screen"
},
{
"index": 65,
"color": "#00FF70",
"label": "coffee;table;cocktail;table"
},
{
"index": 66,
"color": "#00FF85",
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index": 67,
"color": "#FF0000",
"label": "flower"
},
{
"index": 68,
"color": "#FFA300",
"label": "book"
},
{
"index": 69,
"color": "#FF6600",
"label": "hill"
},
{
"index": 70,
"color": "#C2FF00",
"label": "bench"
},
{
"index": 71,
"color": "#008FFF",
"label": "countertop"
},
{
"index": 72,
"color": "#33FF00",
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index": 73,
"color": "#0052FF",
"label": "palm;palm;tree"
},
{
"index": 74,
"color": "#00FF29",
"label": "kitchen;island"
},
{
"index": 75,
"color": "#00FFAD",
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index": 76,
"color": "#0A00FF",
"label": "swivel;chair"
},
{
"index": 77,
"color": "#ADFF00",
"label": "boat"
},
{
"index": 78,
"color": "#00FF99",
"label": "bar"
},
{
"index": 79,
"color": "#FF5C00",
"label": "arcade;machine"
},
{
"index": 80,
"color": "#FF00FF",
"label": "hovel;hut;hutch;shack;shanty"
},
{
"index": 81,
"color": "#FF00F5",
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index": 82,
"color": "#FF0066",
"label": "towel"
},
{
"index": 83,
"color": "#FFAD00",
"label": "light;light;source"
},
{
"index": 84,
"color": "#FF0014",
"label": "truck;motortruck"
},
{
"index": 85,
"color": "#FFB8B8",
"label": "tower"
},
{
"index": 86,
"color": "#001FFF",
"label": "chandelier;pendant;pendent"
},
{
"index": 87,
"color": "#00FF3D",
"label": "awning;sunshade;sunblind"
},
{
"index": 88,
"color": "#0047FF",
"label": "streetlight;street;lamp"
},
{
"index": 89,
"color": "#FF00CC",
"label": "booth;cubicle;stall;kiosk"
},
{
"index": 90,
"color": "#00FFC2",
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index": 91,
"color": "#00FF52",
"label": "airplane;aeroplane;plane"
},
{
"index": 92,
"color": "#000AFF",
"label": "dirt;track"
},
{
"index": 93,
"color": "#0070FF",
"label": "apparel;wearing;apparel;dress;clothes"
},
{
"index": 94,
"color": "#3300FF",
"label": "pole"
},
{
"index": 95,
"color": "#00C2FF",
"label": "land;ground;soil"
},
{
"index": 96,
"color": "#007AFF",
"label": "bannister;banister;balustrade;balusters;handrail"
},
{
"index": 97,
"color": "#00FFA3",
"label": "escalator;moving;staircase;moving;stairway"
},
{
"index": 98,
"color": "#FF9900",
"label": "ottoman;pouf;pouffe;puff;hassock"
},
{
"index": 99,
"color": "#00FF0A",
"label": "bottle"
},
{
"index": 100,
"color": "#FF7000",
"label": "buffet;counter;sideboard"
},
{
"index": 101,
"color": "#8FFF00",
"label": "poster;posting;placard;notice;bill;card"
},
{
"index": 102,
"color": "#5200FF",
"label": "stage"
},
{
"index": 103,
"color": "#A3FF00",
"label": "van"
},
{
"index": 104,
"color": "#FFEB00",
"label": "ship"
},
{
"index": 105,
"color": "#08B8AA",
"label": "fountain"
},
{
"index": 106,
"color": "#8500FF",
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index": 107,
"color": "#00FF5C",
"label": "canopy"
},
{
"index": 108,
"color": "#B800FF",
"label": "washer;automatic;washer;washing;machine"
},
{
"index": 109,
"color": "#FF001F",
"label": "plaything;toy"
},
{
"index": 110,
"color": "#00B8FF",
"label": "swimming;pool;swimming;bath;natatorium"
},
{
"index": 111,
"color": "#00D6FF",
"label": "stool"
},
{
"index": 112,
"color": "#FF0070",
"label": "barrel;cask"
},
{
"index": 113,
"color": "#5CFF00",
"label": "basket;handbasket"
},
{
"index": 114,
"color": "#00E0FF",
"label": "waterfall;falls"
},
{
"index": 115,
"color": "#70E0FF",
"label": "tent;collapsible;shelter"
},
{
"index": 116,
"color": "#46B8A0",
"label": "bag"
},
{
"index": 117,
"color": "#A300FF",
"label": "minibike;motorbike"
},
{
"index": 118,
"color": "#9900FF",
"label": "cradle"
},
{
"index": 119,
"color": "#47FF00",
"label": "oven"
},
{
"index": 120,
"color": "#FF00A3",
"label": "ball"
},
{
"index": 121,
"color": "#FFCC00",
"label": "food;solid;food"
},
{
"index": 122,
"color": "#FF008F",
"label": "step;stair"
},
{
"index": 123,
"color": "#00FFEB",
"label": "tank;storage;tank"
},
{
"index": 124,
"color": "#85FF00",
"label": "trade;name;brand;name;brand;marque"
},
{
"index": 125,
"color": "#FF00EB",
"label": "microwave;microwave;oven"
},
{
"index": 126,
"color": "#F500FF",
"label": "pot;flowerpot"
},
{
"index": 127,
"color": "#FF007A",
"label": "animal;animate;being;beast;brute;creature;fauna"
},
{
"index": 128,
"color": "#FFF500",
"label": "bicycle;bike;wheel;cycle"
},
{
"index": 129,
"color": "#0ABED4",
"label": "lake"
},
{
"index": 130,
"color": "#D6FF00",
"label": "dishwasher;dish;washer;dishwashing;machine"
},
{
"index": 131,
"color": "#00CCFF",
"label": "screen;silver;screen;projection;screen"
},
{
"index": 132,
"color": "#1400FF",
"label": "blanket;cover"
},
{
"index": 133,
"color": "#FFFF00",
"label": "sculpture"
},
{
"index": 134,
"color": "#0099FF",
"label": "hood;exhaust;hood"
},
{
"index": 135,
"color": "#0029FF",
"label": "sconce"
},
{
"index": 136,
"color": "#00FFCC",
"label": "vase"
},
{
"index": 137,
"color": "#2900FF",
"label": "traffic;light;traffic;signal;stoplight"
},
{
"index": 138,
"color": "#29FF00",
"label": "tray"
},
{
"index": 139,
"color": "#AD00FF",
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index": 140,
"color": "#00F5FF",
"label": "fan"
},
{
"index": 141,
"color": "#4700FF",
"label": "pier;wharf;wharfage;dock"
},
{
"index": 142,
"color": "#7A00FF",
"label": "crt;screen"
},
{
"index": 143,
"color": "#00FFB8",
"label": "plate"
},
{
"index": 144,
"color": "#005CFF",
"label": "monitor;monitoring;device"
},
{
"index": 145,
"color": "#B8FF00",
"label": "bulletin;board;notice;board"
},
{
"index": 146,
"color": "#0085FF",
"label": "shower"
},
{
"index": 147,
"color": "#FFD600",
"label": "radiator"
},
{
"index": 148,
"color": "#19C2C2",
"label": "glass;drinking;glass"
},
{
"index": 149,
"color": "#66FF00",
"label": "clock"
},
{
"index": 150,
"color": "#5C00FF",
"label": "flag"
}
]

View File

@ -0,0 +1,155 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(
long,
help = "path to the label file in json format",
default_value = "candle-examples/examples/segformer/assets/labels.json"
)]
label_path: PathBuf,
#[arg(long, help = "path to for the output mask image")]
output_path: PathBuf,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}
#[derive(Debug, serde::Deserialize)]
struct LabelItem {
index: u32,
color: String,
}
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let label_file = std::fs::read_to_string(&args.label_path)?;
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
let label_colors: HashMap<u32, Rgb<u8>> = label_items
.iter()
.map(|x| {
(x.index - 1, {
let color = x.color.trim_start_matches('#');
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
Rgb([r, g, b])
})
})
.collect();
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = label_items.len();
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
// generate a mask image
let mask = &segmentations.squeeze(0)?.argmax(0)?;
let (h, w) = mask.dims2()?;
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
let mask = mask
.iter()
.flat_map(|x| label_colors[x].data())
.collect::<Vec<u8>>();
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
// resize
let mask = image::DynamicImage::from(mask);
let mask = mask.resize_to_fill(
w as u32 * 4,
h as u32 * 4,
image::imageops::FilterType::CatmullRom,
);
mask.save(args.output_path.clone())?;
println!("mask image saved to {:?}", args.output_path);
Ok(())
}
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
let classification = classification.squeeze(0)?;
println!(
"classification logits {:?}",
classification.to_vec1::<f32>()?
);
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
let label_id = format!("{}", label_id);
println!("label: {}", config.id2label[&label_id]);
Ok(())
}
pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}

View File

@ -42,6 +42,7 @@ pub mod repvgg;
pub mod resnet;
pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;

View File

@ -0,0 +1,705 @@
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
#[serde(default)]
pub id2label: HashMap<String, String>,
pub num_channels: usize,
pub num_encoder_blocks: usize,
pub depths: Vec<usize>,
pub sr_ratios: Vec<usize>,
pub hidden_sizes: Vec<usize>,
pub patch_sizes: Vec<usize>,
pub strides: Vec<usize>,
pub num_attention_heads: Vec<usize>,
pub mlp_ratios: Vec<usize>,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub decoder_hidden_size: usize,
}
#[derive(Debug, Clone)]
struct SegformerOverlapPatchEmbeddings {
projection: Conv2d,
layer_norm: candle_nn::LayerNorm,
}
impl SegformerOverlapPatchEmbeddings {
fn new(
config: &Config,
patch_size: usize,
stride: usize,
num_channels: usize,
hidden_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let projection = conv2d(
num_channels,
hidden_size,
patch_size,
Conv2dConfig {
stride,
padding: patch_size / 2,
..Default::default()
},
vb.pp("proj"),
)?;
let layer_norm =
candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
Ok(Self {
projection,
layer_norm,
})
}
}
impl Module for SegformerOverlapPatchEmbeddings {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let embeddings = self.projection.forward(x)?;
let shape = embeddings.shape();
// [B, C, H, W] -> [B, H * W, C]
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
// [B, H * W, C] -> [B, C, H, W]
let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
Ok(embeddings)
}
}
#[derive(Debug, Clone)]
struct SegformerEfficientSelfAttention {
num_attention_heads: usize,
attention_head_size: usize,
query: Linear,
key: Linear,
value: Linear,
sr: Option<Conv2d>,
layer_norm: Option<layer_norm::LayerNorm>,
}
impl SegformerEfficientSelfAttention {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
if hidden_size % num_attention_heads != 0 {
candle::bail!(
"The hidden size {} is not a multiple of the number of attention heads {}",
hidden_size,
num_attention_heads
)
}
let attention_head_size = hidden_size / num_attention_heads;
let all_head_size = num_attention_heads * attention_head_size;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
(
Some(conv2d(
hidden_size,
hidden_size,
sequence_reduction_ratio,
Conv2dConfig {
stride: sequence_reduction_ratio,
..Default::default()
},
vb.pp("sr"),
)?),
Some(candle_nn::layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp("layer_norm"),
)?),
)
} else {
(None, None)
};
Ok(Self {
num_attention_heads,
attention_head_size,
query,
key,
value,
sr,
layer_norm,
})
}
fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
let (batch, seq_length, _) = hidden_states.shape().dims3()?;
let new_shape = &[
batch,
seq_length,
self.num_attention_heads,
self.attention_head_size,
];
let hidden_states = hidden_states.reshape(new_shape)?;
let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
Ok(hidden_states)
}
}
impl Module for SegformerEfficientSelfAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
let query = self
.transpose_for_scores(self.query.forward(&hidden_states)?)?
.contiguous()?;
let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
let hidden_states = sr.forward(x)?;
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
layer_norm.forward(&hidden_states)?
} else {
// already [B, H * W, C]
hidden_states
};
// standard self-attention
let key = self
.transpose_for_scores(self.key.forward(&hidden_states)?)?
.contiguous()?;
let value = self
.transpose_for_scores(self.value.forward(&hidden_states)?)?
.contiguous()?;
let attention_scores =
(query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let result = attention_scores.matmul(&value)?;
let result = result.permute((0, 2, 1, 3))?.contiguous()?;
result.flatten_from(D::Minus2)
}
}
#[derive(Debug, Clone)]
struct SegformerSelfOutput {
dense: Linear,
}
impl SegformerSelfOutput {
fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
Ok(Self { dense })
}
}
impl Module for SegformerSelfOutput {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.dense.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerAttention {
attention: SegformerEfficientSelfAttention,
output: SegformerSelfOutput,
}
impl SegformerAttention {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let attention = SegformerEfficientSelfAttention::new(
config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
vb.pp("self"),
)?;
let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
Ok(Self { attention, output })
}
}
impl Module for SegformerAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let attention_output = self.attention.forward(x)?;
self.output.forward(&attention_output)
}
}
#[derive(Debug, Clone)]
struct SegformerDWConv {
dw_conv: Conv2d,
}
impl SegformerDWConv {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let dw_conv = conv2d(
dim,
dim,
3,
Conv2dConfig {
stride: 1,
padding: 1,
groups: dim,
..Default::default()
},
vb.pp("dwconv"),
)?;
Ok(Self { dw_conv })
}
}
impl Module for SegformerDWConv {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.dw_conv.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerMixFFN {
dense1: Linear,
dw_conv: SegformerDWConv,
act: Activation,
dense2: Linear,
}
impl SegformerMixFFN {
fn new(
config: &Config,
in_features: usize,
hidden_features: usize,
out_features: usize,
vb: VarBuilder,
) -> Result<Self> {
let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
let act = config.hidden_act;
let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
Ok(Self {
dense1,
dw_conv,
act,
dense2,
})
}
}
impl Module for SegformerMixFFN {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (batch, _, height, width) = x.shape().dims4()?;
let hidden_states = self
.dense1
.forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
let channels = hidden_states.dim(2)?;
let hidden_states = self.dw_conv.forward(
&hidden_states
.permute((0, 2, 1))?
.reshape((batch, channels, height, width))?,
)?;
let hidden_states = self.act.forward(&hidden_states)?;
let hidden_states = self
.dense2
.forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
let channels = hidden_states.dim(2)?;
hidden_states
.permute((0, 2, 1))?
.reshape((batch, channels, height, width))
}
}
#[derive(Debug, Clone)]
struct SegformerLayer {
layer_norm_1: candle_nn::LayerNorm,
attention: SegformerAttention,
layer_norm_2: candle_nn::LayerNorm,
mlp: SegformerMixFFN,
}
impl SegformerLayer {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
mlp_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
let attention = SegformerAttention::new(
config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
vb.pp("attention"),
)?;
let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
let mlp = SegformerMixFFN::new(
config,
hidden_size,
hidden_size * mlp_ratio,
hidden_size,
vb.pp("mlp"),
)?;
Ok(Self {
layer_norm_1,
attention,
layer_norm_2,
mlp,
})
}
}
impl Module for SegformerLayer {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let shape = x.shape().dims4()?;
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
// attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C])
let attention_output = self.attention.forward(&layer_norm_output)?;
let hidden_states = (attention_output + hidden_states)?;
let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
let mlp_output = self
.mlp
.forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
}
}
#[derive(Debug, Clone)]
struct SegformerEncoder {
/// config file
config: Config,
/// a list of embeddings
patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
/// a list of attention blocks, each consisting of layers
blocks: Vec<Vec<SegformerLayer>>,
/// a final list of layer norms
layer_norms: Vec<candle_nn::LayerNorm>,
}
impl SegformerEncoder {
fn new(config: Config, vb: VarBuilder) -> Result<Self> {
let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let patch_size = config.patch_sizes[i];
let stride = config.strides[i];
let hidden_size = config.hidden_sizes[i];
let num_channels = if i == 0 {
config.num_channels
} else {
config.hidden_sizes[i - 1]
};
patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
&config,
patch_size,
stride,
num_channels,
hidden_size,
vb.pp(&format!("patch_embeddings.{}", i)),
)?);
let mut layers = Vec::with_capacity(config.depths[i]);
for j in 0..config.depths[i] {
let sequence_reduction_ratio = config.sr_ratios[i];
let num_attention_heads = config.num_attention_heads[i];
let mlp_ratio = config.mlp_ratios[i];
layers.push(SegformerLayer::new(
&config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
mlp_ratio,
vb.pp(&format!("block.{}.{}", i, j)),
)?);
}
blocks.push(layers);
layer_norms.push(layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp(&format!("layer_norm.{}", i)),
)?);
}
Ok(Self {
config,
patch_embeddings,
blocks,
layer_norms,
})
}
}
impl ModuleWithHiddenStates for SegformerEncoder {
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
let mut hidden_states = x.clone();
for i in 0..self.config.num_encoder_blocks {
hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
for layer in &self.blocks[i] {
hidden_states = layer.forward(&hidden_states)?;
}
let shape = hidden_states.shape().dims4()?;
hidden_states =
self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
all_hidden_states.push(hidden_states.clone());
}
Ok(all_hidden_states)
}
}
#[derive(Debug, Clone)]
struct SegformerModel {
encoder: SegformerEncoder,
}
impl SegformerModel {
fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
Ok(Self { encoder })
}
}
impl ModuleWithHiddenStates for SegformerModel {
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
self.encoder.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerMLP {
proj: Linear,
}
impl SegformerMLP {
fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
Ok(Self { proj })
}
}
impl Module for SegformerMLP {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.proj.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerDecodeHead {
linear_c: Vec<SegformerMLP>,
linear_fuse: candle_nn::Conv2d,
batch_norm: candle_nn::BatchNorm,
classifier: candle_nn::Conv2d,
}
impl SegformerDecodeHead {
fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let hidden_size = config.hidden_sizes[i];
linear_c.push(SegformerMLP::new(
config,
hidden_size,
vb.pp(&format!("linear_c.{}", i)),
)?);
}
let linear_fuse = conv2d_no_bias(
config.decoder_hidden_size * config.num_encoder_blocks,
config.decoder_hidden_size,
1,
Conv2dConfig::default(),
vb.pp("linear_fuse"),
)?;
let batch_norm = candle_nn::batch_norm(
config.decoder_hidden_size,
config.layer_norm_eps,
vb.pp("batch_norm"),
)?;
let classifier = conv2d_no_bias(
config.decoder_hidden_size,
num_labels,
1,
Conv2dConfig::default(),
vb.pp("classifier"),
)?;
Ok(Self {
linear_c,
linear_fuse,
batch_norm,
classifier,
})
}
fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
if encoder_hidden_states.len() != self.linear_c.len() {
candle::bail!(
"The number of encoder hidden states {} is not equal to the number of linear layers {}",
encoder_hidden_states.len(),
self.linear_c.len()
)
}
// most fine layer
let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
let mut hidden_states = Vec::with_capacity(self.linear_c.len());
for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
let (batch, _, height, width) = hidden_state.shape().dims4()?;
let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
batch,
hidden_state.dim(2)?,
height,
width,
))?;
let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
hidden_states.push(hidden_state);
}
hidden_states.reverse();
let hidden_states = Tensor::cat(&hidden_states, 1)?;
let hidden_states = self.linear_fuse.forward(&hidden_states)?;
let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
let hidden_states = hidden_states.relu()?;
self.classifier.forward(&hidden_states)
}
}
trait ModuleWithHiddenStates {
fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
}
#[derive(Debug, Clone)]
pub struct SemanticSegmentationModel {
segformer: SegformerModel,
decode_head: SegformerDecodeHead,
}
impl SemanticSegmentationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
Ok(Self {
segformer,
decode_head,
})
}
}
impl Module for SemanticSegmentationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let hidden_states = self.segformer.forward(x)?;
self.decode_head.forward(&hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct ImageClassificationModel {
segformer: SegformerModel,
classifier: Linear,
}
impl ImageClassificationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
Ok(Self {
segformer,
classifier,
})
}
}
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
let hidden_states = all_hidden_states.last().unwrap();
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_json_load() {
let raw_json = r#"{
"architectures": [
"SegformerForImageClassification"
],
"attention_probs_dropout_prob": 0.0,
"classifier_dropout_prob": 0.1,
"decoder_hidden_size": 256,
"depths": [
2,
2,
2,
2
],
"downsampling_rates": [
1,
4,
8,
16
],
"drop_path_rate": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_sizes": [
32,
64,
160,
256
],
"image_size": 224,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"mlp_ratios": [
4,
4,
4,
4
],
"model_type": "segformer",
"num_attention_heads": [
1,
2,
5,
8
],
"num_channels": 3,
"num_encoder_blocks": 4,
"patch_sizes": [
7,
3,
3,
3
],
"sr_ratios": [
8,
4,
2,
1
],
"strides": [
4,
2,
2,
2
],
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0"
}"#;
let config: Config = serde_json::from_str(raw_json).unwrap();
assert_eq!(vec![4, 2, 2, 2], config.strides);
assert_eq!(1e-6, config.layer_norm_eps);
}
}