mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add an initial Segformer implementation (#1617)
* add segformer * Make the id2label field optional. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
28
candle-examples/examples/segformer/README.md
Normal file
28
candle-examples/examples/segformer/README.md
Normal file
@ -0,0 +1,28 @@
|
||||
# candle-segformer
|
||||
|
||||
- [HuggingFace Segformer Model Card][segformer]
|
||||
- [`mit-b0` - An encoder only pretrained model][encoder]
|
||||
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
|
||||
|
||||
## How to run the example
|
||||
|
||||
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
|
||||
|
||||
```bash
|
||||
# run the image classification task
|
||||
cargo run --example segformer classify <path-to-image>
|
||||
# run the segmentation task
|
||||
cargo run --example segformer segment <path-to-image>
|
||||
```
|
||||
|
||||
Example output for classification:
|
||||
|
||||
```text
|
||||
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
|
||||
label: hamburger
|
||||
```
|
||||
|
||||
[pr]: https://github.com/huggingface/candle/pull/1617
|
||||
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
|
||||
[encoder]: https://huggingface.co/nvidia/mit-b0
|
||||
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
|
752
candle-examples/examples/segformer/assets/labels.json
Normal file
752
candle-examples/examples/segformer/assets/labels.json
Normal file
@ -0,0 +1,752 @@
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"color": "#787878",
|
||||
"label": "wall"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"color": "#B47878",
|
||||
"label": "building;edifice"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"color": "#06E6E6",
|
||||
"label": "sky"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"color": "#503232",
|
||||
"label": "floor;flooring"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"color": "#04C803",
|
||||
"label": "tree"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"color": "#787850",
|
||||
"label": "ceiling"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"color": "#8C8C8C",
|
||||
"label": "road;route"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"color": "#CC05FF",
|
||||
"label": "bed"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"color": "#E6E6E6",
|
||||
"label": "windowpane;window"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"color": "#04FA07",
|
||||
"label": "grass"
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"color": "#E005FF",
|
||||
"label": "cabinet"
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"color": "#EBFF07",
|
||||
"label": "sidewalk;pavement"
|
||||
},
|
||||
{
|
||||
"index": 13,
|
||||
"color": "#96053D",
|
||||
"label": "person;individual;someone;somebody;mortal;soul"
|
||||
},
|
||||
{
|
||||
"index": 14,
|
||||
"color": "#787846",
|
||||
"label": "earth;ground"
|
||||
},
|
||||
{
|
||||
"index": 15,
|
||||
"color": "#08FF33",
|
||||
"label": "door;double;door"
|
||||
},
|
||||
{
|
||||
"index": 16,
|
||||
"color": "#FF0652",
|
||||
"label": "table"
|
||||
},
|
||||
{
|
||||
"index": 17,
|
||||
"color": "#8FFF8C",
|
||||
"label": "mountain;mount"
|
||||
},
|
||||
{
|
||||
"index": 18,
|
||||
"color": "#CCFF04",
|
||||
"label": "plant;flora;plant;life"
|
||||
},
|
||||
{
|
||||
"index": 19,
|
||||
"color": "#FF3307",
|
||||
"label": "curtain;drape;drapery;mantle;pall"
|
||||
},
|
||||
{
|
||||
"index": 20,
|
||||
"color": "#CC4603",
|
||||
"label": "chair"
|
||||
},
|
||||
{
|
||||
"index": 21,
|
||||
"color": "#0066C8",
|
||||
"label": "car;auto;automobile;machine;motorcar"
|
||||
},
|
||||
{
|
||||
"index": 22,
|
||||
"color": "#3DE6FA",
|
||||
"label": "water"
|
||||
},
|
||||
{
|
||||
"index": 23,
|
||||
"color": "#FF0633",
|
||||
"label": "painting;picture"
|
||||
},
|
||||
{
|
||||
"index": 24,
|
||||
"color": "#0B66FF",
|
||||
"label": "sofa;couch;lounge"
|
||||
},
|
||||
{
|
||||
"index": 25,
|
||||
"color": "#FF0747",
|
||||
"label": "shelf"
|
||||
},
|
||||
{
|
||||
"index": 26,
|
||||
"color": "#FF09E0",
|
||||
"label": "house"
|
||||
},
|
||||
{
|
||||
"index": 27,
|
||||
"color": "#0907E6",
|
||||
"label": "sea"
|
||||
},
|
||||
{
|
||||
"index": 28,
|
||||
"color": "#DCDCDC",
|
||||
"label": "mirror"
|
||||
},
|
||||
{
|
||||
"index": 29,
|
||||
"color": "#FF095C",
|
||||
"label": "rug;carpet;carpeting"
|
||||
},
|
||||
{
|
||||
"index": 30,
|
||||
"color": "#7009FF",
|
||||
"label": "field"
|
||||
},
|
||||
{
|
||||
"index": 31,
|
||||
"color": "#08FFD6",
|
||||
"label": "armchair"
|
||||
},
|
||||
{
|
||||
"index": 32,
|
||||
"color": "#07FFE0",
|
||||
"label": "seat"
|
||||
},
|
||||
{
|
||||
"index": 33,
|
||||
"color": "#FFB806",
|
||||
"label": "fence;fencing"
|
||||
},
|
||||
{
|
||||
"index": 34,
|
||||
"color": "#0AFF47",
|
||||
"label": "desk"
|
||||
},
|
||||
{
|
||||
"index": 35,
|
||||
"color": "#FF290A",
|
||||
"label": "rock;stone"
|
||||
},
|
||||
{
|
||||
"index": 36,
|
||||
"color": "#07FFFF",
|
||||
"label": "wardrobe;closet;press"
|
||||
},
|
||||
{
|
||||
"index": 37,
|
||||
"color": "#E0FF08",
|
||||
"label": "lamp"
|
||||
},
|
||||
{
|
||||
"index": 38,
|
||||
"color": "#6608FF",
|
||||
"label": "bathtub;bathing;tub;bath;tub"
|
||||
},
|
||||
{
|
||||
"index": 39,
|
||||
"color": "#FF3D06",
|
||||
"label": "railing;rail"
|
||||
},
|
||||
{
|
||||
"index": 40,
|
||||
"color": "#FFC207",
|
||||
"label": "cushion"
|
||||
},
|
||||
{
|
||||
"index": 41,
|
||||
"color": "#FF7A08",
|
||||
"label": "base;pedestal;stand"
|
||||
},
|
||||
{
|
||||
"index": 42,
|
||||
"color": "#00FF14",
|
||||
"label": "box"
|
||||
},
|
||||
{
|
||||
"index": 43,
|
||||
"color": "#FF0829",
|
||||
"label": "column;pillar"
|
||||
},
|
||||
{
|
||||
"index": 44,
|
||||
"color": "#FF0599",
|
||||
"label": "signboard;sign"
|
||||
},
|
||||
{
|
||||
"index": 45,
|
||||
"color": "#0633FF",
|
||||
"label": "chest;of;drawers;chest;bureau;dresser"
|
||||
},
|
||||
{
|
||||
"index": 46,
|
||||
"color": "#EB0CFF",
|
||||
"label": "counter"
|
||||
},
|
||||
{
|
||||
"index": 47,
|
||||
"color": "#A09614",
|
||||
"label": "sand"
|
||||
},
|
||||
{
|
||||
"index": 48,
|
||||
"color": "#00A3FF",
|
||||
"label": "sink"
|
||||
},
|
||||
{
|
||||
"index": 49,
|
||||
"color": "#8C8C8C",
|
||||
"label": "skyscraper"
|
||||
},
|
||||
{
|
||||
"index": 50,
|
||||
"color": "#FA0A0F",
|
||||
"label": "fireplace;hearth;open;fireplace"
|
||||
},
|
||||
{
|
||||
"index": 51,
|
||||
"color": "#14FF00",
|
||||
"label": "refrigerator;icebox"
|
||||
},
|
||||
{
|
||||
"index": 52,
|
||||
"color": "#1FFF00",
|
||||
"label": "grandstand;covered;stand"
|
||||
},
|
||||
{
|
||||
"index": 53,
|
||||
"color": "#FF1F00",
|
||||
"label": "path"
|
||||
},
|
||||
{
|
||||
"index": 54,
|
||||
"color": "#FFE000",
|
||||
"label": "stairs;steps"
|
||||
},
|
||||
{
|
||||
"index": 55,
|
||||
"color": "#99FF00",
|
||||
"label": "runway"
|
||||
},
|
||||
{
|
||||
"index": 56,
|
||||
"color": "#0000FF",
|
||||
"label": "case;display;case;showcase;vitrine"
|
||||
},
|
||||
{
|
||||
"index": 57,
|
||||
"color": "#FF4700",
|
||||
"label": "pool;table;billiard;table;snooker;table"
|
||||
},
|
||||
{
|
||||
"index": 58,
|
||||
"color": "#00EBFF",
|
||||
"label": "pillow"
|
||||
},
|
||||
{
|
||||
"index": 59,
|
||||
"color": "#00ADFF",
|
||||
"label": "screen;door;screen"
|
||||
},
|
||||
{
|
||||
"index": 60,
|
||||
"color": "#1F00FF",
|
||||
"label": "stairway;staircase"
|
||||
},
|
||||
{
|
||||
"index": 61,
|
||||
"color": "#0BC8C8",
|
||||
"label": "river"
|
||||
},
|
||||
{
|
||||
"index": 62,
|
||||
"color": "#FF5200",
|
||||
"label": "bridge;span"
|
||||
},
|
||||
{
|
||||
"index": 63,
|
||||
"color": "#00FFF5",
|
||||
"label": "bookcase"
|
||||
},
|
||||
{
|
||||
"index": 64,
|
||||
"color": "#003DFF",
|
||||
"label": "blind;screen"
|
||||
},
|
||||
{
|
||||
"index": 65,
|
||||
"color": "#00FF70",
|
||||
"label": "coffee;table;cocktail;table"
|
||||
},
|
||||
{
|
||||
"index": 66,
|
||||
"color": "#00FF85",
|
||||
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
|
||||
},
|
||||
{
|
||||
"index": 67,
|
||||
"color": "#FF0000",
|
||||
"label": "flower"
|
||||
},
|
||||
{
|
||||
"index": 68,
|
||||
"color": "#FFA300",
|
||||
"label": "book"
|
||||
},
|
||||
{
|
||||
"index": 69,
|
||||
"color": "#FF6600",
|
||||
"label": "hill"
|
||||
},
|
||||
{
|
||||
"index": 70,
|
||||
"color": "#C2FF00",
|
||||
"label": "bench"
|
||||
},
|
||||
{
|
||||
"index": 71,
|
||||
"color": "#008FFF",
|
||||
"label": "countertop"
|
||||
},
|
||||
{
|
||||
"index": 72,
|
||||
"color": "#33FF00",
|
||||
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
|
||||
},
|
||||
{
|
||||
"index": 73,
|
||||
"color": "#0052FF",
|
||||
"label": "palm;palm;tree"
|
||||
},
|
||||
{
|
||||
"index": 74,
|
||||
"color": "#00FF29",
|
||||
"label": "kitchen;island"
|
||||
},
|
||||
{
|
||||
"index": 75,
|
||||
"color": "#00FFAD",
|
||||
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
|
||||
},
|
||||
{
|
||||
"index": 76,
|
||||
"color": "#0A00FF",
|
||||
"label": "swivel;chair"
|
||||
},
|
||||
{
|
||||
"index": 77,
|
||||
"color": "#ADFF00",
|
||||
"label": "boat"
|
||||
},
|
||||
{
|
||||
"index": 78,
|
||||
"color": "#00FF99",
|
||||
"label": "bar"
|
||||
},
|
||||
{
|
||||
"index": 79,
|
||||
"color": "#FF5C00",
|
||||
"label": "arcade;machine"
|
||||
},
|
||||
{
|
||||
"index": 80,
|
||||
"color": "#FF00FF",
|
||||
"label": "hovel;hut;hutch;shack;shanty"
|
||||
},
|
||||
{
|
||||
"index": 81,
|
||||
"color": "#FF00F5",
|
||||
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
|
||||
},
|
||||
{
|
||||
"index": 82,
|
||||
"color": "#FF0066",
|
||||
"label": "towel"
|
||||
},
|
||||
{
|
||||
"index": 83,
|
||||
"color": "#FFAD00",
|
||||
"label": "light;light;source"
|
||||
},
|
||||
{
|
||||
"index": 84,
|
||||
"color": "#FF0014",
|
||||
"label": "truck;motortruck"
|
||||
},
|
||||
{
|
||||
"index": 85,
|
||||
"color": "#FFB8B8",
|
||||
"label": "tower"
|
||||
},
|
||||
{
|
||||
"index": 86,
|
||||
"color": "#001FFF",
|
||||
"label": "chandelier;pendant;pendent"
|
||||
},
|
||||
{
|
||||
"index": 87,
|
||||
"color": "#00FF3D",
|
||||
"label": "awning;sunshade;sunblind"
|
||||
},
|
||||
{
|
||||
"index": 88,
|
||||
"color": "#0047FF",
|
||||
"label": "streetlight;street;lamp"
|
||||
},
|
||||
{
|
||||
"index": 89,
|
||||
"color": "#FF00CC",
|
||||
"label": "booth;cubicle;stall;kiosk"
|
||||
},
|
||||
{
|
||||
"index": 90,
|
||||
"color": "#00FFC2",
|
||||
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
|
||||
},
|
||||
{
|
||||
"index": 91,
|
||||
"color": "#00FF52",
|
||||
"label": "airplane;aeroplane;plane"
|
||||
},
|
||||
{
|
||||
"index": 92,
|
||||
"color": "#000AFF",
|
||||
"label": "dirt;track"
|
||||
},
|
||||
{
|
||||
"index": 93,
|
||||
"color": "#0070FF",
|
||||
"label": "apparel;wearing;apparel;dress;clothes"
|
||||
},
|
||||
{
|
||||
"index": 94,
|
||||
"color": "#3300FF",
|
||||
"label": "pole"
|
||||
},
|
||||
{
|
||||
"index": 95,
|
||||
"color": "#00C2FF",
|
||||
"label": "land;ground;soil"
|
||||
},
|
||||
{
|
||||
"index": 96,
|
||||
"color": "#007AFF",
|
||||
"label": "bannister;banister;balustrade;balusters;handrail"
|
||||
},
|
||||
{
|
||||
"index": 97,
|
||||
"color": "#00FFA3",
|
||||
"label": "escalator;moving;staircase;moving;stairway"
|
||||
},
|
||||
{
|
||||
"index": 98,
|
||||
"color": "#FF9900",
|
||||
"label": "ottoman;pouf;pouffe;puff;hassock"
|
||||
},
|
||||
{
|
||||
"index": 99,
|
||||
"color": "#00FF0A",
|
||||
"label": "bottle"
|
||||
},
|
||||
{
|
||||
"index": 100,
|
||||
"color": "#FF7000",
|
||||
"label": "buffet;counter;sideboard"
|
||||
},
|
||||
{
|
||||
"index": 101,
|
||||
"color": "#8FFF00",
|
||||
"label": "poster;posting;placard;notice;bill;card"
|
||||
},
|
||||
{
|
||||
"index": 102,
|
||||
"color": "#5200FF",
|
||||
"label": "stage"
|
||||
},
|
||||
{
|
||||
"index": 103,
|
||||
"color": "#A3FF00",
|
||||
"label": "van"
|
||||
},
|
||||
{
|
||||
"index": 104,
|
||||
"color": "#FFEB00",
|
||||
"label": "ship"
|
||||
},
|
||||
{
|
||||
"index": 105,
|
||||
"color": "#08B8AA",
|
||||
"label": "fountain"
|
||||
},
|
||||
{
|
||||
"index": 106,
|
||||
"color": "#8500FF",
|
||||
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
|
||||
},
|
||||
{
|
||||
"index": 107,
|
||||
"color": "#00FF5C",
|
||||
"label": "canopy"
|
||||
},
|
||||
{
|
||||
"index": 108,
|
||||
"color": "#B800FF",
|
||||
"label": "washer;automatic;washer;washing;machine"
|
||||
},
|
||||
{
|
||||
"index": 109,
|
||||
"color": "#FF001F",
|
||||
"label": "plaything;toy"
|
||||
},
|
||||
{
|
||||
"index": 110,
|
||||
"color": "#00B8FF",
|
||||
"label": "swimming;pool;swimming;bath;natatorium"
|
||||
},
|
||||
{
|
||||
"index": 111,
|
||||
"color": "#00D6FF",
|
||||
"label": "stool"
|
||||
},
|
||||
{
|
||||
"index": 112,
|
||||
"color": "#FF0070",
|
||||
"label": "barrel;cask"
|
||||
},
|
||||
{
|
||||
"index": 113,
|
||||
"color": "#5CFF00",
|
||||
"label": "basket;handbasket"
|
||||
},
|
||||
{
|
||||
"index": 114,
|
||||
"color": "#00E0FF",
|
||||
"label": "waterfall;falls"
|
||||
},
|
||||
{
|
||||
"index": 115,
|
||||
"color": "#70E0FF",
|
||||
"label": "tent;collapsible;shelter"
|
||||
},
|
||||
{
|
||||
"index": 116,
|
||||
"color": "#46B8A0",
|
||||
"label": "bag"
|
||||
},
|
||||
{
|
||||
"index": 117,
|
||||
"color": "#A300FF",
|
||||
"label": "minibike;motorbike"
|
||||
},
|
||||
{
|
||||
"index": 118,
|
||||
"color": "#9900FF",
|
||||
"label": "cradle"
|
||||
},
|
||||
{
|
||||
"index": 119,
|
||||
"color": "#47FF00",
|
||||
"label": "oven"
|
||||
},
|
||||
{
|
||||
"index": 120,
|
||||
"color": "#FF00A3",
|
||||
"label": "ball"
|
||||
},
|
||||
{
|
||||
"index": 121,
|
||||
"color": "#FFCC00",
|
||||
"label": "food;solid;food"
|
||||
},
|
||||
{
|
||||
"index": 122,
|
||||
"color": "#FF008F",
|
||||
"label": "step;stair"
|
||||
},
|
||||
{
|
||||
"index": 123,
|
||||
"color": "#00FFEB",
|
||||
"label": "tank;storage;tank"
|
||||
},
|
||||
{
|
||||
"index": 124,
|
||||
"color": "#85FF00",
|
||||
"label": "trade;name;brand;name;brand;marque"
|
||||
},
|
||||
{
|
||||
"index": 125,
|
||||
"color": "#FF00EB",
|
||||
"label": "microwave;microwave;oven"
|
||||
},
|
||||
{
|
||||
"index": 126,
|
||||
"color": "#F500FF",
|
||||
"label": "pot;flowerpot"
|
||||
},
|
||||
{
|
||||
"index": 127,
|
||||
"color": "#FF007A",
|
||||
"label": "animal;animate;being;beast;brute;creature;fauna"
|
||||
},
|
||||
{
|
||||
"index": 128,
|
||||
"color": "#FFF500",
|
||||
"label": "bicycle;bike;wheel;cycle"
|
||||
},
|
||||
{
|
||||
"index": 129,
|
||||
"color": "#0ABED4",
|
||||
"label": "lake"
|
||||
},
|
||||
{
|
||||
"index": 130,
|
||||
"color": "#D6FF00",
|
||||
"label": "dishwasher;dish;washer;dishwashing;machine"
|
||||
},
|
||||
{
|
||||
"index": 131,
|
||||
"color": "#00CCFF",
|
||||
"label": "screen;silver;screen;projection;screen"
|
||||
},
|
||||
{
|
||||
"index": 132,
|
||||
"color": "#1400FF",
|
||||
"label": "blanket;cover"
|
||||
},
|
||||
{
|
||||
"index": 133,
|
||||
"color": "#FFFF00",
|
||||
"label": "sculpture"
|
||||
},
|
||||
{
|
||||
"index": 134,
|
||||
"color": "#0099FF",
|
||||
"label": "hood;exhaust;hood"
|
||||
},
|
||||
{
|
||||
"index": 135,
|
||||
"color": "#0029FF",
|
||||
"label": "sconce"
|
||||
},
|
||||
{
|
||||
"index": 136,
|
||||
"color": "#00FFCC",
|
||||
"label": "vase"
|
||||
},
|
||||
{
|
||||
"index": 137,
|
||||
"color": "#2900FF",
|
||||
"label": "traffic;light;traffic;signal;stoplight"
|
||||
},
|
||||
{
|
||||
"index": 138,
|
||||
"color": "#29FF00",
|
||||
"label": "tray"
|
||||
},
|
||||
{
|
||||
"index": 139,
|
||||
"color": "#AD00FF",
|
||||
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
|
||||
},
|
||||
{
|
||||
"index": 140,
|
||||
"color": "#00F5FF",
|
||||
"label": "fan"
|
||||
},
|
||||
{
|
||||
"index": 141,
|
||||
"color": "#4700FF",
|
||||
"label": "pier;wharf;wharfage;dock"
|
||||
},
|
||||
{
|
||||
"index": 142,
|
||||
"color": "#7A00FF",
|
||||
"label": "crt;screen"
|
||||
},
|
||||
{
|
||||
"index": 143,
|
||||
"color": "#00FFB8",
|
||||
"label": "plate"
|
||||
},
|
||||
{
|
||||
"index": 144,
|
||||
"color": "#005CFF",
|
||||
"label": "monitor;monitoring;device"
|
||||
},
|
||||
{
|
||||
"index": 145,
|
||||
"color": "#B8FF00",
|
||||
"label": "bulletin;board;notice;board"
|
||||
},
|
||||
{
|
||||
"index": 146,
|
||||
"color": "#0085FF",
|
||||
"label": "shower"
|
||||
},
|
||||
{
|
||||
"index": 147,
|
||||
"color": "#FFD600",
|
||||
"label": "radiator"
|
||||
},
|
||||
{
|
||||
"index": 148,
|
||||
"color": "#19C2C2",
|
||||
"label": "glass;drinking;glass"
|
||||
},
|
||||
{
|
||||
"index": 149,
|
||||
"color": "#66FF00",
|
||||
"label": "clock"
|
||||
},
|
||||
{
|
||||
"index": 150,
|
||||
"color": "#5C00FF",
|
||||
"label": "flag"
|
||||
}
|
||||
]
|
155
candle-examples/examples/segformer/main.rs
Normal file
155
candle-examples/examples/segformer/main.rs
Normal file
@ -0,0 +1,155 @@
|
||||
use candle::Device;
|
||||
use candle::Module;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segformer::{
|
||||
Config, ImageClassificationModel, SemanticSegmentationModel,
|
||||
};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use image::Rgb;
|
||||
use imageproc::integral_image::ArrayData;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(about, version, long_about = None)]
|
||||
struct CliArgs {
|
||||
#[arg(long, help = "use cpu")]
|
||||
cpu: bool,
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
#[derive(Args, Debug)]
|
||||
struct SegmentationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(
|
||||
long,
|
||||
help = "path to the label file in json format",
|
||||
default_value = "candle-examples/examples/segformer/assets/labels.json"
|
||||
)]
|
||||
label_path: PathBuf,
|
||||
#[arg(long, help = "path to for the output mask image")]
|
||||
output_path: PathBuf,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct ClassificationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "paolinox/segformer-finetuned-food101"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
Segment(SegmentationArgs),
|
||||
Classify(ClassificationArgs),
|
||||
}
|
||||
|
||||
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
|
||||
println!("loading model {} via huggingface hub", model_name);
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name.clone());
|
||||
let model_file = api.get("model.safetensors")?;
|
||||
println!("model {} downloaded and loaded", model_name);
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
|
||||
let config = std::fs::read_to_string(api.get("config.json")?)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
println!("{:?}", config);
|
||||
Ok((vb, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LabelItem {
|
||||
index: u32,
|
||||
color: String,
|
||||
}
|
||||
|
||||
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let label_file = std::fs::read_to_string(&args.label_path)?;
|
||||
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
|
||||
let label_colors: HashMap<u32, Rgb<u8>> = label_items
|
||||
.iter()
|
||||
.map(|x| {
|
||||
(x.index - 1, {
|
||||
let color = x.color.trim_start_matches('#');
|
||||
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
|
||||
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
|
||||
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
|
||||
Rgb([r, g, b])
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = label_items.len();
|
||||
|
||||
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
|
||||
let segmentations = model.forward(&image)?;
|
||||
|
||||
// generate a mask image
|
||||
let mask = &segmentations.squeeze(0)?.argmax(0)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
|
||||
let mask = mask
|
||||
.iter()
|
||||
.flat_map(|x| label_colors[x].data())
|
||||
.collect::<Vec<u8>>();
|
||||
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
|
||||
// resize
|
||||
let mask = image::DynamicImage::from(mask);
|
||||
let mask = mask.resize_to_fill(
|
||||
w as u32 * 4,
|
||||
h as u32 * 4,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
mask.save(args.output_path.clone())?;
|
||||
println!("mask image saved to {:?}", args.output_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = 7;
|
||||
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
|
||||
let classification = model.forward(&image)?;
|
||||
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
|
||||
let classification = classification.squeeze(0)?;
|
||||
println!(
|
||||
"classification logits {:?}",
|
||||
classification.to_vec1::<f32>()?
|
||||
);
|
||||
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
|
||||
let label_id = format!("{}", label_id);
|
||||
println!("label: {}", config.id2label[&label_id]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = CliArgs::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
if let Commands::Segment(args) = args.command {
|
||||
segmentation_task(args, &device)?
|
||||
} else if let Commands::Classify(args) = args.command {
|
||||
classification_task(args, &device)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user