Refactor the hierarchy.

This commit is contained in:
Nicolas Patry
2023-06-27 11:57:27 +02:00
parent 6c4a960b15
commit d7f729fb8f
41 changed files with 35 additions and 33 deletions

274
candle-core/src/backprop.rs Normal file
View File

@ -0,0 +1,274 @@
use crate::{op::Op, Error, Result, Tensor, TensorId};
use std::collections::HashMap;
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
/// argument.
/// This assumes that the op graph is a DAG.
fn sorted_nodes(&self) -> Vec<&Tensor> {
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
// to get around some lifetime limitations.
fn walk<'a>(
node: &'a Tensor,
nodes: Vec<&'a Tensor>,
already_seen: &mut HashMap<TensorId, bool>,
) -> (bool, Vec<&'a Tensor>) {
if let Some(&tg) = already_seen.get(&node.id()) {
return (tg, nodes);
}
let mut track_grad = false;
let mut nodes = if node.is_variable() {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if let Some(op) = node.op() {
match op {
Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t3, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
let (tg, nodes) = walk(arg, nodes, already_seen);
track_grad |= tg;
nodes
}),
Op::Affine { arg, mul, .. } => {
if *mul == 0. {
nodes
} else {
let (tg, nodes) = walk(arg, nodes, already_seen);
track_grad |= tg;
nodes
}
}
Op::Reshape(node)
| Op::Broadcast(node)
| Op::Sum(node, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Sqr(node)
| Op::Sqrt(node)
| Op::Gelu(node)
| Op::Exp(node)
| Op::Log(node)
| Op::Sin(node)
| Op::Cos(node)
| Op::Abs(node)
| Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
}
}
} else {
nodes
};
already_seen.insert(node.id(), track_grad);
if track_grad {
nodes.push(node);
}
(track_grad, nodes)
}
let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
nodes.reverse();
nodes
}
pub fn backward(&self) -> Result<GradStore> {
let sorted_nodes = self.sorted_nodes();
let mut grads = GradStore::new();
grads.insert(self, self.ones_like()?);
for node in sorted_nodes.iter() {
if node.is_variable() {
continue;
}
let grad = grads.remove(node).unwrap();
// TODO: We should perform all these operations in place (or at least not track the
// whole graph).
// The only drawback would be if we wanted to support grad of grad but this is out of
// scope.
if let Some(op) = node.op() {
match op {
Op::Add(lhs, rhs) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
Op::Sub(lhs, rhs) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
}
Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Div(lhs, rhs) => {
let lhs_grad = grad.div(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::WhereCond(_pred, _t, _f) => {
return Err(Error::BackwardNotSupported { op: "where_cond" })
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}
Op::Matmul(lhs, rhs) => {
// Skipping checks, the op went ok, we can skip
// the matmul size checks for now.
let lhs_grad = grad.matmul(&rhs.t()?)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = lhs.t()?.matmul(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Cat(args, dim) => {
let mut start_idx = 0;
for arg in args {
let len = arg.dims()[*dim];
let arg_grad = grad.narrow(*dim, start_idx, len)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?;
start_idx += len;
}
}
Op::Broadcast(_arg) => {
return Err(Error::BackwardNotSupported { op: "broadcast" })
}
Op::Sum(_arg, _sum_dims) => {
return Err(Error::BackwardNotSupported { op: "sum" })
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Log(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
}
Op::Sin(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
}
Op::Cos(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
Op::Exp(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
}
Op::Neg(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
Op::Narrow(_arg, _, _, _) => {
return Err(Error::BackwardNotSupported { op: "narrow" })
}
Op::Softmax(_arg, _) => {
return Err(Error::BackwardNotSupported { op: "softmax" })
}
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqrt(arg) => {
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::ToDevice(arg) => {
let sum_grad = grads.or_insert(arg)?;
let arg_grad = grad.to_device(&sum_grad.device())?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Transpose(arg, dim1, dim2) => {
let arg_grad = grad.transpose(*dim1, *dim2)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
};
}
}
Ok(grads)
}
}
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {
fn new() -> Self {
GradStore(HashMap::new())
}
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
self.0.get(&id)
}
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
self.0.get(&tensor.id())
}
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
self.0.remove(&tensor.id())
}
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
self.0.insert(tensor.id(), grad)
}
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
use std::collections::hash_map::Entry;
let grad = match self.0.entry(tensor.id()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
let grad = tensor.zeros_like()?;
entry.insert(grad)
}
};
Ok(grad)
}
}

View File

@ -0,0 +1,886 @@
use crate::op::{BinaryOp, UnaryOp};
use crate::{DType, Error, Result, Shape, StridedIndex};
use gemm::{gemm, Parallelism};
use half::{bf16, f16};
// TODO: Think about whether we would be better off with a dtype and
// a buffer as an owned slice of bytes.
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
// intercept the oom errors to avoid panicking and provide a proper error.
#[derive(Debug, Clone)]
pub enum CpuStorage {
U32(Vec<u32>),
BF16(Vec<bf16>),
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
}
fn wcond<T: Copy>(
pred: &[u32],
shape: &Shape,
stride: &[usize],
t: &[T],
stride_t: &[usize],
f: &[T],
stride_f: &[usize],
) -> Vec<T> {
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
{
let elem_count = shape.elem_count();
let pred = &pred[..elem_count];
let t = &t[..elem_count];
let f = &f[..elem_count];
pred.iter()
.zip(t.iter().zip(f.iter()))
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
.collect::<Vec<_>>()
} else {
let dims = shape.dims();
let it_p = StridedIndex::new(dims, stride);
let it_t = StridedIndex::new(dims, stride_t);
let it_f = StridedIndex::new(dims, stride_f);
it_p.zip(it_t.zip(it_f))
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
.collect::<Vec<_>>()
}
}
macro_rules! map1 {
($v: expr, $fn: ident, $( $args:expr ),*) => {{
let v = match $v {
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
};
Ok(v)
}};
}
fn sum_impl1<T: Copy + num_traits::NumAssign>(
src: &[T],
dst_shape: &Shape,
src_dims: &[usize],
stride: &[usize],
to_dst_index: impl Fn(usize) -> usize,
) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); dst_shape.elem_count()];
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
dst[to_dst_index(unstr_index)] += src[src_index];
}
Ok(dst)
}
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
vs: &[T],
shape: &Shape,
stride: &[usize],
mut f: F,
) -> Vec<U> {
if shape.is_contiguous(stride) {
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
} else {
StridedIndex::new(shape.dims(), stride)
.map(|i| f(vs[i]))
.collect()
}
}
// This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<T> {
let dims = shape.dims();
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
} else {
let lhs_index = StridedIndex::new(dims, lhs_stride);
let rhs_index = StridedIndex::new(dims, rhs_stride);
lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect()
}
}
fn take_impl1<T: Copy>(
vs: &[T],
ids: &[u32],
shape: &Shape,
stride: &[usize],
vocab_size: usize,
hidden_size: usize,
) -> Result<Vec<T>> {
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
for index in StridedIndex::new(shape.dims(), stride) {
let index = ids[index].try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "take",
});
} else {
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(values)
}
fn copy_strided_src_<T: Copy + std::fmt::Display>(
src: &[T],
dst: &mut [T],
dst_offset: usize,
src_shape: &Shape,
src_stride: &[usize],
src_offset: usize,
) {
let src = &src[src_offset..];
if src_shape.is_contiguous(src_stride) {
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
} else {
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
for (dst_index, src_index) in src_indexes.enumerate() {
let dst_index = dst_index + dst_offset;
if dst_index >= dst.len() {
break;
}
dst[dst_index] = src[src_index]
}
}
}
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
Self::U32(_) => DType::U32,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
}
}
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
D::cpu_storage_as_mut_slice(self)
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
// TODO: find a way around the quadratic number of cases below.
match (self, dtype) {
(Self::U32(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::BF16(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::BF16(data))
}
(Self::F16(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data))
}
(Self::F32(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, bf16::from_f32);
Ok(Self::BF16(data))
}
(Self::F64(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, bf16::from_f64);
Ok(Self::BF16(data))
}
(Self::U32(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::BF16(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
}
(Self::F16(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::F16(data))
}
(Self::F32(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, f16::from_f32);
Ok(Self::F16(data))
}
(Self::F64(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, f16::from_f64);
Ok(Self::F16(data))
}
(Self::U32(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v as f32);
Ok(Self::F32(data))
}
(Self::BF16(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::F16(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::F32(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::F32(data))
}
(Self::F64(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v as f32);
Ok(Self::F32(data))
}
(Self::U32(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::U32(data))
}
(Self::BF16(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::F16(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::F32(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F64(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U32(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v as f64);
Ok(Self::F64(data))
}
(Self::BF16(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::F16(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::F32(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v as f64);
Ok(Self::F64(data))
}
(Self::F64(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v);
Ok(Self::F64(data))
}
}
}
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
let src_dims = shape.dims();
let mut dst_dims = src_dims.to_vec();
for &sum_dim in sum_dims.iter() {
dst_dims[sum_dim] = 1;
}
let dst_shape = Shape::from(dst_dims);
let mut sum_dims = sum_dims.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_and_stride: Vec<_> = sum_dims
.iter()
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
.collect();
let to_dst_index = |unstr_index: usize| {
// TODO: Optimize, the following does lots of slow division.
let mut dst_index = unstr_index;
// Set the sum_dims indexes to 0.
for &(dim, stride) in sum_dims_and_stride.iter() {
// The compiler is able to optimize the following in a single divmod op.
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst_index
};
// TODO: Maybe provide an implementation with higher precision accumulators?
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index)
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way.
let dims = shape.dims();
let elem_per_slice = dims[dim];
let prod_pre_dim = dims[..dim].iter().product();
let prod_post_dim = dims[dim + 1..].iter().product();
match self {
Self::BF16(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx].to_f64();
idx += prod_post_dim
}
let sum = bf16::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Self::F16(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx].to_f64();
idx += prod_post_dim
}
let sum = f16::from_f64(sum);
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Self::F32(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx] as f64;
idx += prod_post_dim
}
let sum = sum as f32;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Self::F64(storage) => {
for pre_idx in 0..prod_pre_dim {
for post_idx in 0..prod_post_dim {
let mut sum = 0f64;
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
sum += storage[idx];
idx += prod_post_dim
}
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
for _ in 0..elem_per_slice {
storage[idx] /= sum;
idx += prod_post_dim
}
}
}
}
Self::U32(_) => {}
}
Ok(())
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Self::U32(storage) => {
let mul = mul as u32;
let add = add as u32;
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::U32(data))
}
Self::BF16(storage) => {
let mul = bf16::from_f64(mul);
let add = bf16::from_f64(add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let mul = f16::from_f64(mul);
let add = f16::from_f64(add);
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::F16(data))
}
Self::F32(storage) => {
let mul = mul as f32;
let add = add as f32;
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(storage, shape, stride, |v| v * mul + add);
Ok(Self::F64(data))
}
}
}
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
match self {
Self::BF16(storage) => {
let data = unary_map(storage, shape, stride, B::bf16);
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let data = unary_map(storage, shape, stride, B::f16);
Ok(Self::F16(data))
}
Self::F32(storage) => {
let data = unary_map(storage, shape, stride, B::f32);
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(storage, shape, stride, B::f64);
Ok(Self::F64(data))
}
Self::U32(storage) => {
let data = unary_map(storage, shape, stride, B::u32);
Ok(Self::U32(data))
}
}
}
pub(crate) fn binary_impl<B: BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
match (self, rhs) {
(Self::BF16(lhs), Self::BF16(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16);
Ok(Self::BF16(data))
}
(Self::F16(lhs), Self::F16(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16);
Ok(Self::F16(data))
}
(Self::F32(lhs), Self::F32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
Ok(Self::F32(data))
}
(Self::F64(lhs), Self::F64(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64);
Ok(Self::F64(data))
}
(Self::U32(lhs), Self::U32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32);
Ok(Self::U32(data))
}
_ => {
// This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: rhs.dtype(),
op: B::NAME,
})
}
}
}
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
dst_offset: usize,
src_shape: &Shape,
src_stride: &[usize],
src_offset: usize,
) -> Result<()> {
if src_shape.rank() != src_stride.len() {
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
}
match (self, dst) {
(Self::U32(src), Self::U32(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::BF16(src), Self::BF16(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::F16(src), Self::F16(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::F32(src), Self::F32(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::F64(src), Self::F64(dst)) => {
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
}
(_, dst) => {
// This should be covered by the dtype check above.
return Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: dst.dtype(),
op: "copy_strided",
});
}
}
Ok(())
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
// TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?;
match (t, f) {
(Self::BF16(t), Self::BF16(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::BF16(data))
}
(Self::F16(t), Self::F16(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F16(data))
}
(Self::F32(t), Self::F32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F32(data))
}
(Self::F64(t), Self::F64(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F64(data))
}
(Self::U32(t), Self::U32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::U32(data))
}
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: t.dtype(),
rhs: f.dtype(),
op: "where_cond",
}),
}
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
stride: &[usize],
vs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
let ids = self.as_slice::<u32>()?;
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let a_skip: usize = m * k;
let b_skip: usize = n * k;
let c_skip: usize = m * n;
let rank = lhs_stride.len();
let lhs_cs = lhs_stride[rank - 1];
let lhs_rs = lhs_stride[rank - 2];
let rhs_cs = rhs_stride[rank - 1];
let rhs_rs = rhs_stride[rank - 2];
if lhs_stride.len() > 2 {
let lhs_batch_stride = &lhs_stride[..rank - 2];
let rhs_batch_stride = &rhs_stride[..rank - 2];
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
// Temporary error before we support abitrary striding.
return Err(Error::UnexpectedStriding);
}
}
let dst_shape: Shape = (m, n).into();
let dst_strides = dst_shape.stride_contiguous();
let dst_rs = dst_strides[0];
let dst_cs = dst_strides[1];
match (self, rhs) {
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
let mut dst = vec![f16::ZERO; b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
f16::ONE,
// beta: T,
f16::ONE,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
Ok(Self::F16(dst))
}
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
let mut dst = vec![0f32; b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1f32,
// beta: T,
1f32,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
Ok(Self::F32(dst))
}
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
let mut dst = vec![0f64; b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
gemm(
// m: usize,
m,
// n: usize,
n,
// k: usize,
k,
// dst: *mut T,
dst_p.as_mut_ptr(),
// dst_cs: isize,
dst_cs as isize,
// dst_rs: isize,
dst_rs as isize,
// read_dst: bool,
false,
// lhs: *const T,
lhs_p.as_ptr(),
// lhs_cs: isize,
lhs_cs as isize,
// lhs_rs: isize,
lhs_rs as isize,
// rhs: *const T,
rhs_p.as_ptr(),
// rhs_cs: isize,
rhs_cs as isize,
// rhs_rs: isize,
rhs_rs as isize,
// alpha: T,
1f64,
// beta: T,
1f64,
// conj_dst: bool,
false,
// conj_lhs: bool,
false,
// conj_rhs: bool,
true,
// parallelism: Parallelism
Parallelism::None,
)
}
}
Ok(Self::F64(dst))
}
_ => {
// This should be covered by the dtype check above.
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: rhs.dtype(),
op: "matmul",
})
}
}
}
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
DType::U32 => {
let data = vec![1u32; elem_count];
Self::U32(data)
}
DType::BF16 => {
let data = vec![bf16::ONE; elem_count];
Self::BF16(data)
}
DType::F16 => {
let data = vec![f16::ONE; elem_count];
Self::F16(data)
}
DType::F32 => {
let data = vec![1f32; elem_count];
Self::F32(data)
}
DType::F64 => {
let data = vec![1f64; elem_count];
Self::F64(data)
}
}
}
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
let elem_count = shape.elem_count();
match dtype {
DType::U32 => {
let data = vec![0u32; elem_count];
Self::U32(data)
}
DType::BF16 => {
let data = vec![bf16::ZERO; elem_count];
Self::BF16(data)
}
DType::F16 => {
let data = vec![f16::ZERO; elem_count];
Self::F16(data)
}
DType::F32 => {
let data = vec![0f32; elem_count];
Self::F32(data)
}
DType::F64 => {
let data = vec![0f64; elem_count];
Self::F64(data)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Device, Tensor};
#[test]
fn simple_matmul() -> Result<()> {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
let data = vec![1.0f32, 2.0];
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
let data = vec![3.0f32, 4.0];
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
let c = a.matmul(&b)?;
assert_eq!(
c.to_vec3::<f32>()?,
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
);
Ok(())
}
}

