mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Adding matmul?
This commit is contained in:
@ -16,6 +16,7 @@ members = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
ggblas = "0.1.0"
|
||||||
safetensors = "0.3.1"
|
safetensors = "0.3.1"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
cudarc = { version = "0.9.9", optional = true }
|
cudarc = { version = "0.9.9", optional = true }
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::storage::{BinaryOp, UnaryOp};
|
use crate::storage::{BinaryOp, UnaryOp};
|
||||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||||
|
use ggblas::batched_sgemm;
|
||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
// TODO: Think about whether we would be better off with a dtype and
|
||||||
// a buffer as an owned slice of bytes.
|
// a buffer as an owned slice of bytes.
|
||||||
@ -17,6 +18,14 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
|
||||||
|
D::cpu_storage_as_slice(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
|
||||||
|
D::cpu_storage_as_mut_slice(self)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(
|
pub(crate) fn affine_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
@ -97,6 +106,38 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn matmul_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
println!("rhs {rhs:?}");
|
||||||
|
println!("lhs_stride {lhs_stride:?}");
|
||||||
|
println!("rhs_stride {rhs_stride:?}");
|
||||||
|
// todo!("matmul");
|
||||||
|
let a_skip: usize = m * k;
|
||||||
|
let b_skip: usize = n * k;
|
||||||
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
|
let mut c = Self::F32(vec![0.0; b * m * n]);
|
||||||
|
|
||||||
|
batched_sgemm(
|
||||||
|
self.as_slice()?,
|
||||||
|
a_skip,
|
||||||
|
rhs.as_slice()?,
|
||||||
|
b_skip,
|
||||||
|
c.as_mut_slice()?,
|
||||||
|
c_skip,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
b,
|
||||||
|
);
|
||||||
|
Ok(c)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
@ -125,3 +166,45 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{Device, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn simple_matmul() -> Result<()> {
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
|
||||||
|
|
||||||
|
let data = vec![1.0f32, 2.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?;
|
||||||
|
let data = vec![3.0f32, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?;
|
||||||
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(
|
||||||
|
c.to_vec3::<f32>()?,
|
||||||
|
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -101,7 +101,7 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
|
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
|
11
src/dtype.rs
11
src/dtype.rs
@ -25,6 +25,7 @@ pub trait WithDType: Sized + Copy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||||
|
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! with_dtype {
|
macro_rules! with_dtype {
|
||||||
@ -45,6 +46,16 @@ macro_rules! with_dtype {
|
|||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
|
||||||
|
match s {
|
||||||
|
CpuStorage::$dtype(data) => Ok(data),
|
||||||
|
_ => Err(Error::UnexpectedDType {
|
||||||
|
expected: DType::$dtype,
|
||||||
|
got: s.dtype(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ pub(crate) enum Op {
|
|||||||
Mul(Tensor, Tensor),
|
Mul(Tensor, Tensor),
|
||||||
Sub(Tensor, Tensor),
|
Sub(Tensor, Tensor),
|
||||||
Div(Tensor, Tensor),
|
Div(Tensor, Tensor),
|
||||||
|
Matmul(Tensor, Tensor),
|
||||||
|
|
||||||
#[allow(dead_code)] // add is currently unused.
|
#[allow(dead_code)] // add is currently unused.
|
||||||
Affine {
|
Affine {
|
||||||
|
@ -241,4 +241,22 @@ impl Storage {
|
|||||||
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
self.unary_impl::<Sqrt>(shape, stride)
|
self.unary_impl::<Sqrt>(shape, stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn matmul_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
bmnk: (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(rhs, "matmul")?;
|
||||||
|
self.same_dtype(rhs, "matmul")?;
|
||||||
|
match (self, rhs) {
|
||||||
|
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
|
||||||
|
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
_ => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
130
src/tensor.rs
130
src/tensor.rs
@ -151,7 +151,7 @@ impl Tensor {
|
|||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = array.shape()?;
|
let shape = array.shape()?;
|
||||||
let storage = device.tensor(array)?;
|
let storage = device.storage(array)?;
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -172,6 +172,26 @@ impl Tensor {
|
|||||||
Self::new_impl(array, device, true)
|
Self::new_impl(array, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||||
|
a: &[D],
|
||||||
|
shape: S,
|
||||||
|
device: Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let shape = shape.into();
|
||||||
|
let storage = device.storage(a);
|
||||||
|
let stride = shape.stride_contiguous();
|
||||||
|
let is_variable = false;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
op: None,
|
||||||
|
is_variable,
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
let lhs = self.shape();
|
let lhs = self.shape();
|
||||||
let rhs = rhs.shape();
|
let rhs = rhs.shape();
|
||||||
@ -234,6 +254,57 @@ impl Tensor {
|
|||||||
Ok(Self(Arc::new(tensor_)))
|
Ok(Self(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
let a_dims = self.shape().dims();
|
||||||
|
let b_dims = rhs.shape().dims();
|
||||||
|
|
||||||
|
let dim = a_dims.len();
|
||||||
|
|
||||||
|
// if dim < 2 {
|
||||||
|
// return Err(SmeltError::InsufficientRank { minimum_rank: 2 });
|
||||||
|
// }
|
||||||
|
if b_dims.len() != dim {
|
||||||
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: rhs.shape().clone(),
|
||||||
|
op: "matmul",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let m = a_dims[dim - 2];
|
||||||
|
let k = a_dims[dim - 1];
|
||||||
|
let k2 = b_dims[dim - 2];
|
||||||
|
let n = b_dims[dim - 1];
|
||||||
|
if k != k2 {
|
||||||
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: rhs.shape().clone(),
|
||||||
|
op: "matmul",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
|
||||||
|
c_shape.extend(&[m, n]);
|
||||||
|
let c_shape: Shape = Shape(c_shape);
|
||||||
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
|
|
||||||
|
let storage = self.storage.matmul_impl(
|
||||||
|
&rhs.storage,
|
||||||
|
(batching, m, n, k),
|
||||||
|
self.stride(),
|
||||||
|
rhs.stride(),
|
||||||
|
)?;
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage,
|
||||||
|
shape: c_shape.clone(),
|
||||||
|
stride: c_shape.stride_contiguous(),
|
||||||
|
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Self(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
crate::StridedIndex::new(self.dims(), self.stride())
|
crate::StridedIndex::new(self.dims(), self.stride())
|
||||||
}
|
}
|
||||||
@ -279,6 +350,28 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
|
||||||
|
let (dim1, dim2, dim3) = self.shape().r3()?;
|
||||||
|
match &self.storage {
|
||||||
|
Storage::Cpu(cpu_storage) => {
|
||||||
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
|
let mut top_rows = vec![];
|
||||||
|
let mut src_index = self.strided_index();
|
||||||
|
for _idx in 0..dim1 {
|
||||||
|
let mut rows = vec![];
|
||||||
|
for _jdx in 0..dim2 {
|
||||||
|
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
|
||||||
|
rows.push(row)
|
||||||
|
}
|
||||||
|
top_rows.push(rows);
|
||||||
|
}
|
||||||
|
assert!(src_index.next().is_none());
|
||||||
|
Ok(top_rows)
|
||||||
|
}
|
||||||
|
Storage::Cuda { .. } => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
self.storage.dtype()
|
self.storage.dtype()
|
||||||
}
|
}
|
||||||
@ -340,7 +433,8 @@ impl Tensor {
|
|||||||
Op::Add(lhs, rhs)
|
Op::Add(lhs, rhs)
|
||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
| Op::Div(lhs, rhs) => {
|
| Op::Div(lhs, rhs)
|
||||||
|
| Op::Matmul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||||
@ -420,6 +514,38 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Matmul(lhs, rhs) => {
|
||||||
|
// let (m, k) = lhs.shape;
|
||||||
|
// let n = rhs.shape.1;
|
||||||
|
// let strides = (m, n).strides();
|
||||||
|
// Self::matmul(
|
||||||
|
// (m, n, k),
|
||||||
|
// true,
|
||||||
|
// grad_out.as_ptr(),
|
||||||
|
// strides,
|
||||||
|
// rhs.data.as_ptr(),
|
||||||
|
// [rhs.strides[1], rhs.strides[0]],
|
||||||
|
// grad_lhs.as_mut_ptr(),
|
||||||
|
// lhs.strides,
|
||||||
|
// );
|
||||||
|
// Self::matmul(
|
||||||
|
// (k, m, n),
|
||||||
|
// true,
|
||||||
|
// lhs.data.as_ptr(),
|
||||||
|
// [lhs.strides[1], lhs.strides[0]],
|
||||||
|
// grad_out.as_ptr(),
|
||||||
|
// strides,
|
||||||
|
// grad_rhs.as_mut_ptr(),
|
||||||
|
// rhs.strides,
|
||||||
|
// );
|
||||||
|
|
||||||
|
let lhs_grad = grad.matmul(rhs)?;
|
||||||
|
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||||
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
|
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||||
|
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
Reference in New Issue
Block a user