mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use broadcasted scalars for const tensors.
This commit is contained in:
@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>(
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
|
||||
fn ones_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
@ -132,6 +130,8 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
// Maybe we should allocate some actual storage for vars rather than just using a
|
||||
// broadcasted scalar?
|
||||
Self::ones_impl(shape, dtype, device, true)
|
||||
}
|
||||
|
||||
@ -139,16 +139,14 @@ impl Tensor {
|
||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
|
||||
fn zeros_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
@ -599,8 +597,7 @@ impl Tensor {
|
||||
&self.layout
|
||||
}
|
||||
|
||||
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
||||
pub fn stride_tmp(&self) -> &[usize] {
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
self.layout.stride()
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user