View File

@ -0,0 +1,978 @@
use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::Arc;
/// cudarc related errors
#[derive(thiserror::Error, Debug)]
pub enum CudaError {
#[error(transparent)]
Cuda(#[from] cudarc::driver::DriverError),
#[error(transparent)]
Compiler(#[from] cudarc::nvrtc::CompileError),
#[error(transparent)]
Cublas(#[from] cudarc::cublas::result::CublasError),
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("missing kernel '{module_name}'")]
MissingKernel { module_name: String },
#[error("internal error '{0}'")]
InternalError(&'static str),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{cuda} when loading {module_name}")]
Load {
cuda: cudarc::driver::DriverError,
module_name: String,
},
}
type Result<T> = std::result::Result<T, CudaError>;
/// Unique identifier for cuda devices.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub(crate) struct DeviceId(usize);
impl DeviceId {
fn new() -> Self {
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
use std::sync::atomic;
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
}
}
#[derive(Debug, Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
#[allow(dead_code)]
blas: Arc<cudarc::cublas::CudaBlas>,
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
Ok(Self {
id: DeviceId::new(),
device,
blas: Arc::new(blas),
})
}
pub(crate) fn same_id(&self, rhs: &Self) -> bool {
self.id == rhs.id
}
pub(crate) fn ordinal(&self) -> usize {
self.device.ordinal()
}
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})?;
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
}
}
#[derive(Debug)]
enum CudaStorageSlice {
U32(CudaSlice<u32>),
BF16(CudaSlice<bf16>),
F16(CudaSlice<f16>),
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
#[derive(Debug)]
pub struct CudaStorage {
slice: CudaStorageSlice,
device: CudaDevice,
}
fn gemm_config<T>(
alpha: T,
beta: T,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<StridedBatchedConfig<T>> {
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
use cudarc::cublas::sys::cublasOperation_t;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
// The setup below was copied from:
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
let gemm = GemmConfig {
alpha,
beta,
m: n as i32,
n: m as i32,
k: k as i32,
lda,
ldb,
ldc: n as i32,
transa,
transb,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
stride_a: (m * k) as i64,
stride_b: (n * k) as i64,
stride_c: (m * n) as i64,
})
}
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
let slice = match &self.slice {
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
};
let device = self.device.clone();
Ok(Self { slice, device })
}
pub fn dtype(&self) -> DType {
match self.slice {
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::BF16(_) => DType::BF16,
CudaStorageSlice::F16(_) => DType::F16,
CudaStorageSlice::F32(_) => DType::F32,
CudaStorageSlice::F64(_) => DType::F64,
}
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
use cudarc::driver::DevicePtr;
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, stride].concat())?;
let inp = match &self.slice {
CudaStorageSlice::U32(inp) => inp.device_ptr(),
CudaStorageSlice::BF16(inp) => inp.device_ptr(),
CudaStorageSlice::F16(inp) => inp.device_ptr(),
CudaStorageSlice::F32(inp) => inp.device_ptr(),
CudaStorageSlice::F64(inp) => inp.device_ptr(),
};
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
let slice = match dtype {
DType::U32 => {
let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
DType::BF16 => {
let out = unsafe { dev.alloc::<bf16>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
DType::F16 => {
let out = unsafe { dev.alloc::<f16>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
DType::F32 => {
let out = unsafe { dev.alloc::<f32>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
DType::F64 => {
let out = unsafe { dev.alloc::<f64>(el) }?;
let params = (el, dims.len(), &ds, *inp, &out);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
Ok(Self {
slice,
device: dev.clone(),
})
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, stride].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
bf16::from_f64(mul),
bf16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (
el_count,
dims.len(),
&ds,
arg,
&out,
f16::from_f64(mul),
f16::from_f64(add),
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
let src_dims = shape.dims();
let el = shape.elem_count();
let mut dst_el = el;
for &sum_dim in sum_dims.iter() {
dst_el /= src_dims[sum_dim];
}
let mut sum_dims = sum_dims.to_vec();
// Sort the sum_dims as they have to be processed from left to right when converting the
// indexes.
sum_dims.sort();
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
let sum_dims_s: Vec<usize> = sum_dims
.iter()
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
.collect();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<u32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<bf16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f16>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f32>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
let out = dev.alloc_zeros::<f64>(dst_el)?;
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(CudaError::InternalError(
"TODO: implement divide_by_sum_over_dim",
))
}
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
&self,
shape: &Shape,
stride: &[usize],
) -> Result<Self> {
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, stride].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(_arg) => {
todo!("No unary kernels for u32");
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let dims = shape.dims();
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let dev = self.device();
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
match &self.slice {
CudaStorageSlice::U32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::BF16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::BF16(cpu_storage))
}
CudaStorageSlice::F16(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::F16(cpu_storage))
}
CudaStorageSlice::F32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::F32(cpu_storage))
}
CudaStorageSlice::F64(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::F64(cpu_storage))
}
}
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
let ids = match &self.slice {
CudaStorageSlice::U32(slice) => slice,
_ => Err(CudaError::UnexpectedDType {
msg: "where conditions should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
let slice = match (&t.slice, &f.slice) {
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<f64>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
// SAFETY: Set later by running the kernel.
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
let out = unsafe { dev.alloc::<u32>(el) }?;
let params = (el, dims.len(), &ds, ids, t, f, &out);
// SAFETY: ffi
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
// The dtypes should have been checked at this point so this is an internal error.
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
stride: &[usize],
rhs: &Self,
h_size: usize, // hidden size
v_size: usize, // vocab size
) -> Result<Self> {
let ids = match &self.slice {
CudaStorageSlice::U32(slice) => slice,
_ => Err(CudaError::UnexpectedDType {
msg: "embedding ids should be u32",
expected: DType::U32,
got: self.dtype(),
})?,
};
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let dev = self.device();
let ds = dev.htod_copy([dims, stride].concat())?;
let slice = match &rhs.slice {
// The kernels below assume that rhs is contiguous.
CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::BF16(arg) => {
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::BF16(out)
}
CudaStorageSlice::F16(arg) => {
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F16(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F32(out)
}
CudaStorageSlice::F64(arg) => {
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::F64(out)
}
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let elem_count = b * m * n;
let dev = &self.device;
let slice = match (&self.slice, &rhs.slice) {
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
todo!("bf16")
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}?;
CudaStorageSlice::F64(out)
}
_ => return Err(CudaError::InternalError("dtype mismatch in matmul op")),
};
let device = dev.clone();
Ok(Self { slice, device })
}
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
dst_offset: usize,
src_shape: &Shape,
src_stride: &[usize],
src_offset: usize,
) -> Result<()> {
if src_shape.rank() != src_stride.len() {
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
}
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let dev = &self.device;
let ds = dev.htod_copy([dims, src_stride].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
}
}
_ => {
return Err(CudaError::InternalError(
"dtype mismatch in copy_strided op",
))
}
}
Ok(())
}
}

159
candle-core/src/device.rs Normal file
View File

@ -0,0 +1,159 @@
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
}
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
pub trait NdArray {
fn shape(&self) -> Result<Shape>;
fn to_cpu_storage(&self) -> CpuStorage;
}
impl<S: WithDType> NdArray for S {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(&[*self])
}
}
impl<S: WithDType, const N: usize> NdArray for &[S; N] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self.as_slice())
}
}
impl<S: WithDType> NdArray for &[S] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from(self.len()))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage(self)
}
}
impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((M, N)))
}
fn to_cpu_storage(&self) -> CpuStorage {
S::to_cpu_storage_owned(self.concat())
}
}
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
for &[[[S; N3]; N2]; N1]
{
fn shape(&self) -> Result<Shape> {
Ok(Shape::from((N1, N2, N3)))
}
fn to_cpu_storage(&self) -> CpuStorage {
let mut vec = Vec::new();
vec.reserve(N1 * N2 * N3);
for i1 in 0..N1 {
for i2 in 0..N2 {
vec.extend(self[i1][i2])
}
}
S::to_cpu_storage_owned(vec)
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn same_id(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
_ => false,
}
}
pub fn location(&self) -> DeviceLocation {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => DeviceLocation::Cuda {
gpu_id: device.ordinal(),
},
}
}
pub fn is_cuda(&self) -> bool {
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
}
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuStorage::ones_impl(shape, dtype);
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
let storage = CpuStorage::zeros_impl(shape, dtype);
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
Device::Cuda(device) => {
let storage = array.to_cpu_storage();
let storage = device.cuda_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
}
pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
match self {
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
Device::Cuda(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.cuda_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
}
}
}

