From ca19a9af6220366bb0beaea1eb7f34a1a8a2e07b Mon Sep 17 00:00:00 2001 From: MilkFather <31627231+MilkFather@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:35:13 +0800 Subject: [PATCH] Fix linspace implementation (#1358) * Fix linspace implementation `steps` should be strictly greater than 1 to make it consistent with the context. * Handle steps == 0 and steps == 1. * Fix rustfmt. --------- Co-authored-by: laurent --- candle-examples/examples/yolo-v8/main.rs | 2 +- .../src/models/stable_diffusion/utils.rs | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 54414fb5..c65a5ca1 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -7,7 +7,7 @@ extern crate accelerate_src; mod model; use model::{Multiples, YoloV8, YoloV8Pose}; -use candle::{DType, IndexOp, Result, Tensor, Device}; +use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Module, VarBuilder}; use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint}; use clap::{Parser, ValueEnum}; diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index 0c95cfef..cef06f1c 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -1,12 +1,15 @@ use candle::{Device, Result, Tensor}; pub fn linspace(start: f64, stop: f64, steps: usize) -> Result { - if steps < 1 { - candle::bail!("cannot use linspace with steps {steps} <= 1") + if steps == 0 { + Tensor::from_vec(Vec::::new(), steps, &Device::Cpu) + } else if steps == 1 { + Tensor::from_vec(vec![start], steps, &Device::Cpu) + } else { + let delta = (stop - start) / (steps - 1) as f64; + let vs = (0..steps) + .map(|step| start + step as f64 * delta) + .collect::>(); + Tensor::from_vec(vs, steps, &Device::Cpu) } - let delta = (stop - start) / (steps - 1) as f64; - let vs = (0..steps) - .map(|step| start + step as f64 * delta) - .collect::>(); - Tensor::from_vec(vs, steps, &Device::Cpu) }