mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
15
.pre-commit-config.yaml
Normal file
15
.pre-commit-config.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
repos:
|
||||
- repo: https://github.com/Narsil/pre-commit-rust
|
||||
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
|
||||
hooks:
|
||||
- id: fmt
|
||||
name: "Rust (fmt)"
|
||||
- id: clippy
|
||||
name: "Rust (clippy)"
|
||||
args:
|
||||
[
|
||||
"--tests",
|
||||
"--examples",
|
||||
"--",
|
||||
"-Dwarnings",
|
||||
]
|
@ -20,12 +20,13 @@ safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
cudarc = { version = "0.9.9", optional = true }
|
||||
candle-kernels = { path = "kernels", optional = true }
|
||||
gemm = "0.15.4"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
rand = "0.8.5"
|
||||
tokenizers = "0.13.3"
|
||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::storage::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||
use gemm::{gemm, Parallelism};
|
||||
|
||||
// TODO: Think about whether we would be better off with a dtype and
|
||||
// a buffer as an owned slice of bytes.
|
||||
@ -17,6 +18,14 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
|
||||
D::cpu_storage_as_mut_slice(self)
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
@ -97,6 +106,93 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rank = lhs_stride.len();
|
||||
let lhs_cs = lhs_stride[rank - 1];
|
||||
let lhs_rs = lhs_stride[rank - 2];
|
||||
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
if lhs_stride.len() > 2 {
|
||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||
|
||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||
// Temporary error before we support abitrary striding.
|
||||
return Err(Error::UnexpectedStriding);
|
||||
}
|
||||
}
|
||||
|
||||
let mut dst = vec![0.0; b * m * n];
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
let dst_strides = dst_shape.stride_contiguous();
|
||||
let dst_rs = dst_strides[0];
|
||||
let dst_cs = dst_strides[1];
|
||||
|
||||
for step in 0..b {
|
||||
let lhs_p = &self.as_slice::<f32>()?[step * a_skip..];
|
||||
let rhs_p = &rhs.as_slice::<f32>()?[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
1.0,
|
||||
// beta: T,
|
||||
1.0,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let c = Self::F32(dst);
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
@ -125,3 +221,45 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn simple_matmul() -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(
|
||||
c.to_vec3::<f32>()?,
|
||||
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn tensor<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
|
11
src/dtype.rs
11
src/dtype.rs
@ -25,6 +25,7 @@ pub trait WithDType: Sized + Copy {
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
|
||||
}
|
||||
|
||||
macro_rules! with_dtype {
|
||||
@ -45,6 +46,16 @@ macro_rules! with_dtype {
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -12,6 +12,11 @@ pub enum Error {
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error(
|
||||
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
|
||||
)]
|
||||
ShapeMismatch { buffer_size: usize, shape: Shape },
|
||||
|
||||
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
ShapeMismatchBinaryOp {
|
||||
lhs: Shape,
|
||||
@ -40,6 +45,10 @@ pub enum Error {
|
||||
shape: Shape,
|
||||
},
|
||||
|
||||
// TODO this is temporary when we support arbitrary matmul
|
||||
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
||||
UnexpectedStriding,
|
||||
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] crate::CudaError),
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ pub(crate) enum Op {
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
Matmul(Tensor, Tensor),
|
||||
|
||||
#[allow(dead_code)] // add is currently unused.
|
||||
Affine {
|
||||
|
@ -241,4 +241,22 @@ impl Storage {
|
||||
pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
self.unary_impl::<Sqrt>(shape, stride)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "matmul")?;
|
||||
self.same_dtype(rhs, "matmul")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => {
|
||||
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
148
src/tensor.rs
148
src/tensor.rs
@ -147,11 +147,16 @@ impl Tensor {
|
||||
|
||||
pub fn new_impl<A: crate::device::NdArray>(
|
||||
array: A,
|
||||
shape: Shape,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = array.shape()?;
|
||||
let storage = device.tensor(array)?;
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.shape()?.elem_count();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape });
|
||||
}
|
||||
let storage = device.storage(array)?;
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -165,11 +170,29 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn new<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||
Self::new_impl(array, device, false)
|
||||
let shape = array.shape()?;
|
||||
Self::new_impl(array, shape, device, false)
|
||||
}
|
||||
|
||||
pub fn var<A: crate::device::NdArray>(array: A, device: &Device) -> Result<Self> {
|
||||
Self::new_impl(array, device, true)
|
||||
let shape = array.shape()?;
|
||||
Self::new_impl(array, shape, device, true)
|
||||
}
|
||||
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::new_impl(array, shape.into(), device, false)
|
||||
}
|
||||
|
||||
pub fn var_from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
Self::new_impl(array, shape.into(), device, true)
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
@ -234,10 +257,65 @@ impl Tensor {
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn matmul(&self, rhs: &Self) -> Result<Self> {
|
||||
let a_dims = self.shape().dims();
|
||||
let b_dims = rhs.shape().dims();
|
||||
|
||||
let dim = a_dims.len();
|
||||
|
||||
if dim < 2 || b_dims.len() != dim {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
});
|
||||
}
|
||||
|
||||
let m = a_dims[dim - 2];
|
||||
let k = a_dims[dim - 1];
|
||||
let k2 = b_dims[dim - 2];
|
||||
let n = b_dims[dim - 1];
|
||||
if k != k2 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "matmul",
|
||||
});
|
||||
}
|
||||
|
||||
let mut c_shape: Vec<_> = a_dims[..dim - 2].into();
|
||||
c_shape.extend(&[m, n]);
|
||||
let c_shape = Shape(c_shape);
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
|
||||
let storage = self.storage.matmul_impl(
|
||||
&rhs.storage,
|
||||
(batching, m, n, k),
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: c_shape.clone(),
|
||||
stride: c_shape.stride_contiguous(),
|
||||
op: Some(Op::Matmul(self.clone(), rhs.clone())),
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Self(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(self.dims(), self.stride())
|
||||
}
|
||||
|
||||
pub fn as_slice<S: crate::WithDType>(&self) -> Result<&[S]> {
|
||||
match &self.storage {
|
||||
Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage),
|
||||
Storage::Cuda { .. } => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
|
||||
if self.rank() != 1 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
@ -279,6 +357,28 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
|
||||
let (dim1, dim2, dim3) = self.shape().r3()?;
|
||||
match &self.storage {
|
||||
Storage::Cpu(cpu_storage) => {
|
||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||
let mut top_rows = vec![];
|
||||
let mut src_index = self.strided_index();
|
||||
for _idx in 0..dim1 {
|
||||
let mut rows = vec![];
|
||||
for _jdx in 0..dim2 {
|
||||
let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect();
|
||||
rows.push(row)
|
||||
}
|
||||
top_rows.push(rows);
|
||||
}
|
||||
assert!(src_index.next().is_none());
|
||||
Ok(top_rows)
|
||||
}
|
||||
Storage::Cuda { .. } => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.storage.dtype()
|
||||
}
|
||||
@ -311,6 +411,31 @@ impl Tensor {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn t(&self) -> Result<Tensor> {
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut shape = self.shape().clone();
|
||||
let n = stride.len();
|
||||
if n < 2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: 2,
|
||||
got: n,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
(shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]);
|
||||
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
stride,
|
||||
// TODO The op should have a backward
|
||||
op: None,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.shape.is_contiguous(&self.stride)
|
||||
}
|
||||
@ -340,7 +465,8 @@ impl Tensor {
|
||||
Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs) => {
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -420,6 +546,18 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
// Skipping checks, the op went ok, we can skip
|
||||
// the matmul size checks for now.
|
||||
|
||||
let lhs_grad = grad.matmul(&rhs.t()?)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
|
||||
let rhs_grad = lhs.t()?.matmul(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -1,5 +1,5 @@
|
||||
use anyhow::{Context, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle::{Device, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn simple_grad() -> Result<()> {
|
||||
@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> {
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [11., 7., 13.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_grad() -> Result<()> {
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?;
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?;
|
||||
|
||||
let c = x.matmul(&y)?;
|
||||
let grads = c.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
let grad_y = grads.get(&y).context("no grad for y")?;
|
||||
assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3)));
|
||||
assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2)));
|
||||
assert_eq!(
|
||||
grad_x.as_slice::<f32>()?,
|
||||
&[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.]
|
||||
);
|
||||
assert_eq!(
|
||||
grad_y.as_slice::<f32>()?,
|
||||
&[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user