diff --git a/README.md b/README.md index dbb5d583..ef1e55dd 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,10 @@ Check out our [examples](./candle-examples/examples/): - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation. - [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to - image generative model, yet to be optimized. + image generative model. +- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained + using self-supervision (can be used for imagenet classification, depth + evaluation, segmentation). Run them using the following commands: ``` @@ -38,6 +41,7 @@ cargo run --example falcon --release cargo run --example bert --release cargo run --example bigcode --release cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch" +cargo run --example dinov2 --release -- --image path/to/myinput.jpg ``` In order to use **CUDA** add `--features cuda` to the example command line. If @@ -75,6 +79,7 @@ And then head over to - LLMs: Llama v1 and v2, Falcon, StarCoder. - Whisper (multi-lingual support). - Stable Diffusion. + - Computer Vision: DINOv2. - Serverless (on CPU), small and fast deployments. - Quantization support using the llama.cpp quantized types. diff --git a/candle-examples/examples/dinov2/main.rs b/candle-examples/examples/dinov2/main.rs index 2de28459..44204b28 100644 --- a/candle-examples/examples/dinov2/main.rs +++ b/candle-examples/examples/dinov2/main.rs @@ -291,7 +291,7 @@ pub fn vit_small(vb: VarBuilder) -> Result { #[derive(Parser)] struct Args { #[arg(long)] - model: String, + model: Option, #[arg(long)] image: String, @@ -309,7 +309,15 @@ pub fn main() -> anyhow::Result<()> { let image = candle_examples::load_image224(args.image)?; println!("loaded image {image:?}"); - let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-dino-v2".into()); + api.get("dinov2_vits14.safetensors")? + } + Some(model) => model.into(), + }; + let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); let model = vit_small(vb)?;