diff --git a/candle-examples/examples/yolo-v3/main.rs b/candle-examples/examples/yolo-v3/main.rs index 5b1937ac..a6574697 100644 --- a/candle-examples/examples/yolo-v3/main.rs +++ b/candle-examples/examples/yolo-v3/main.rs @@ -43,6 +43,7 @@ pub fn report( confidence_threshold: f32, nms_threshold: f32, ) -> Result { + let pred = pred.to_device(&Device::Cpu)?; let (npreds, pred_size) = pred.dims2()?; let nclasses = pred_size - 5; // The bounding boxes grouped by (maximum) class index. diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index af8cf98a..54414fb5 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod model; use model::{Multiples, YoloV8, YoloV8Pose}; -use candle::{DType, IndexOp, Result, Tensor}; +use candle::{DType, IndexOp, Result, Tensor, Device}; use candle_nn::{Module, VarBuilder}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; @@ -61,6 +61,7 @@ pub fn report_detect( nms_threshold: f32, legend_size: u32, ) -> Result { + let pred = pred.to_device(&Device::Cpu)?; let (pred_size, npreds) = pred.dims2()?; let nclasses = pred_size - 4; // The bounding boxes grouped by (maximum) class index. @@ -153,6 +154,7 @@ pub fn report_pose( confidence_threshold: f32, nms_threshold: f32, ) -> Result { + let pred = pred.to_device(&Device::Cpu)?; let (pred_size, npreds) = pred.dims2()?; if pred_size != 17 * 3 + 4 + 1 { candle::bail!("unexpected pred-size {pred_size}");