96
candle-core/src/dtype.rs Normal file
View File

@ -0,0 +1,96 @@
use crate::{CpuStorage, Error, Result};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DType {
U32,
BF16,
F16,
F32,
F64,
}
impl DType {
pub fn as_str(&self) -> &'static str {
match self {
Self::U32 => "u32",
Self::BF16 => "bf16",
Self::F16 => "f16",
Self::F32 => "f32",
Self::F64 => "f64",
}
}
pub fn size_in_bytes(&self) -> usize {
match self {
Self::U32 => 4,
Self::BF16 => 2,
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
}
}
}
pub trait WithDType: Sized + Copy {
const DTYPE: DType;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
Self::to_cpu_storage_owned(data.to_vec())
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
}
macro_rules! with_dtype {
($ty:ty, $dtype:ident) => {
impl WithDType for $ty {
const DTYPE: DType = DType::$dtype;
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
CpuStorage::$dtype(data)
}
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
match s {
CpuStorage::$dtype(data) => Ok(data),
_ => Err(Error::UnexpectedDType {
expected: DType::$dtype,
got: s.dtype(),
msg: "unexpected dtype",
}),
}
}
}
};
}
with_dtype!(u32, U32);
with_dtype!(half::f16, F16);
with_dtype!(half::bf16, BF16);
with_dtype!(f32, F32);
with_dtype!(f64, F64);

