mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use broadcasted scalars for const tensors.
This commit is contained in:
@ -99,7 +99,7 @@ impl Tensor {
|
||||
pub fn backward(&self) -> Result<GradStore> {
|
||||
let sorted_nodes = self.sorted_nodes();
|
||||
let mut grads = GradStore::new();
|
||||
grads.insert(self, self.ones_like()?);
|
||||
grads.insert(self, self.ones_like()?.contiguous()?);
|
||||
for node in sorted_nodes.iter() {
|
||||
if node.is_variable() {
|
||||
continue;
|
||||
|
@ -3,6 +3,8 @@ use crate::{Error, Result};
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Shape(Vec<usize>);
|
||||
|
||||
pub const SCALAR: Shape = Shape(vec![]);
|
||||
|
||||
impl std::fmt::Debug for Shape {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", &self.dims())
|
||||
|
@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>(
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
|
||||
fn ones_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
@ -132,6 +130,8 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
// Maybe we should allocate some actual storage for vars rather than just using a
|
||||
// broadcasted scalar?
|
||||
Self::ones_impl(shape, dtype, device, true)
|
||||
}
|
||||
|
||||
@ -139,16 +139,14 @@ impl Tensor {
|
||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||
}
|
||||
|
||||
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
|
||||
fn zeros_impl<S: Into<Shape>>(
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
|
||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||
@ -599,8 +597,7 @@ impl Tensor {
|
||||
&self.layout
|
||||
}
|
||||
|
||||
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
||||
pub fn stride_tmp(&self) -> &[usize] {
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
self.layout.stride()
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,6 @@ fn matmul_grad(device: &Device) -> Result<()> {
|
||||
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
|
||||
|
||||
let c = x.matmul(&y)?;
|
||||
let grads = c.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
|
Reference in New Issue
Block a user