mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add an initial Segformer implementation (#1617)
* add segformer * Make the id2label field optional. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
28
candle-examples/examples/segformer/README.md
Normal file
28
candle-examples/examples/segformer/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# candle-segformer
|
||||
|
||||
- [HuggingFace Segformer Model Card][segformer]
|
||||
- [`mit-b0` - An encoder only pretrained model][encoder]
|
||||
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
|
||||
|
||||
## How to run the example
|
||||
|
||||
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
|
||||
|
||||
```bash
|
||||
# run the image classification task
|
||||
cargo run --example segformer classify <path-to-image>
|
||||
# run the segmentation task
|
||||
cargo run --example segformer segment <path-to-image>
|
||||
```
|
||||
|
||||
Example output for classification:
|
||||
|
||||
```text
|
||||
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
|
||||
label: hamburger
|
||||
```
|
||||
|
||||
[pr]: https://github.com/huggingface/candle/pull/1617
|
||||
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
|
||||
[encoder]: https://huggingface.co/nvidia/mit-b0
|
||||
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
|
752
candle-examples/examples/segformer/assets/labels.json
Normal file
752
candle-examples/examples/segformer/assets/labels.json
Normal file
@ -0,0 +1,752 @@
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"color": "#787878",
|
||||
"label": "wall"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"color": "#B47878",
|
||||
"label": "building;edifice"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"color": "#06E6E6",
|
||||
"label": "sky"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"color": "#503232",
|
||||
"label": "floor;flooring"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"color": "#04C803",
|
||||
"label": "tree"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"color": "#787850",
|
||||
"label": "ceiling"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"color": "#8C8C8C",
|
||||
"label": "road;route"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"color": "#CC05FF",
|
||||
"label": "bed"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"color": "#E6E6E6",
|
||||
"label": "windowpane;window"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"color": "#04FA07",
|
||||
"label": "grass"
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"color": "#E005FF",
|
||||
"label": "cabinet"
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"color": "#EBFF07",
|
||||
"label": "sidewalk;pavement"
|
||||
},
|
||||
{
|
||||
"index": 13,
|
||||
"color": "#96053D",
|
||||
"label": "person;individual;someone;somebody;mortal;soul"
|
||||
},
|
||||
{
|
||||
"index": 14,
|
||||
"color": "#787846",
|
||||
"label": "earth;ground"
|
||||
},
|
||||
{
|
||||
"index": 15,
|
||||
"color": "#08FF33",
|
||||
"label": "door;double;door"
|
||||
},
|
||||
{
|
||||
"index": 16,
|
||||
"color": "#FF0652",
|
||||
"label": "table"
|
||||
},
|
||||
{
|
||||
"index": 17,
|
||||
"color": "#8FFF8C",
|
||||
"label": "mountain;mount"
|
||||
},
|
||||
{
|
||||
"index": 18,
|
||||
"color": "#CCFF04",
|
||||
"label": "plant;flora;plant;life"
|
||||
},
|
||||
{
|
||||
"index": 19,
|
||||
"color": "#FF3307",
|
||||
"label": "curtain;drape;drapery;mantle;pall"
|
||||
},
|
||||
{
|
||||
"index": 20,
|
||||
"color": "#CC4603",
|
||||
"label": "chair"
|
||||
},
|
||||
{
|
||||
"index": 21,
|
||||
"color": "#0066C8",
|
||||
"label": "car;auto;automobile;machine;motorcar"
|
||||
},
|
||||
{
|
||||
"index": 22,
|
||||
"color": "#3DE6FA",
|
||||
"label": "water"
|
||||
},
|
||||
{
|
||||
"index": 23,
|
||||
"color": "#FF0633",
|
||||
"label": "painting;picture"
|
||||
},
|
||||
{
|
||||
"index": 24,
|
||||
"color": "#0B66FF",
|
||||
"label": "sofa;couch;lounge"
|
||||
},
|
||||
{
|
||||
"index": 25,
|
||||
"color": "#FF0747",
|
||||
"label": "shelf"
|
||||
},
|
||||
{
|
||||
"index": 26,
|
||||
"color": "#FF09E0",
|
||||
"label": "house"
|
||||
},
|
||||
{
|
||||
"index": 27,
|
||||
"color": "#0907E6",
|
||||
"label": "sea"
|
||||
},
|
||||
{
|
||||
"index": 28,
|
||||
"color": "#DCDCDC",
|
||||
"label": "mirror"
|
||||
},
|
||||
{
|
||||
"index": 29,
|
||||
"color": "#FF095C",
|
||||
"label": "rug;carpet;carpeting"
|
||||
},
|
||||
{
|
||||
"index": 30,
|
||||
"color": "#7009FF",
|
||||
"label": "field"
|
||||
},
|
||||
{
|
||||
"index": 31,
|
||||
"color": "#08FFD6",
|
||||
"label": "armchair"
|
||||
},
|
||||
{
|
||||
"index": 32,
|
||||
"color": "#07FFE0",
|
||||
"label": "seat"
|
||||
},
|
||||
{
|
||||
"index": 33,
|
||||
"color": "#FFB806",
|
||||
"label": "fence;fencing"
|
||||
},
|
||||
{
|
||||
"index": 34,
|
||||
"color": "#0AFF47",
|
||||
"label": "desk"
|
||||
},
|
||||
{
|
||||
"index": 35,
|
||||
"color": "#FF290A",
|
||||
"label": "rock;stone"
|
||||
},
|
||||
{
|
||||
"index": 36,
|
||||
"color": "#07FFFF",
|
||||
"label": "wardrobe;closet;press"
|
||||
},
|
||||
{
|
||||
"index": 37,
|
||||
"color": "#E0FF08",
|
||||
"label": "lamp"
|
||||
},
|
||||
{
|
||||
"index": 38,
|
||||
"color": "#6608FF",
|
||||
"label": "bathtub;bathing;tub;bath;tub"
|
||||
},
|
||||
{
|
||||
"index": 39,
|
||||
"color": "#FF3D06",
|
||||
"label": "railing;rail"
|
||||
},
|
||||
{
|
||||
"index": 40,
|
||||
"color": "#FFC207",
|
||||
"label": "cushion"
|
||||
},
|
||||
{
|
||||
"index": 41,
|
||||
"color": "#FF7A08",
|
||||
"label": "base;pedestal;stand"
|
||||
},
|
||||
{
|
||||
"index": 42,
|
||||
"color": "#00FF14",
|
||||
"label": "box"
|
||||
},
|
||||
{
|
||||
"index": 43,
|
||||
"color": "#FF0829",
|
||||
"label": "column;pillar"
|
||||
},
|
||||
{
|
||||
"index": 44,
|
||||
"color": "#FF0599",
|
||||
"label": "signboard;sign"
|
||||
},
|
||||
{
|
||||
"index": 45,
|
||||
"color": "#0633FF",
|
||||
"label": "chest;of;drawers;chest;bureau;dresser"
|
||||
},
|
||||
{
|
||||
"index": 46,
|
||||
"color": "#EB0CFF",
|
||||
"label": "counter"
|
||||
},
|
||||
{
|
||||
"index": 47,
|
||||
"color": "#A09614",
|
||||
"label": "sand"
|
||||
},
|
||||
{
|
||||
"index": 48,
|
||||
"color": "#00A3FF",
|
||||
"label": "sink"
|
||||
},
|
||||
{
|
||||
"index": 49,
|
||||
"color": "#8C8C8C",
|
||||
"label": "skyscraper"
|
||||
},
|
||||
{
|
||||
"index": 50,
|
||||
"color": "#FA0A0F",
|
||||
"label": "fireplace;hearth;open;fireplace"
|
||||
},
|
||||
{
|
||||
"index": 51,
|
||||
"color": "#14FF00",
|
||||
"label": "refrigerator;icebox"
|
||||
},
|
||||
{
|
||||
"index": 52,
|
||||
"color": "#1FFF00",
|
||||
"label": "grandstand;covered;stand"
|
||||
},
|
||||
{
|
||||
"index": 53,
|
||||
"color": "#FF1F00",
|
||||
"label": "path"
|
||||
},
|
||||
{
|
||||
"index": 54,
|
||||
"color": "#FFE000",
|
||||
"label": "stairs;steps"
|
||||
},
|
||||
{
|
||||
"index": 55,
|
||||
"color": "#99FF00",
|
||||
"label": "runway"
|
||||
},
|
||||
{
|
||||
"index": 56,
|
||||
"color": "#0000FF",
|
||||
"label": "case;display;case;showcase;vitrine"
|
||||
},
|
||||
{
|
||||
"index": 57,
|
||||
"color": "#FF4700",
|
||||
"label": "pool;table;billiard;table;snooker;table"
|
||||
},
|
||||
{
|
||||
"index": 58,
|
||||
"color": "#00EBFF",
|
||||
"label": "pillow"
|
||||
},
|
||||
{
|
||||
"index": 59,
|
||||
"color": "#00ADFF",
|
||||
"label": "screen;door;screen"
|
||||
},
|
||||
{
|
||||
"index": 60,
|
||||
"color": "#1F00FF",
|
||||
"label": "stairway;staircase"
|
||||
},
|
||||
{
|
||||
"index": 61,
|
||||
"color": "#0BC8C8",
|
||||
"label": "river"
|
||||
},
|
||||
{
|
||||
"index": 62,
|
||||
"color": "#FF5200",
|
||||
"label": "bridge;span"
|
||||
},
|
||||
{
|
||||
"index": 63,
|
||||
"color": "#00FFF5",
|
||||
"label": "bookcase"
|
||||
},
|
||||
{
|
||||
"index": 64,
|
||||
"color": "#003DFF",
|
||||
"label": "blind;screen"
|
||||
},
|
||||
{
|
||||
"index": 65,
|
||||
"color": "#00FF70",
|
||||
"label": "coffee;table;cocktail;table"
|
||||
},
|
||||
{
|
||||
"index": 66,
|
||||
"color": "#00FF85",
|
||||
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
|
||||
},
|
||||
{
|
||||
"index": 67,
|
||||
"color": "#FF0000",
|
||||
"label": "flower"
|
||||
},
|
||||
{
|
||||
"index": 68,
|
||||
"color": "#FFA300",
|
||||
"label": "book"
|
||||
},
|
||||
{
|
||||
"index": 69,
|
||||
"color": "#FF6600",
|
||||
"label": "hill"
|
||||
},
|
||||
{
|
||||
"index": 70,
|
||||
"color": "#C2FF00",
|
||||
"label": "bench"
|
||||
},
|
||||
{
|
||||
"index": 71,
|
||||
"color": "#008FFF",
|
||||
"label": "countertop"
|
||||
},
|
||||
{
|
||||
"index": 72,
|
||||
"color": "#33FF00",
|
||||
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
|
||||
},
|
||||
{
|
||||
"index": 73,
|
||||
"color": "#0052FF",
|
||||
"label": "palm;palm;tree"
|
||||
},
|
||||
{
|
||||
"index": 74,
|
||||
"color": "#00FF29",
|
||||
"label": "kitchen;island"
|
||||
},
|
||||
{
|
||||
"index": 75,
|
||||
"color": "#00FFAD",
|
||||
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
|
||||
},
|
||||
{
|
||||
"index": 76,
|
||||
"color": "#0A00FF",
|
||||
"label": "swivel;chair"
|
||||
},
|
||||
{
|
||||
"index": 77,
|
||||
"color": "#ADFF00",
|
||||
"label": "boat"
|
||||
},
|
||||
{
|
||||
"index": 78,
|
||||
"color": "#00FF99",
|
||||
"label": "bar"
|
||||
},
|
||||
{
|
||||
"index": 79,
|
||||
"color": "#FF5C00",
|
||||
"label": "arcade;machine"
|
||||
},
|
||||
{
|
||||
"index": 80,
|
||||
"color": "#FF00FF",
|
||||
"label": "hovel;hut;hutch;shack;shanty"
|
||||
},
|
||||
{
|
||||
"index": 81,
|
||||
"color": "#FF00F5",
|
||||
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
|
||||
},
|
||||
{
|
||||
"index": 82,
|
||||
"color": "#FF0066",
|
||||
"label": "towel"
|
||||
},
|
||||
{
|
||||
"index": 83,
|
||||
"color": "#FFAD00",
|
||||
"label": "light;light;source"
|
||||
},
|
||||
{
|
||||
"index": 84,
|
||||
"color": "#FF0014",
|
||||
"label": "truck;motortruck"
|
||||
},
|
||||
{
|
||||
"index": 85,
|
||||
"color": "#FFB8B8",
|
||||
"label": "tower"
|
||||
},
|
||||
{
|
||||
"index": 86,
|
||||
"color": "#001FFF",
|
||||
"label": "chandelier;pendant;pendent"
|
||||
},
|
||||
{
|
||||
"index": 87,
|
||||
"color": "#00FF3D",
|
||||
"label": "awning;sunshade;sunblind"
|
||||
},
|
||||
{
|
||||
"index": 88,
|
||||
"color": "#0047FF",
|
||||
"label": "streetlight;street;lamp"
|
||||
},
|
||||
{
|
||||
"index": 89,
|
||||
"color": "#FF00CC",
|
||||
"label": "booth;cubicle;stall;kiosk"
|
||||
},
|
||||
{
|
||||
"index": 90,
|
||||
"color": "#00FFC2",
|
||||
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
|
||||
},
|
||||
{
|
||||
"index": 91,
|
||||
"color": "#00FF52",
|
||||
"label": "airplane;aeroplane;plane"
|
||||
},
|
||||
{
|
||||
"index": 92,
|
||||
"color": "#000AFF",
|
||||
"label": "dirt;track"
|
||||
},
|
||||
{
|
||||
"index": 93,
|
||||
"color": "#0070FF",
|
||||
"label": "apparel;wearing;apparel;dress;clothes"
|
||||
},
|
||||
{
|
||||
"index": 94,
|
||||
"color": "#3300FF",
|
||||
"label": "pole"
|
||||
},
|
||||
{
|
||||
"index": 95,
|
||||
"color": "#00C2FF",
|
||||
"label": "land;ground;soil"
|
||||
},
|
||||
{
|
||||
"index": 96,
|
||||
"color": "#007AFF",
|
||||
"label": "bannister;banister;balustrade;balusters;handrail"
|
||||
},
|
||||
{
|
||||
"index": 97,
|
||||
"color": "#00FFA3",
|
||||
"label": "escalator;moving;staircase;moving;stairway"
|
||||
},
|
||||
{
|
||||
"index": 98,
|
||||
"color": "#FF9900",
|
||||
"label": "ottoman;pouf;pouffe;puff;hassock"
|
||||
},
|
||||
{
|
||||
"index": 99,
|
||||
"color": "#00FF0A",
|
||||
"label": "bottle"
|
||||
},
|
||||
{
|
||||
"index": 100,
|
||||
"color": "#FF7000",
|
||||
"label": "buffet;counter;sideboard"
|
||||
},
|
||||
{
|
||||
"index": 101,
|
||||
"color": "#8FFF00",
|
||||
"label": "poster;posting;placard;notice;bill;card"
|
||||
},
|
||||
{
|
||||
"index": 102,
|
||||
"color": "#5200FF",
|
||||
"label": "stage"
|
||||
},
|
||||
{
|
||||
"index": 103,
|
||||
"color": "#A3FF00",
|
||||
"label": "van"
|
||||
},
|
||||
{
|
||||
"index": 104,
|
||||
"color": "#FFEB00",
|
||||
"label": "ship"
|
||||
},
|
||||
{
|
||||
"index": 105,
|
||||
"color": "#08B8AA",
|
||||
"label": "fountain"
|
||||
},
|
||||
{
|
||||
"index": 106,
|
||||
"color": "#8500FF",
|
||||
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
|
||||
},
|
||||
{
|
||||
"index": 107,
|
||||
"color": "#00FF5C",
|
||||
"label": "canopy"
|
||||
},
|
||||
{
|
||||
"index": 108,
|
||||
"color": "#B800FF",
|
||||
"label": "washer;automatic;washer;washing;machine"
|
||||
},
|
||||
{
|
||||
"index": 109,
|
||||
"color": "#FF001F",
|
||||
"label": "plaything;toy"
|
||||
},
|
||||
{
|
||||
"index": 110,
|
||||
"color": "#00B8FF",
|
||||
"label": "swimming;pool;swimming;bath;natatorium"
|
||||
},
|
||||
{
|
||||
"index": 111,
|
||||
"color": "#00D6FF",
|
||||
"label": "stool"
|
||||
},
|
||||
{
|
||||
"index": 112,
|
||||
"color": "#FF0070",
|
||||
"label": "barrel;cask"
|
||||
},
|
||||
{
|
||||
"index": 113,
|
||||
"color": "#5CFF00",
|
||||
"label": "basket;handbasket"
|
||||
},
|
||||
{
|
||||
"index": 114,
|
||||
"color": "#00E0FF",
|
||||
"label": "waterfall;falls"
|
||||
},
|
||||
{
|
||||
"index": 115,
|
||||
"color": "#70E0FF",
|
||||
"label": "tent;collapsible;shelter"
|
||||
},
|
||||
{
|
||||
"index": 116,
|
||||
"color": "#46B8A0",
|
||||
"label": "bag"
|
||||
},
|
||||
{
|
||||
"index": 117,
|
||||
"color": "#A300FF",
|
||||
"label": "minibike;motorbike"
|
||||
},
|
||||
{
|
||||
"index": 118,
|
||||
"color": "#9900FF",
|
||||
"label": "cradle"
|
||||
},
|
||||
{
|
||||
"index": 119,
|
||||
"color": "#47FF00",
|
||||
"label": "oven"
|
||||
},
|
||||
{
|
||||
"index": 120,
|
||||
"color": "#FF00A3",
|
||||
"label": "ball"
|
||||
},
|
||||
{
|
||||
"index": 121,
|
||||
"color": "#FFCC00",
|
||||
"label": "food;solid;food"
|
||||
},
|
||||
{
|
||||
"index": 122,
|
||||
"color": "#FF008F",
|
||||
"label": "step;stair"
|
||||
},
|
||||
{
|
||||
"index": 123,
|
||||
"color": "#00FFEB",
|
||||
"label": "tank;storage;tank"
|
||||
},
|
||||
{
|
||||
"index": 124,
|
||||
"color": "#85FF00",
|
||||
"label": "trade;name;brand;name;brand;marque"
|
||||
},
|
||||
{
|
||||
"index": 125,
|
||||
"color": "#FF00EB",
|
||||
"label": "microwave;microwave;oven"
|
||||
},
|
||||
{
|
||||
"index": 126,
|
||||
"color": "#F500FF",
|
||||
"label": "pot;flowerpot"
|
||||
},
|
||||
{
|
||||
"index": 127,
|
||||
"color": "#FF007A",
|
||||
"label": "animal;animate;being;beast;brute;creature;fauna"
|
||||
},
|
||||
{
|
||||
"index": 128,
|
||||
"color": "#FFF500",
|
||||
"label": "bicycle;bike;wheel;cycle"
|
||||
},
|
||||
{
|
||||
"index": 129,
|
||||
"color": "#0ABED4",
|
||||
"label": "lake"
|
||||
},
|
||||
{
|
||||
"index": 130,
|
||||
"color": "#D6FF00",
|
||||
"label": "dishwasher;dish;washer;dishwashing;machine"
|
||||
},
|
||||
{
|
||||
"index": 131,
|
||||
"color": "#00CCFF",
|
||||
"label": "screen;silver;screen;projection;screen"
|
||||
},
|
||||
{
|
||||
"index": 132,
|
||||
"color": "#1400FF",
|
||||
"label": "blanket;cover"
|
||||
},
|
||||
{
|
||||
"index": 133,
|
||||
"color": "#FFFF00",
|
||||
"label": "sculpture"
|
||||
},
|
||||
{
|
||||
"index": 134,
|
||||
"color": "#0099FF",
|
||||
"label": "hood;exhaust;hood"
|
||||
},
|
||||
{
|
||||
"index": 135,
|
||||
"color": "#0029FF",
|
||||
"label": "sconce"
|
||||
},
|
||||
{
|
||||
"index": 136,
|
||||
"color": "#00FFCC",
|
||||
"label": "vase"
|
||||
},
|
||||
{
|
||||
"index": 137,
|
||||
"color": "#2900FF",
|
||||
"label": "traffic;light;traffic;signal;stoplight"
|
||||
},
|
||||
{
|
||||
"index": 138,
|
||||
"color": "#29FF00",
|
||||
"label": "tray"
|
||||
},
|
||||
{
|
||||
"index": 139,
|
||||
"color": "#AD00FF",
|
||||
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
|
||||
},
|
||||
{
|
||||
"index": 140,
|
||||
"color": "#00F5FF",
|
||||
"label": "fan"
|
||||
},
|
||||
{
|
||||
"index": 141,
|
||||
"color": "#4700FF",
|
||||
"label": "pier;wharf;wharfage;dock"
|
||||
},
|
||||
{
|
||||
"index": 142,
|
||||
"color": "#7A00FF",
|
||||
"label": "crt;screen"
|
||||
},
|
||||
{
|
||||
"index": 143,
|
||||
"color": "#00FFB8",
|
||||
"label": "plate"
|
||||
},
|
||||
{
|
||||
"index": 144,
|
||||
"color": "#005CFF",
|
||||
"label": "monitor;monitoring;device"
|
||||
},
|
||||
{
|
||||
"index": 145,
|
||||
"color": "#B8FF00",
|
||||
"label": "bulletin;board;notice;board"
|
||||
},
|
||||
{
|
||||
"index": 146,
|
||||
"color": "#0085FF",
|
||||
"label": "shower"
|
||||
},
|
||||
{
|
||||
"index": 147,
|
||||
"color": "#FFD600",
|
||||
"label": "radiator"
|
||||
},
|
||||
{
|
||||
"index": 148,
|
||||
"color": "#19C2C2",
|
||||
"label": "glass;drinking;glass"
|
||||
},
|
||||
{
|
||||
"index": 149,
|
||||
"color": "#66FF00",
|
||||
"label": "clock"
|
||||
},
|
||||
{
|
||||
"index": 150,
|
||||
"color": "#5C00FF",
|
||||
"label": "flag"
|
||||
}
|
||||
]
|
155
candle-examples/examples/segformer/main.rs
Normal file
155
candle-examples/examples/segformer/main.rs
Normal file
@ -0,0 +1,155 @@
|
||||
use candle::Device;
|
||||
use candle::Module;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segformer::{
|
||||
Config, ImageClassificationModel, SemanticSegmentationModel,
|
||||
};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use image::Rgb;
|
||||
use imageproc::integral_image::ArrayData;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(about, version, long_about = None)]
|
||||
struct CliArgs {
|
||||
#[arg(long, help = "use cpu")]
|
||||
cpu: bool,
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
#[derive(Args, Debug)]
|
||||
struct SegmentationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(
|
||||
long,
|
||||
help = "path to the label file in json format",
|
||||
default_value = "candle-examples/examples/segformer/assets/labels.json"
|
||||
)]
|
||||
label_path: PathBuf,
|
||||
#[arg(long, help = "path to for the output mask image")]
|
||||
output_path: PathBuf,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct ClassificationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "paolinox/segformer-finetuned-food101"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
Segment(SegmentationArgs),
|
||||
Classify(ClassificationArgs),
|
||||
}
|
||||
|
||||
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
|
||||
println!("loading model {} via huggingface hub", model_name);
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name.clone());
|
||||
let model_file = api.get("model.safetensors")?;
|
||||
println!("model {} downloaded and loaded", model_name);
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
|
||||
let config = std::fs::read_to_string(api.get("config.json")?)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
println!("{:?}", config);
|
||||
Ok((vb, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LabelItem {
|
||||
index: u32,
|
||||
color: String,
|
||||
}
|
||||
|
||||
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let label_file = std::fs::read_to_string(&args.label_path)?;
|
||||
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
|
||||
let label_colors: HashMap<u32, Rgb<u8>> = label_items
|
||||
.iter()
|
||||
.map(|x| {
|
||||
(x.index - 1, {
|
||||
let color = x.color.trim_start_matches('#');
|
||||
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
|
||||
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
|
||||
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
|
||||
Rgb([r, g, b])
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = label_items.len();
|
||||
|
||||
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
|
||||
let segmentations = model.forward(&image)?;
|
||||
|
||||
// generate a mask image
|
||||
let mask = &segmentations.squeeze(0)?.argmax(0)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
|
||||
let mask = mask
|
||||
.iter()
|
||||
.flat_map(|x| label_colors[x].data())
|
||||
.collect::<Vec<u8>>();
|
||||
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
|
||||
// resize
|
||||
let mask = image::DynamicImage::from(mask);
|
||||
let mask = mask.resize_to_fill(
|
||||
w as u32 * 4,
|
||||
h as u32 * 4,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
mask.save(args.output_path.clone())?;
|
||||
println!("mask image saved to {:?}", args.output_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = 7;
|
||||
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
|
||||
let classification = model.forward(&image)?;
|
||||
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
|
||||
let classification = classification.squeeze(0)?;
|
||||
println!(
|
||||
"classification logits {:?}",
|
||||
classification.to_vec1::<f32>()?
|
||||
);
|
||||
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
|
||||
let label_id = format!("{}", label_id);
|
||||
println!("label: {}", config.id2label[&label_id]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = CliArgs::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
if let Commands::Segment(args) = args.command {
|
||||
segmentation_task(args, &device)?
|
||||
} else if let Commands::Classify(args) = args.command {
|
||||
classification_task(args, &device)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -42,6 +42,7 @@ pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
pub mod rwkv_v6;
|
||||
pub mod segformer;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
|
705
candle-transformers/src/models/segformer.rs
Normal file
705
candle-transformers/src/models/segformer.rs
Normal file
@ -0,0 +1,705 @@
|
||||
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||
use candle::{Module, ModuleT, Result, Tensor, D};
|
||||
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub id2label: HashMap<String, String>,
|
||||
pub num_channels: usize,
|
||||
pub num_encoder_blocks: usize,
|
||||
pub depths: Vec<usize>,
|
||||
pub sr_ratios: Vec<usize>,
|
||||
pub hidden_sizes: Vec<usize>,
|
||||
pub patch_sizes: Vec<usize>,
|
||||
pub strides: Vec<usize>,
|
||||
pub num_attention_heads: Vec<usize>,
|
||||
pub mlp_ratios: Vec<usize>,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub decoder_hidden_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerOverlapPatchEmbeddings {
|
||||
projection: Conv2d,
|
||||
layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl SegformerOverlapPatchEmbeddings {
|
||||
fn new(
|
||||
config: &Config,
|
||||
patch_size: usize,
|
||||
stride: usize,
|
||||
num_channels: usize,
|
||||
hidden_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let projection = conv2d(
|
||||
num_channels,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
Conv2dConfig {
|
||||
stride,
|
||||
padding: patch_size / 2,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("proj"),
|
||||
)?;
|
||||
let layer_norm =
|
||||
candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
projection,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerOverlapPatchEmbeddings {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let embeddings = self.projection.forward(x)?;
|
||||
let shape = embeddings.shape();
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
// [B, H * W, C] -> [B, C, H, W]
|
||||
let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerEfficientSelfAttention {
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
sr: Option<Conv2d>,
|
||||
layer_norm: Option<layer_norm::LayerNorm>,
|
||||
}
|
||||
|
||||
impl SegformerEfficientSelfAttention {
|
||||
fn new(
|
||||
config: &Config,
|
||||
hidden_size: usize,
|
||||
num_attention_heads: usize,
|
||||
sequence_reduction_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if hidden_size % num_attention_heads != 0 {
|
||||
candle::bail!(
|
||||
"The hidden size {} is not a multiple of the number of attention heads {}",
|
||||
hidden_size,
|
||||
num_attention_heads
|
||||
)
|
||||
}
|
||||
let attention_head_size = hidden_size / num_attention_heads;
|
||||
let all_head_size = num_attention_heads * attention_head_size;
|
||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||
let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
|
||||
(
|
||||
Some(conv2d(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
sequence_reduction_ratio,
|
||||
Conv2dConfig {
|
||||
stride: sequence_reduction_ratio,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("sr"),
|
||||
)?),
|
||||
Some(candle_nn::layer_norm(
|
||||
hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("layer_norm"),
|
||||
)?),
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
Ok(Self {
|
||||
num_attention_heads,
|
||||
attention_head_size,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sr,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
|
||||
let (batch, seq_length, _) = hidden_states.shape().dims3()?;
|
||||
let new_shape = &[
|
||||
batch,
|
||||
seq_length,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
];
|
||||
let hidden_states = hidden_states.reshape(new_shape)?;
|
||||
let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerEfficientSelfAttention {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let query = self
|
||||
.transpose_for_scores(self.query.forward(&hidden_states)?)?
|
||||
.contiguous()?;
|
||||
let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
|
||||
let hidden_states = sr.forward(x)?;
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
layer_norm.forward(&hidden_states)?
|
||||
} else {
|
||||
// already [B, H * W, C]
|
||||
hidden_states
|
||||
};
|
||||
// standard self-attention
|
||||
let key = self
|
||||
.transpose_for_scores(self.key.forward(&hidden_states)?)?
|
||||
.contiguous()?;
|
||||
let value = self
|
||||
.transpose_for_scores(self.value.forward(&hidden_states)?)?
|
||||
.contiguous()?;
|
||||
let attention_scores =
|
||||
(query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
|
||||
let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
let result = attention_scores.matmul(&value)?;
|
||||
let result = result.permute((0, 2, 1, 3))?.contiguous()?;
|
||||
result.flatten_from(D::Minus2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerSelfOutput {
|
||||
dense: Linear,
|
||||
}
|
||||
|
||||
impl SegformerSelfOutput {
|
||||
fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
|
||||
Ok(Self { dense })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerSelfOutput {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.dense.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerAttention {
|
||||
attention: SegformerEfficientSelfAttention,
|
||||
output: SegformerSelfOutput,
|
||||
}
|
||||
|
||||
impl SegformerAttention {
|
||||
fn new(
|
||||
config: &Config,
|
||||
hidden_size: usize,
|
||||
num_attention_heads: usize,
|
||||
sequence_reduction_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let attention = SegformerEfficientSelfAttention::new(
|
||||
config,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
vb.pp("self"),
|
||||
)?;
|
||||
let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
|
||||
Ok(Self { attention, output })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerAttention {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(x)?;
|
||||
self.output.forward(&attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerDWConv {
|
||||
dw_conv: Conv2d,
|
||||
}
|
||||
|
||||
impl SegformerDWConv {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let dw_conv = conv2d(
|
||||
dim,
|
||||
dim,
|
||||
3,
|
||||
Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("dwconv"),
|
||||
)?;
|
||||
Ok(Self { dw_conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerDWConv {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.dw_conv.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerMixFFN {
|
||||
dense1: Linear,
|
||||
dw_conv: SegformerDWConv,
|
||||
act: Activation,
|
||||
dense2: Linear,
|
||||
}
|
||||
|
||||
impl SegformerMixFFN {
|
||||
fn new(
|
||||
config: &Config,
|
||||
in_features: usize,
|
||||
hidden_features: usize,
|
||||
out_features: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
|
||||
let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
|
||||
let act = config.hidden_act;
|
||||
let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
|
||||
Ok(Self {
|
||||
dense1,
|
||||
dw_conv,
|
||||
act,
|
||||
dense2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerMixFFN {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (batch, _, height, width) = x.shape().dims4()?;
|
||||
let hidden_states = self
|
||||
.dense1
|
||||
.forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
let channels = hidden_states.dim(2)?;
|
||||
let hidden_states = self.dw_conv.forward(
|
||||
&hidden_states
|
||||
.permute((0, 2, 1))?
|
||||
.reshape((batch, channels, height, width))?,
|
||||
)?;
|
||||
let hidden_states = self.act.forward(&hidden_states)?;
|
||||
let hidden_states = self
|
||||
.dense2
|
||||
.forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
let channels = hidden_states.dim(2)?;
|
||||
hidden_states
|
||||
.permute((0, 2, 1))?
|
||||
.reshape((batch, channels, height, width))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerLayer {
|
||||
layer_norm_1: candle_nn::LayerNorm,
|
||||
attention: SegformerAttention,
|
||||
layer_norm_2: candle_nn::LayerNorm,
|
||||
mlp: SegformerMixFFN,
|
||||
}
|
||||
|
||||
impl SegformerLayer {
|
||||
fn new(
|
||||
config: &Config,
|
||||
hidden_size: usize,
|
||||
num_attention_heads: usize,
|
||||
sequence_reduction_ratio: usize,
|
||||
mlp_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
|
||||
let attention = SegformerAttention::new(
|
||||
config,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
vb.pp("attention"),
|
||||
)?;
|
||||
let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
|
||||
let mlp = SegformerMixFFN::new(
|
||||
config,
|
||||
hidden_size,
|
||||
hidden_size * mlp_ratio,
|
||||
hidden_size,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
layer_norm_1,
|
||||
attention,
|
||||
layer_norm_2,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerLayer {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let shape = x.shape().dims4()?;
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
|
||||
let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
|
||||
// attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C])
|
||||
let attention_output = self.attention.forward(&layer_norm_output)?;
|
||||
let hidden_states = (attention_output + hidden_states)?;
|
||||
let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
|
||||
let mlp_output = self
|
||||
.mlp
|
||||
.forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
|
||||
hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerEncoder {
|
||||
/// config file
|
||||
config: Config,
|
||||
/// a list of embeddings
|
||||
patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
|
||||
/// a list of attention blocks, each consisting of layers
|
||||
blocks: Vec<Vec<SegformerLayer>>,
|
||||
/// a final list of layer norms
|
||||
layer_norms: Vec<candle_nn::LayerNorm>,
|
||||
}
|
||||
|
||||
impl SegformerEncoder {
|
||||
fn new(config: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
|
||||
let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
|
||||
let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
|
||||
for i in 0..config.num_encoder_blocks {
|
||||
let patch_size = config.patch_sizes[i];
|
||||
let stride = config.strides[i];
|
||||
let hidden_size = config.hidden_sizes[i];
|
||||
let num_channels = if i == 0 {
|
||||
config.num_channels
|
||||
} else {
|
||||
config.hidden_sizes[i - 1]
|
||||
};
|
||||
patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
|
||||
&config,
|
||||
patch_size,
|
||||
stride,
|
||||
num_channels,
|
||||
hidden_size,
|
||||
vb.pp(&format!("patch_embeddings.{}", i)),
|
||||
)?);
|
||||
let mut layers = Vec::with_capacity(config.depths[i]);
|
||||
for j in 0..config.depths[i] {
|
||||
let sequence_reduction_ratio = config.sr_ratios[i];
|
||||
let num_attention_heads = config.num_attention_heads[i];
|
||||
let mlp_ratio = config.mlp_ratios[i];
|
||||
layers.push(SegformerLayer::new(
|
||||
&config,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
mlp_ratio,
|
||||
vb.pp(&format!("block.{}.{}", i, j)),
|
||||
)?);
|
||||
}
|
||||
blocks.push(layers);
|
||||
layer_norms.push(layer_norm(
|
||||
hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp(&format!("layer_norm.{}", i)),
|
||||
)?);
|
||||
}
|
||||
Ok(Self {
|
||||
config,
|
||||
patch_embeddings,
|
||||
blocks,
|
||||
layer_norms,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleWithHiddenStates for SegformerEncoder {
|
||||
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
|
||||
let mut hidden_states = x.clone();
|
||||
for i in 0..self.config.num_encoder_blocks {
|
||||
hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
|
||||
for layer in &self.blocks[i] {
|
||||
hidden_states = layer.forward(&hidden_states)?;
|
||||
}
|
||||
let shape = hidden_states.shape().dims4()?;
|
||||
hidden_states =
|
||||
self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
|
||||
all_hidden_states.push(hidden_states.clone());
|
||||
}
|
||||
Ok(all_hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerModel {
|
||||
encoder: SegformerEncoder,
|
||||
}
|
||||
|
||||
impl SegformerModel {
|
||||
fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
|
||||
Ok(Self { encoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleWithHiddenStates for SegformerModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
|
||||
self.encoder.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerMLP {
|
||||
proj: Linear,
|
||||
}
|
||||
|
||||
impl SegformerMLP {
|
||||
fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerMLP {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.proj.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerDecodeHead {
|
||||
linear_c: Vec<SegformerMLP>,
|
||||
linear_fuse: candle_nn::Conv2d,
|
||||
batch_norm: candle_nn::BatchNorm,
|
||||
classifier: candle_nn::Conv2d,
|
||||
}
|
||||
|
||||
impl SegformerDecodeHead {
|
||||
fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
|
||||
for i in 0..config.num_encoder_blocks {
|
||||
let hidden_size = config.hidden_sizes[i];
|
||||
linear_c.push(SegformerMLP::new(
|
||||
config,
|
||||
hidden_size,
|
||||
vb.pp(&format!("linear_c.{}", i)),
|
||||
)?);
|
||||
}
|
||||
let linear_fuse = conv2d_no_bias(
|
||||
config.decoder_hidden_size * config.num_encoder_blocks,
|
||||
config.decoder_hidden_size,
|
||||
1,
|
||||
Conv2dConfig::default(),
|
||||
vb.pp("linear_fuse"),
|
||||
)?;
|
||||
let batch_norm = candle_nn::batch_norm(
|
||||
config.decoder_hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("batch_norm"),
|
||||
)?;
|
||||
let classifier = conv2d_no_bias(
|
||||
config.decoder_hidden_size,
|
||||
num_labels,
|
||||
1,
|
||||
Conv2dConfig::default(),
|
||||
vb.pp("classifier"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
linear_c,
|
||||
linear_fuse,
|
||||
batch_norm,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
|
||||
if encoder_hidden_states.len() != self.linear_c.len() {
|
||||
candle::bail!(
|
||||
"The number of encoder hidden states {} is not equal to the number of linear layers {}",
|
||||
encoder_hidden_states.len(),
|
||||
self.linear_c.len()
|
||||
)
|
||||
}
|
||||
// most fine layer
|
||||
let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
|
||||
let mut hidden_states = Vec::with_capacity(self.linear_c.len());
|
||||
for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
|
||||
let (batch, _, height, width) = hidden_state.shape().dims4()?;
|
||||
let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
|
||||
batch,
|
||||
hidden_state.dim(2)?,
|
||||
height,
|
||||
width,
|
||||
))?;
|
||||
let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
|
||||
hidden_states.push(hidden_state);
|
||||
}
|
||||
hidden_states.reverse();
|
||||
let hidden_states = Tensor::cat(&hidden_states, 1)?;
|
||||
let hidden_states = self.linear_fuse.forward(&hidden_states)?;
|
||||
let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
|
||||
let hidden_states = hidden_states.relu()?;
|
||||
self.classifier.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
trait ModuleWithHiddenStates {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SemanticSegmentationModel {
|
||||
segformer: SegformerModel,
|
||||
decode_head: SegformerDecodeHead,
|
||||
}
|
||||
|
||||
impl SemanticSegmentationModel {
|
||||
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
|
||||
let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
|
||||
Ok(Self {
|
||||
segformer,
|
||||
decode_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SemanticSegmentationModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.segformer.forward(x)?;
|
||||
self.decode_head.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ImageClassificationModel {
|
||||
segformer: SegformerModel,
|
||||
classifier: Linear,
|
||||
}
|
||||
|
||||
impl ImageClassificationModel {
|
||||
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
|
||||
let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
|
||||
Ok(Self {
|
||||
segformer,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageClassificationModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let all_hidden_states = self.segformer.forward(x)?;
|
||||
let hidden_states = all_hidden_states.last().unwrap();
|
||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let mean = hidden_states.mean(1)?;
|
||||
self.classifier.forward(&mean)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_json_load() {
|
||||
let raw_json = r#"{
|
||||
"architectures": [
|
||||
"SegformerForImageClassification"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"classifier_dropout_prob": 0.1,
|
||||
"decoder_hidden_size": 256,
|
||||
"depths": [
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"downsampling_rates": [
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16
|
||||
],
|
||||
"drop_path_rate": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_sizes": [
|
||||
32,
|
||||
64,
|
||||
160,
|
||||
256
|
||||
],
|
||||
"image_size": 224,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"mlp_ratios": [
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4
|
||||
],
|
||||
"model_type": "segformer",
|
||||
"num_attention_heads": [
|
||||
1,
|
||||
2,
|
||||
5,
|
||||
8
|
||||
],
|
||||
"num_channels": 3,
|
||||
"num_encoder_blocks": 4,
|
||||
"patch_sizes": [
|
||||
7,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
"sr_ratios": [
|
||||
8,
|
||||
4,
|
||||
2,
|
||||
1
|
||||
],
|
||||
"strides": [
|
||||
4,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.12.0.dev0"
|
||||
}"#;
|
||||
let config: Config = serde_json::from_str(raw_json).unwrap();
|
||||
assert_eq!(vec![4, 2, 2, 2], config.strides);
|
||||
assert_eq!(1e-6, config.layer_norm_eps);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user