View File

@ -0,0 +1,136 @@
#![allow(dead_code)]
use crate::{CpuStorage, DType, Error, Result, Shape};
#[derive(thiserror::Error, Debug)]
pub enum DummyError {}
pub type CudaError = DummyError;
#[derive(Debug, Clone)]
pub struct CudaDevice;
macro_rules! fail {
() => {
unimplemented!("cuda support has not been enabled")
};
}
impl CudaDevice {
pub(crate) fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn same_id(&self, _: &Self) -> bool {
true
}
pub(crate) fn ordinal(&self) -> usize {
fail!()
}
pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
}
#[derive(Debug)]
pub struct CudaStorage;
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> DType {
fail!()
}
pub fn device(&self) -> &CudaDevice {
fail!()
}
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
&self,
_: &Self,
_: &Shape,
_: &[usize],
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn where_cond(
&self,
_: &Shape,
_: &[usize],
_: &Self,
_: &[usize],
_: &Self,
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn embedding_impl(
&self,
_: &Shape,
_: &[usize],
_: &Self,
_: usize,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn matmul_impl(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &[usize],
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn copy_strided_src(
&self,
_: &mut Self,
_: usize,
_: &Shape,
_: &[usize],
_: usize,
) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
}

102
candle-core/src/error.rs Normal file
View File

@ -0,0 +1,102 @@
use crate::{DType, DeviceLocation, Shape};
/// Main library error type.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
UnexpectedDType {
msg: &'static str,
expected: DType,
got: DType,
},
#[error("{op} only supports contiguous tensors")]
RequiresContiguous { op: &'static str },
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
#[error("{op} invalid index {index} with vocab {vocab_size}")]
InvalidIndex {
op: &'static str,
index: usize,
vocab_size: usize,
},
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error(
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
)]
ShapeMismatch { buffer_size: usize, shape: Shape },
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
ShapeMismatchBinaryOp {
lhs: Shape,
rhs: Shape,
op: &'static str,
},
#[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
ShapeMismatchCat {
dim: usize,
first_shape: Shape,
n: usize,
nth_shape: Shape,
},
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DeviceMismatchBinaryOp {
lhs: DeviceLocation,
rhs: DeviceLocation,
op: &'static str,
},
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
DTypeMismatchBinaryOp {
lhs: DType,
rhs: DType,
op: &'static str,
},
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
UnexpectedNumberOfDims {
expected: usize,
got: usize,
shape: Shape,
},
// TODO this is temporary when we support arbitrary matmul
#[error("temporary error where matmul doesn't support arbitrary striding")]
UnexpectedStriding,
#[error(transparent)]
Cuda(#[from] crate::CudaError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),
#[error("npy/npz error {0}")]
Npy(String),
/// Zip file format error.
#[error(transparent)]
Zip(#[from] zip::result::ZipError),
/// Integer parse error.
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),
/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
}
pub type Result<T> = std::result::Result<T, Error>;

29
candle-core/src/lib.rs Normal file
View File

@ -0,0 +1,29 @@
mod backprop;
mod cpu_backend;
#[cfg(feature = "cuda")]
mod cuda_backend;
mod device;
mod dtype;
mod dummy_cuda_backend;
mod error;
mod npy;
mod op;
mod shape;
mod storage;
mod strided_index;
mod tensor;
pub use cpu_backend::CpuStorage;
pub use device::{Device, DeviceLocation};
pub use dtype::{DType, WithDType};
pub use error::{Error, Result};
pub use shape::Shape;
pub use storage::Storage;
use strided_index::StridedIndex;
pub use tensor::{Tensor, TensorId};
#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};

401
candle-core/src/npy.rs Normal file
View File

@ -0,0 +1,401 @@
//! Numpy support for literals.
//!
//! The spec for the npy format can be found in
//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
//! The functions from this module can be used to read literals from npy/npz files
//! or write literals to these files. A npy file contains a single literal (unnamed)
//! whereas a npz file can contain multiple named literals. npz files are also compressed.
//!
//! These two formats are easy to use in Python using the numpy library.
//!
//! ```python
//! import numpy as np
//! x = np.arange(10)
//!
//! # Write a npy file.
//! np.save("test.npy", x)
//!
//! # Read a value from the npy file.
//! x = np.load("test.npy")
//!
//! # Write multiple values to a npz file.
//! values = { "x": x, "x_plus_one": x + 1 }
//! np.savez("test.npz", **values)
//!
//! # Load multiple values from a npz file.
//! values = np.loadz("test.npz")
//! ```
use crate::{DType, Device, Error, Result, Shape, Tensor};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Write};
use std::path::Path;
const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
const NPY_SUFFIX: &str = ".npy";
fn read_header<R: Read>(reader: &mut R) -> Result<String> {
let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
reader.read_exact(&mut magic_string)?;
if magic_string != NPY_MAGIC_STRING {
return Err(Error::Npy("magic string mismatch".to_string()));
}
let mut version = [0u8; 2];
reader.read_exact(&mut version)?;
let header_len_len = match version[0] {
1 => 2,
2 => 4,
otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
};
let mut header_len = vec![0u8; header_len_len];
reader.read_exact(&mut header_len)?;
let header_len = header_len
.iter()
.rev()
.fold(0_usize, |acc, &v| 256 * acc + v as usize);
let mut header = vec![0u8; header_len];
reader.read_exact(&mut header)?;
Ok(String::from_utf8_lossy(&header).to_string())
}
#[derive(Debug, PartialEq)]
struct Header {
descr: DType,
fortran_order: bool,
shape: Vec<usize>,
}
impl Header {
fn shape(&self) -> Shape {
Shape::from(self.shape.as_slice())
}
fn to_string(&self) -> Result<String> {
let fortran_order = if self.fortran_order { "True" } else { "False" };
let mut shape = self
.shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(",");
let descr = match self.descr {
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
DType::F16 => "f2",
DType::F32 => "f4",
DType::F64 => "f8",
DType::U32 => "u4",
};
if !shape.is_empty() {
shape.push(',')
}
Ok(format!(
"{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
))
}
// Hacky parser for the npy header, a typical example would be:
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
fn parse(header: &str) -> Result<Header> {
let header =
header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
let mut parts: Vec<String> = vec![];
let mut start_index = 0usize;
let mut cnt_parenthesis = 0i64;
for (index, c) in header.chars().enumerate() {
match c {
'(' => cnt_parenthesis += 1,
')' => cnt_parenthesis -= 1,
',' => {
if cnt_parenthesis == 0 {
parts.push(header[start_index..index].to_owned());
start_index = index + 1;
}
}
_ => {}
}
}
parts.push(header[start_index..].to_owned());
let mut part_map: HashMap<String, String> = HashMap::new();
for part in parts.iter() {
let part = part.trim();
if !part.is_empty() {
match part.split(':').collect::<Vec<_>>().as_slice() {
[key, value] => {
let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
let _ = part_map.insert(key.to_owned(), value.to_owned());
}
_ => return Err(Error::Npy(format!("unable to parse header {header}"))),
}
}
}
let fortran_order = match part_map.get("fortran_order") {
None => false,
Some(fortran_order) => match fortran_order.as_ref() {
"False" => false,
"True" => true,
_ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
},
};
let descr = match part_map.get("descr") {
None => return Err(Error::Npy("no descr in header".to_string())),
Some(descr) => {
if descr.is_empty() {
return Err(Error::Npy("empty descr".to_string()));
}
if descr.starts_with('>') {
return Err(Error::Npy(format!("little-endian descr {descr}")));
}
// the only supported types in tensor are:
// float64, float32, float16,
// complex64, complex128,
// int64, int32, int16, int8,
// uint8, and bool.
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
"e" | "f2" => DType::F16,
"f" | "f4" => DType::F32,
"d" | "f8" => DType::F64,
// "i" | "i4" => DType::S32,
// "q" | "i8" => DType::S64,
// "h" | "i2" => DType::S16,
// "b" | "i1" => DType::S8,
// "B" | "u1" => DType::U8,
"I" | "u4" => DType::U32,
// "?" | "b1" => DType::Pred,
// "F" | "F4" => DType::C64,
// "D" | "F8" => DType::C128,
descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
}
}
};
let shape = match part_map.get("shape") {
None => return Err(Error::Npy("no shape in header".to_string())),
Some(shape) => {
let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
if shape.is_empty() {
vec![]
} else {
shape
.split(',')
.map(|v| v.trim().parse::<usize>())
.collect::<std::result::Result<Vec<_>, _>>()?
}
}
};
Ok(Header {
descr,
fortran_order,
shape,
})
}
}
impl Tensor {
// TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {
let mut data_t = vec![bf16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F16 => {
let mut data_t = vec![f16::ZERO; elem_count];
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F32 => {
let mut data_t = vec![0f32; elem_count];
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::F64 => {
let mut data_t = vec![0f64; elem_count];
reader.read_f64_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
DType::U32 => {
let mut data_t = vec![0u32; elem_count];
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
Tensor::from_vec(data_t, shape, &Device::Cpu)
}
}
}
/// Reads a npy file and return the stored multi-dimensional array as a literal.
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
let mut reader = File::open(path.as_ref())?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader)
}
/// Reads a npz file and returns the stored multi-dimensional arrays together with their names.
pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
let zip_reader = BufReader::new(File::open(path.as_ref())?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut result = vec![];
for i in 0..zip.len() {
let mut reader = zip.by_index(i).unwrap();
let name = {
let name = reader.name();
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
};
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
result.push((name, s))
}
Ok(result)
}
/// Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
let zip_reader = BufReader::new(File::open(path.as_ref())?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut result = vec![];
for name in names.iter() {
let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
Ok(reader) => reader,
Err(_) => Err(Error::Npy(format!(
"no array for {name} in {:?}",
path.as_ref()
)))?,
};
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
result.push(s)
}
Ok(result)
}
fn write<T: Write>(&self, f: &mut T) -> Result<()> {
f.write_all(NPY_MAGIC_STRING)?;
f.write_all(&[1u8, 0u8])?;
let header = Header {
descr: self.dtype(),
fortran_order: false,
shape: self.dims().to_vec(),
};
let mut header = header.to_string()?;
let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
for _ in 0..pad % 16 {
header.push(' ')
}
header.push('\n');
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
f.write_all(header.as_bytes())?;
let elem_count = self.elem_count();
match self.dtype() {
DType::BF16 => {
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F16 => {
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
for &v in vs.reinterpret_cast() {
f.write_u16::<LittleEndian>(v)?
}
}
DType::F32 => {
// TODO: Avoid using a buffer when data is already on the CPU.
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
f.write_f32::<LittleEndian>(v)?
}
}
DType::F64 => {
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
f.write_f64::<LittleEndian>(v)?
}
}
DType::U32 => {
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
f.write_u32::<LittleEndian>(v)?
}
}
}
Ok(())
}
/// Writes a multi-dimensional array in the npy format.
pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
let mut f = File::create(path.as_ref())?;
self.write(&mut f)
}
/// Writes multiple multi-dimensional arrays using the npz format.
pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
ts: &[(S, T)],
path: P,
) -> Result<()> {
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
let options =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
for (name, tensor) in ts.iter() {
zip.start_file(format!("{}.npy", name.as_ref()), options)?;
tensor.as_ref().write(&mut zip)?
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::Header;
#[test]
fn parse() {
let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
assert_eq!(
Header::parse(h).unwrap(),
Header {
descr: crate::DType::F64,
fortran_order: false,
shape: vec![128]
}
);
let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
let h = Header::parse(h).unwrap();
assert_eq!(
h,
Header {
descr: crate::DType::F32,
fortran_order: true,
shape: vec![256, 1, 128]
}
);
assert_eq!(
h.to_string().unwrap(),
"{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
);
let h = Header {
descr: crate::DType::U32,
fortran_order: false,
shape: vec![],
};
assert_eq!(
h.to_string().unwrap(),
"{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
);
}
}

197
candle-core/src/op.rs Normal file
View File

@ -0,0 +1,197 @@
use crate::Tensor;
use half::{bf16, f16};
use num_traits::float::Float;
#[derive(Clone)]
pub(crate) enum Op {
Add(Tensor, Tensor),
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),
Cat(Vec<Tensor>, usize),
#[allow(dead_code)] // add is currently unused.
Affine {
arg: Tensor,
mul: f64,
add: f64,
},
Sum(Tensor, Vec<usize>),
ToDType(Tensor),
Broadcast(Tensor),
Exp(Tensor),
Log(Tensor),
Sin(Tensor),
Cos(Tensor),
Abs(Tensor),
Narrow(Tensor, usize, usize, usize),
Neg(Tensor),
Reshape(Tensor),
Softmax(Tensor, usize),
Sqr(Tensor),
Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Gelu(Tensor),
// TODO: Support for custom ops.
}
pub(crate) trait UnaryOp {
const NAME: &'static str;
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16) -> bf16;
fn f16(v1: f16) -> f16;
fn f32(v1: f32) -> f32;
fn f64(v1: f64) -> f64;
fn u32(v1: u32) -> u32;
}
pub(crate) trait BinaryOp {
const NAME: &'static str;
const KERNEL_BF16: &'static str;
const KERNEL_F16: &'static str;
const KERNEL_F32: &'static str;
const KERNEL_F64: &'static str;
const KERNEL_U32: &'static str;
fn bf16(v1: bf16, v2: bf16) -> bf16;
fn f16(v1: f16, v2: f16) -> f16;
fn f32(v1: f32, v2: f32) -> f32;
fn f64(v1: f64, v2: f64) -> f64;
fn u32(v1: u32, v2: u32) -> u32;
}
pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
macro_rules! bin_op {
($op:ident, $name: literal, $e: expr) => {
impl BinaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
fn bf16(v1: bf16, v2: bf16) -> bf16 {
$e(v1, v2)
}
fn f16(v1: f16, v2: f16) -> f16 {
$e(v1, v2)
}
fn f32(v1: f32, v2: f32) -> f32 {
$e(v1, v2)
}
fn f64(v1: f64, v2: f64) -> f64 {
$e(v1, v2)
}
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
}
};
}
bin_op!(Add, "add", |v1, v2| v1 + v2);
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
bin_op!(Div, "div", |v1, v2| v1 / v2);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {
impl UnaryOp for $op {
const NAME: &'static str = $name;
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
fn bf16($a: bf16) -> bf16 {
$e
}
fn f16($a: f16) -> f16 {
$e
}
fn f32($a: f32) -> f32 {
$e
}
fn f64($a: f64) -> f64 {
$e
}
fn u32(_: u32) -> u32 {
todo!("no unary function for u32")
}
}
};
}
unary_op!(Exp, "exp", v, v.exp());
unary_op!(Log, "log", v, v.ln());
unary_op!(Sin, "sin", v, v.sin());
unary_op!(Cos, "cos", v, v.cos());
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Sqr, "sqr", v, v * v);
unary_op!(Sqrt, "sqrt", v, v.sqrt());
/// `gelu` operation
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOp for Gelu {
const NAME: &'static str = "gelu";
fn bf16(v: bf16) -> bf16 {
bf16::from_f32_const(0.5)
* v
* (bf16::ONE
+ bf16::tanh(
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
* v
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
))
}
fn f16(v: f16) -> f16 {
f16::from_f32_const(0.5)
* v
* (f16::ONE
+ f16::tanh(
(f16::from_f32_const(2.0) / f16::PI).sqrt()
* v
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
))
}
fn f32(v: f32) -> f32 {
0.5 * v
* (1.0
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn f64(v: f64) -> f64 {
0.5 * v
* (1.0
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_BF16: &'static str = "gelu_bf16";
const KERNEL_F16: &'static str = "gelu_f16";
const KERNEL_F32: &'static str = "gelu_f32";
const KERNEL_F64: &'static str = "gelu_f64";
const KERNEL_U32: &'static str = "gelu_u32";
}

199
candle-core/src/shape.rs Normal file
View File

@ -0,0 +1,199 @@
use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
pub struct Shape(Vec<usize>);
impl std::fmt::Debug for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", &self.dims())
}
}
impl<const C: usize> From<&[usize; C]> for Shape {
fn from(dims: &[usize; C]) -> Self {
Self(dims.to_vec())
}
}
impl From<&[usize]> for Shape {
fn from(dims: &[usize]) -> Self {
Self(dims.to_vec())
}
}
impl From<&Shape> for Shape {
fn from(shape: &Shape) -> Self {
Self(shape.0.to_vec())
}
}
impl From<()> for Shape {
fn from(_: ()) -> Self {
Self(vec![])
}
}
impl From<usize> for Shape {
fn from(d1: usize) -> Self {
Self(vec![d1])
}
}
impl From<(usize, usize)> for Shape {
fn from(d12: (usize, usize)) -> Self {
Self(vec![d12.0, d12.1])
}
}
impl From<(usize, usize, usize)> for Shape {
fn from(d123: (usize, usize, usize)) -> Self {
Self(vec![d123.0, d123.1, d123.2])
}
}
impl From<(usize, usize, usize, usize)> for Shape {
fn from(d1234: (usize, usize, usize, usize)) -> Self {
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
}
}
impl From<(usize, usize, usize, usize, usize)> for Shape {
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
}
}
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
}
}
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
pub fn $fn_name(&self) -> Result<$out_type> {
if self.0.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
got: self.0.len(),
shape: self.clone(),
})
} else {
Ok($dims(&self.0))
}
}
};
}
impl Shape {
pub fn from_dims(dims: &[usize]) -> Self {
Self(dims.to_vec())
}
pub fn rank(&self) -> usize {
self.0.len()
}
pub fn into_dims(self) -> Vec<usize> {
self.0
}
pub fn dims(&self) -> &[usize] {
&self.0
}
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
r3,
3,
|d: &[usize]| (d[0], d[1], d[2]),
(usize, usize, usize)
);
extract_dims!(
r4,
4,
|d: &[usize]| (d[0], d[1], d[2], d[3]),
(usize, usize, usize, usize)
);
extract_dims!(
r5,
5,
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
(usize, usize, usize, usize, usize)
);
/// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape.
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
let mut stride: Vec<_> = self
.0
.iter()
.rev()
.scan(1, |prod, u| {
let prod_pre_mult = *prod;
*prod *= u;
Some(prod_pre_mult)
})
.collect();
stride.reverse();
stride
}
/// Returns true if the strides are C contiguous (aka row major).
pub fn is_contiguous(&self, stride: &[usize]) -> bool {
if self.0.len() != stride.len() {
return false;
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
if stride != acc {
return false;
}
acc *= dim;
}
true
}
/// Returns true if the strides are Fortran contiguous (aka column major).
pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
if self.0.len() != stride.len() {
return false;
}
let mut acc = 1;
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
if stride != acc {
return false;
}
acc *= dim;
}
true
}
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride() {
let shape = Shape::from(());
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
let shape = Shape::from(42);
assert_eq!(shape.stride_contiguous(), [1]);
let shape = Shape::from((42, 1337));
assert_eq!(shape.stride_contiguous(), [1337, 1]);
let shape = Shape::from((299, 792, 458));
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}

