mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use broadcasted scalars for const tensors.
This commit is contained in:
@ -99,7 +99,7 @@ impl Tensor {
|
|||||||
pub fn backward(&self) -> Result<GradStore> {
|
pub fn backward(&self) -> Result<GradStore> {
|
||||||
let sorted_nodes = self.sorted_nodes();
|
let sorted_nodes = self.sorted_nodes();
|
||||||
let mut grads = GradStore::new();
|
let mut grads = GradStore::new();
|
||||||
grads.insert(self, self.ones_like()?);
|
grads.insert(self, self.ones_like()?.contiguous()?);
|
||||||
for node in sorted_nodes.iter() {
|
for node in sorted_nodes.iter() {
|
||||||
if node.is_variable() {
|
if node.is_variable() {
|
||||||
continue;
|
continue;
|
||||||
|
@ -3,6 +3,8 @@ use crate::{Error, Result};
|
|||||||
#[derive(Clone, PartialEq, Eq)]
|
#[derive(Clone, PartialEq, Eq)]
|
||||||
pub struct Shape(Vec<usize>);
|
pub struct Shape(Vec<usize>);
|
||||||
|
|
||||||
|
pub const SCALAR: Shape = Shape(vec![]);
|
||||||
|
|
||||||
impl std::fmt::Debug for Shape {
|
impl std::fmt::Debug for Shape {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "{:?}", &self.dims())
|
write!(f, "{:?}", &self.dims())
|
||||||
|
@ -115,16 +115,14 @@ fn from_storage<S: Into<Shape>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
|
|
||||||
fn ones_impl<S: Into<Shape>>(
|
fn ones_impl<S: Into<Shape>>(
|
||||||
shape: S,
|
shape: S,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||||
let storage = device.ones(&shape, dtype)?;
|
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||||
Ok(from_storage(storage, shape, None, is_variable))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
pub fn ones<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
@ -132,6 +130,8 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
pub fn ones_var<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
|
// Maybe we should allocate some actual storage for vars rather than just using a
|
||||||
|
// broadcasted scalar?
|
||||||
Self::ones_impl(shape, dtype, device, true)
|
Self::ones_impl(shape, dtype, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,16 +139,14 @@ impl Tensor {
|
|||||||
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
Tensor::ones(self.shape(), self.dtype(), &self.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Maybe this should be a broadcast rather than actually creating the full tensor.
|
|
||||||
fn zeros_impl<S: Into<Shape>>(
|
fn zeros_impl<S: Into<Shape>>(
|
||||||
shape: S,
|
shape: S,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||||
let storage = device.zeros(&shape, dtype)?;
|
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||||
Ok(from_storage(storage, shape, None, is_variable))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
pub fn zeros<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Self> {
|
||||||
@ -599,8 +597,7 @@ impl Tensor {
|
|||||||
&self.layout
|
&self.layout
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
pub fn stride(&self) -> &[usize] {
|
||||||
pub fn stride_tmp(&self) -> &[usize] {
|
|
||||||
self.layout.stride()
|
self.layout.stride()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ fn matmul_grad(device: &Device) -> Result<()> {
|
|||||||
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
|
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
|
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
|
||||||
|
|
||||||
let c = x.matmul(&y)?;
|
let c = x.matmul(&y)?;
|
||||||
let grads = c.backward()?;
|
let grads = c.backward()?;
|
||||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
|
|||||||
let a_tt = a.t()?.contiguous()?.t()?;
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
assert!(!a_tt.is_contiguous());
|
assert!(!a_tt.is_contiguous());
|
||||||
assert_eq!(a.dims(), a_tt.dims());
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
|
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||||
|
|
||||||
let b_tt = b.t()?.contiguous()?.t()?;
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
assert!(!b_tt.is_contiguous());
|
assert!(!b_tt.is_contiguous());
|
||||||
assert_eq!(b.dims(), b_tt.dims());
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
|
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||||
|
|
||||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Reference in New Issue
Block a user