mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Moving to gemm
and adding matmul backprop.
- Tentative `T` operator.
This commit is contained in:
@ -16,11 +16,11 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ggblas = "0.1.0"
|
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
cudarc = { version = "0.9.9", optional = true }
|
cudarc = { version = "0.9.9", optional = true }
|
||||||
candle-kernels = { path = "kernels", optional = true }
|
candle-kernels = { path = "kernels", optional = true }
|
||||||
|
gemm = "0.15.4"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::storage::{BinaryOp, UnaryOp};
|
use crate::storage::{BinaryOp, UnaryOp};
|
||||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||||
use ggblas::batched_sgemm;
|
use gemm::{gemm, Parallelism};
|
||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
@ -113,28 +113,83 @@ impl CpuStorage {
|
|||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
println!("rhs {rhs:?}");
|
|
||||||
println!("lhs_stride {lhs_stride:?}");
|
|
||||||
println!("rhs_stride {rhs_stride:?}");
|
|
||||||
// todo!("matmul");
|
|
||||||
let a_skip: usize = m * k;
|
let a_skip: usize = m * k;
|
||||||
let b_skip: usize = n * k;
|
let b_skip: usize = n * k;
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let mut c = Self::F32(vec![0.0; b * m * n]);
|
let rank = lhs_stride.len();
|
||||||
|
let lhs_cs = lhs_stride[rank - 1];
|
||||||
|
let lhs_rs = lhs_stride[rank - 2];
|
||||||
|
|
||||||
batched_sgemm(
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
self.as_slice()?,
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
a_skip,
|
|
||||||
rhs.as_slice()?,
|
if lhs_stride.len() > 2 {
|
||||||
b_skip,
|
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||||
c.as_mut_slice()?,
|
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||||
c_skip,
|
|
||||||
m,
|
if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] {
|
||||||
n,
|
// Temporary error before we support abitrary striding.
|
||||||
k,
|
return Err(Error::UnexpectedStriding);
|
||||||
b,
|
}
|
||||||
);
|
}
|
||||||
|
|
||||||
|
let mut dst = vec![0.0; b * m * n];
|
||||||
|
|
||||||
|
let dst_shape: Shape = (m, n).into();
|
||||||
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
|
let dst_rs = dst_strides[0];
|
||||||
|
let dst_cs = dst_strides[1];
|
||||||
|
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
|
||||||
|
let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
gemm(
|
||||||
|
// m: usize,
|
||||||
|
m,
|
||||||
|
// n: usize,
|
||||||
|
n,
|
||||||
|
// k: usize,
|
||||||
|
k,
|
||||||
|
// dst: *mut T,
|
||||||
|
dst_p.as_mut_ptr(),
|
||||||
|
// dst_cs: isize,
|
||||||
|
dst_cs as isize,
|
||||||
|
// dst_rs: isize,
|
||||||
|
dst_rs as isize,
|
||||||
|
// read_dst: bool,
|
||||||
|
false,
|
||||||
|
// lhs: *const T,
|
||||||
|
lhs_p.as_ptr(),
|
||||||
|
// lhs_cs: isize,
|
||||||
|
lhs_cs as isize,
|
||||||
|
// lhs_rs: isize,
|
||||||
|
lhs_rs as isize,
|
||||||
|
// rhs: *const T,
|
||||||
|
rhs_p.as_ptr(),
|
||||||
|
// rhs_cs: isize,
|
||||||
|
rhs_cs as isize,
|
||||||
|
// rhs_rs: isize,
|
||||||
|
rhs_rs as isize,
|
||||||
|
// alpha: T,
|
||||||
|
1.0,
|
||||||
|
// beta: T,
|
||||||
|
1.0,
|
||||||
|
// conj_dst: bool,
|
||||||
|
false,
|
||||||
|
// conj_lhs: bool,
|
||||||
|
false,
|
||||||
|
// conj_rhs: bool,
|
||||||
|
true,
|
||||||
|
// parallelism: Parallelism
|
||||||
|
Parallelism::None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let c = Self::F32(dst);
|
||||||
Ok(c)
|
Ok(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,31 +230,31 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn simple_matmul() -> Result<()> {
|
fn simple_matmul() -> Result<()> {
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
|
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
|
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
|
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
let data = vec![1.0f32, 2.0];
|
||||||
let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?;
|
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
|
||||||
let data = vec![3.0f32, 4.0];
|
let data = vec![3.0f32, 4.0];
|
||||||
let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?;
|
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?;
|
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
|
||||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?;
|
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?;
|
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
|
||||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?;
|
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
c.to_vec3::<f32>()?,
|
c.to_vec3::<f32>()?,
|
||||||
|
@ -40,6 +40,9 @@ pub enum Error {
|
|||||||
shape: Shape,
|
shape: Shape,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
||||||
|
UnexpectedStriding,
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(#[from] crate::CudaError),
|
Cuda(#[from] crate::CudaError),
|
||||||
}
|
}
|
||||||
|
108
src/tensor.rs
108
src/tensor.rs
@ -147,10 +147,11 @@ impl Tensor {
|
|||||||
|
|
||||||
pub fn new_impl<A: crate::device::NdArray>(
|
pub fn new_impl<A: crate::device::NdArray>(
|
||||||
array: A,
|
array: A,
|
||||||
|
shape: Shape,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = array.shape()?;
|
// let shape = array.shape()?;
|
||||||
let storage = device.storage(array)?;
|
let storage = device.storage(array)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
@ -165,31 +166,29 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||||
Self::new_impl(array, device, false)
|
let shape = array.shape()?.clone();
|
||||||
|
Self::new_impl(array, shape, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||||
Self::new_impl(array, device, true)
|
let shape = array.shape()?.clone();
|
||||||
|
Self::new_impl(array, shape, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
a: &[D],
|
array: &[D],
|
||||||
shape: S,
|
shape: S,
|
||||||
device: Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
Self::new_impl(array, shape.into(), device, false)
|
||||||
let storage = device.storage(a)?;
|
}
|
||||||
let stride = shape.stride_contiguous();
|
|
||||||
let is_variable = false;
|
pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
let tensor_ = Tensor_ {
|
array: &[D],
|
||||||
id: TensorId::new(),
|
shape: S,
|
||||||
storage,
|
device: &Device,
|
||||||
shape,
|
) -> Result<Self> {
|
||||||
stride,
|
Self::new_impl(array, shape.into(), device, true)
|
||||||
op: None,
|
|
||||||
is_variable,
|
|
||||||
};
|
|
||||||
Ok(Self(Arc::new(tensor_)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
@ -260,6 +259,7 @@ impl Tensor {
|
|||||||
|
|
||||||
let dim = a_dims.len();
|
let dim = a_dims.len();
|
||||||
|
|
||||||
|
// TODO
|
||||||
// if dim < 2 {
|
// if dim < 2 {
|
||||||
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
|
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
|
||||||
// }
|
// }
|
||||||
@ -309,6 +309,13 @@ impl Tensor {
|
|||||||
crate::StridedIndex::new(self.dims(), self.stride())
|
crate::StridedIndex::new(self.dims(), self.stride())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
|
||||||
|
match &self.storage {
|
||||||
|
Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
|
||||||
|
Storage::Cuda { .. } => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||||
if self.rank() != 1 {
|
if self.rank() != 1 {
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -404,6 +411,31 @@ impl Tensor {
|
|||||||
self.id
|
self.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn t(&self) -> Result<Tensor> {
|
||||||
|
let mut stride = self.stride().to_vec();
|
||||||
|
let mut shape = self.shape().clone();
|
||||||
|
let n = stride.len();
|
||||||
|
if n < 2 {
|
||||||
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: 2,
|
||||||
|
got: n,
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
|
||||||
|
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage: self.storage.clone(),
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
// TODO The op should have a backward
|
||||||
|
op: None,
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_contiguous(&self) -> bool {
|
pub fn is_contiguous(&self) -> bool {
|
||||||
self.shape.is_contiguous(&self.stride)
|
self.shape.is_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
@ -514,37 +546,17 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Matmul(_lhs, _rhs) => {
|
Op::Matmul(lhs, rhs) => {
|
||||||
// let (m, k) = lhs.shape;
|
// Skipping checks, the op went ok, we can skip
|
||||||
// let n = rhs.shape.1;
|
// the matmul size checks for now.
|
||||||
// let strides = (m, n).strides();
|
|
||||||
// Self::matmul(
|
|
||||||
// (m, n, k),
|
|
||||||
// true,
|
|
||||||
// grad_out.as_ptr(),
|
|
||||||
// strides,
|
|
||||||
// rhs.data.as_ptr(),
|
|
||||||
// [rhs.strides[1], rhs.strides[0]],
|
|
||||||
// grad_lhs.as_mut_ptr(),
|
|
||||||
// lhs.strides,
|
|
||||||
// );
|
|
||||||
// Self::matmul(
|
|
||||||
// (k, m, n),
|
|
||||||
// true,
|
|
||||||
// lhs.data.as_ptr(),
|
|
||||||
// [lhs.strides[1], lhs.strides[0]],
|
|
||||||
// grad_out.as_ptr(),
|
|
||||||
// strides,
|
|
||||||
// grad_rhs.as_mut_ptr(),
|
|
||||||
// rhs.strides,
|
|
||||||
// );
|
|
||||||
|
|
||||||
// let lhs_grad = grad.matmul(rhs)?;
|
let lhs_grad = grad.matmul(&rhs.t()?)?;
|
||||||
// let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
// *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
// let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
|
||||||
// let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_grad = lhs.t()?.matmul(&grad)?;
|
||||||
// *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Device, Shape, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn simple_grad() -> Result<()> {
|
fn simple_grad() -> Result<()> {
|
||||||
@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> {
|
|||||||
assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]);
|
assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_grad() -> Result<()> {
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let c = x.matmul(&y)?;
|
||||||
|
let grads = c.backward()?;
|
||||||
|
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||||
|
let grad_y = grads.get(&y).context("no grad for y")?;
|
||||||
|
assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
|
||||||
|
assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
|
||||||
|
assert_eq!(
|
||||||
|
grad_x.as_slice::<f32>()?,
|
||||||
|
&[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
grad_y.as_slice::<f32>()?,
|
||||||
|
&[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user