261
candle-core/src/storage.rs Normal file
View File

@ -0,0 +1,261 @@
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
#[derive(Debug)]
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
}
impl Storage {
pub fn try_clone(&self) -> Result<Self> {
match self {
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
Self::Cuda(storage) => {
let storage = storage.try_clone()?;
Ok(Self::Cuda(storage))
}
}
}
pub fn device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
}
}
pub fn dtype(&self) -> DType {
match self {
Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(),
}
}
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.device().location();
let rhs = rhs.device().location();
if lhs != rhs {
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
} else {
Ok(())
}
}
pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
let lhs = self.dtype();
let rhs = rhs.dtype();
if lhs != rhs {
Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op })
} else {
Ok(())
}
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.affine_impl(shape, stride, mul, add)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.affine_impl(shape, stride, mul, add)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.sum(shape, stride, s)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.sum(shape, stride, s)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
match self {
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
}
Ok(())
}
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.to_dtype(shape, stride, dtype)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.to_dtype(shape, stride, dtype)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn unary_impl<B: op::UnaryOp>(
&self,
shape: &Shape,
stride: &[usize],
) -> Result<Self> {
// TODO: Different code path for the contiguous case?
match self {
Storage::Cpu(storage) => {
let storage = storage.unary_impl::<B>(shape, stride)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.unary_impl::<B>(shape, stride)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn binary_impl<B: op::BinaryOp>(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.same_device(rhs, B::NAME)?;
self.same_dtype(rhs, B::NAME)?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: B::NAME,
})
}
}
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
self.same_device(t, "where")?;
self.same_device(f, "where")?;
t.same_dtype(f, "where")?;
match (self, t, f) {
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
Ok(Self::Cuda(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
}),
}
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
stride: &[usize],
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
self.same_device(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
}),
}
}
pub(crate) fn matmul_impl(
&self,
rhs: &Self,
bmnk: (usize, usize, usize, usize),
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.same_device(rhs, "matmul")?;
self.same_dtype(rhs, "matmul")?;
match (self, rhs) {
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "matmul",
}),
}
}
// self, the source can be strided whereas dst is contiguous.
pub(crate) fn copy_strided_src(
&self,
dst: &mut Self,
dst_offset: usize,
src_shape: &Shape,
src_stride: &[usize],
src_offset: usize,
) -> Result<()> {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => {
src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)
}
(Self::Cuda(src), Self::Cuda(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "copy",
}),
}
}
}

View File

@ -0,0 +1,61 @@
/// An iterator over offset position for items of an N-dimensional arrays stored in a
/// flat buffer using some potential strides.
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
dims: &'a [usize],
stride: &'a [usize],
}
impl<'a> StridedIndex<'a> {
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
Some(0)
};
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
dims,
stride,
}
}
}
impl<'a> Iterator for StridedIndex<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_storage_index {
None => return None,
Some(storage_index) => storage_index,
};
let mut updated = false;
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
break;
} else {
*multi_i = 0
}
}
self.next_storage_index = if updated {
let next_storage_index = self
.multi_index
.iter()
.zip(self.stride.iter())
.map(|(&x, &y)| x * y)
.sum();
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}

1062
candle-core/src/tensor.rs Normal file

File diff suppressed because it is too large Load Diff