mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Refactor the hierarchy.
This commit is contained in:
274
candle-core/src/backprop.rs
Normal file
274
candle-core/src/backprop.rs
Normal file
@ -0,0 +1,274 @@
|
||||
use crate::{op::Op, Error, Result, Tensor, TensorId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl Tensor {
|
||||
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
/// argument.
|
||||
/// This assumes that the op graph is a DAG.
|
||||
fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
|
||||
// to get around some lifetime limitations.
|
||||
fn walk<'a>(
|
||||
node: &'a Tensor,
|
||||
nodes: Vec<&'a Tensor>,
|
||||
already_seen: &mut HashMap<TensorId, bool>,
|
||||
) -> (bool, Vec<&'a Tensor>) {
|
||||
if let Some(&tg) = already_seen.get(&node.id()) {
|
||||
return (tg, nodes);
|
||||
}
|
||||
let mut track_grad = false;
|
||||
let mut nodes = if node.is_variable() {
|
||||
// Do not call recursively on the "leaf" nodes.
|
||||
track_grad = true;
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t3, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Cat(args, _) => args.iter().fold(nodes, |nodes, arg| {
|
||||
let (tg, nodes) = walk(arg, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}),
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
if *mul == 0. {
|
||||
nodes
|
||||
} else {
|
||||
let (tg, nodes) = walk(arg, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Sum(node, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Gelu(node)
|
||||
| Op::Exp(node)
|
||||
| Op::Log(node)
|
||||
| Op::Sin(node)
|
||||
| Op::Cos(node)
|
||||
| Op::Abs(node)
|
||||
| Op::Neg(node) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nodes
|
||||
};
|
||||
already_seen.insert(node.id(), track_grad);
|
||||
if track_grad {
|
||||
nodes.push(node);
|
||||
}
|
||||
(track_grad, nodes)
|
||||
}
|
||||
let (_tg, mut nodes) = walk(self, vec![], &mut HashMap::new());
|
||||
nodes.reverse();
|
||||
nodes
|
||||
}
|
||||
|
||||
pub fn backward(&self) -> Result<GradStore> {
|
||||
let sorted_nodes = self.sorted_nodes();
|
||||
let mut grads = GradStore::new();
|
||||
grads.insert(self, self.ones_like()?);
|
||||
for node in sorted_nodes.iter() {
|
||||
if node.is_variable() {
|
||||
continue;
|
||||
}
|
||||
let grad = grads.remove(node).unwrap();
|
||||
// TODO: We should perform all these operations in place (or at least not track the
|
||||
// whole graph).
|
||||
// The only drawback would be if we wanted to support grad of grad but this is out of
|
||||
// scope.
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Sub(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&grad)?;
|
||||
}
|
||||
Op::Mul(lhs, rhs) => {
|
||||
let lhs_grad = grad.mul(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Div(lhs, rhs) => {
|
||||
let lhs_grad = grad.div(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(_pred, _t, _f) => {
|
||||
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
Op::Matmul(lhs, rhs) => {
|
||||
// Skipping checks, the op went ok, we can skip
|
||||
// the matmul size checks for now.
|
||||
|
||||
let lhs_grad = grad.matmul(&rhs.t()?)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
|
||||
let rhs_grad = lhs.t()?.matmul(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Cat(args, dim) => {
|
||||
let mut start_idx = 0;
|
||||
for arg in args {
|
||||
let len = arg.dims()[*dim];
|
||||
let arg_grad = grad.narrow(*dim, start_idx, len)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?;
|
||||
start_idx += len;
|
||||
}
|
||||
}
|
||||
Op::Broadcast(_arg) => {
|
||||
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||
}
|
||||
Op::Sum(_arg, _sum_dims) => {
|
||||
return Err(Error::BackwardNotSupported { op: "sum" })
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Log(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
}
|
||||
Op::Sin(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
||||
}
|
||||
Op::Cos(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
|
||||
Op::Exp(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
}
|
||||
Op::Narrow(_arg, _, _, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "narrow" })
|
||||
}
|
||||
Op::Softmax(_arg, _) => {
|
||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||
}
|
||||
Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }),
|
||||
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
|
||||
Op::Sqr(arg) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Sqrt(arg) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::ToDevice(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = grad.to_device(&sum_grad.device())?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Transpose(arg, dim1, dim2) => {
|
||||
let arg_grad = grad.transpose(*dim1, *dim2)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(grads)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||
|
||||
impl GradStore {
|
||||
fn new() -> Self {
|
||||
GradStore(HashMap::new())
|
||||
}
|
||||
|
||||
pub fn get_id(&self, id: TensorId) -> Option<&Tensor> {
|
||||
self.0.get(&id)
|
||||
}
|
||||
|
||||
pub fn get(&self, tensor: &Tensor) -> Option<&Tensor> {
|
||||
self.0.get(&tensor.id())
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, tensor: &Tensor) -> Option<Tensor> {
|
||||
self.0.remove(&tensor.id())
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, tensor: &Tensor, grad: Tensor) -> Option<Tensor> {
|
||||
self.0.insert(tensor.id(), grad)
|
||||
}
|
||||
|
||||
fn or_insert(&mut self, tensor: &Tensor) -> Result<&mut Tensor> {
|
||||
use std::collections::hash_map::Entry;
|
||||
let grad = match self.0.entry(tensor.id()) {
|
||||
Entry::Occupied(entry) => entry.into_mut(),
|
||||
Entry::Vacant(entry) => {
|
||||
let grad = tensor.zeros_like()?;
|
||||
entry.insert(grad)
|
||||
}
|
||||
};
|
||||
Ok(grad)
|
||||
}
|
||||
}
|
886
candle-core/src/cpu_backend.rs
Normal file
886
candle-core/src/cpu_backend.rs
Normal file
@ -0,0 +1,886 @@
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||
use gemm::{gemm, Parallelism};
|
||||
use half::{bf16, f16};
|
||||
|
||||
// TODO: Think about whether we would be better off with a dtype and
|
||||
// a buffer as an owned slice of bytes.
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CpuStorage {
|
||||
U32(Vec<u32>),
|
||||
BF16(Vec<bf16>),
|
||||
F16(Vec<f16>),
|
||||
F32(Vec<f32>),
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
fn wcond<T: Copy>(
|
||||
pred: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &[T],
|
||||
stride_t: &[usize],
|
||||
f: &[T],
|
||||
stride_f: &[usize],
|
||||
) -> Vec<T> {
|
||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||
{
|
||||
let elem_count = shape.elem_count();
|
||||
let pred = &pred[..elem_count];
|
||||
let t = &t[..elem_count];
|
||||
let f = &f[..elem_count];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
let dims = shape.dims();
|
||||
let it_p = StridedIndex::new(dims, stride);
|
||||
let it_t = StridedIndex::new(dims, stride_t);
|
||||
let it_f = StridedIndex::new(dims, stride_f);
|
||||
it_p.zip(it_t.zip(it_f))
|
||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map1 {
|
||||
($v: expr, $fn: ident, $( $args:expr ),*) => {{
|
||||
let v = match $v {
|
||||
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
|
||||
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
|
||||
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
|
||||
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
|
||||
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
|
||||
};
|
||||
Ok(v)
|
||||
}};
|
||||
}
|
||||
|
||||
fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
||||
src: &[T],
|
||||
dst_shape: &Shape,
|
||||
src_dims: &[usize],
|
||||
stride: &[usize],
|
||||
to_dst_index: impl Fn(usize) -> usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
if shape.is_contiguous(stride) {
|
||||
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
||||
} else {
|
||||
StridedIndex::new(shape.dims(), stride)
|
||||
.map(|i| f(vs[i]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
let dims = shape.dims();
|
||||
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||
} else {
|
||||
let lhs_index = StridedIndex::new(dims, lhs_stride);
|
||||
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
||||
lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn take_impl1<T: Copy>(
|
||||
vs: &[T],
|
||||
ids: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
||||
for index in StridedIndex::new(shape.dims(), stride) {
|
||||
let index = ids[index].try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "take",
|
||||
});
|
||||
} else {
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
) {
|
||||
let src = &src[src_offset..];
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
|
||||
} else {
|
||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||
let dst_index = dst_index + dst_offset;
|
||||
if dst_index >= dst.len() {
|
||||
break;
|
||||
}
|
||||
dst[dst_index] = src[src_index]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::U32(_) => DType::U32,
|
||||
Self::BF16(_) => DType::BF16,
|
||||
Self::F16(_) => DType::F16,
|
||||
Self::F32(_) => DType::F32,
|
||||
Self::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
|
||||
D::cpu_storage_as_mut_slice(self)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
// TODO: find a way around the quadratic number of cases below.
|
||||
match (self, dtype) {
|
||||
(Self::U32(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32()));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F32(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, bf16::from_f32);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F64(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, bf16::from_f64);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32()));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, f16::from_f32);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, f16::from_f64);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = shape.dims();
|
||||
let mut dst_dims = src_dims.to_vec();
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dst_dims[sum_dim] = 1;
|
||||
}
|
||||
let dst_shape = Shape::from(dst_dims);
|
||||
let mut sum_dims = sum_dims.to_vec();
|
||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||
// indexes.
|
||||
sum_dims.sort();
|
||||
let sum_dims_and_stride: Vec<_> = sum_dims
|
||||
.iter()
|
||||
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
||||
.collect();
|
||||
let to_dst_index = |unstr_index: usize| {
|
||||
// TODO: Optimize, the following does lots of slow division.
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in sum_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst_index
|
||||
};
|
||||
// TODO: Maybe provide an implementation with higher precision accumulators?
|
||||
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
// [self] stores data in a contiguous way.
|
||||
let dims = shape.dims();
|
||||
let elem_per_slice = dims[dim];
|
||||
let prod_pre_dim = dims[..dim].iter().product();
|
||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = bf16::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx].to_f64();
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = f16::from_f64(sum);
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx] as f64;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let sum = sum as f32;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
for pre_idx in 0..prod_pre_dim {
|
||||
for post_idx in 0..prod_post_dim {
|
||||
let mut sum = 0f64;
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
sum += storage[idx];
|
||||
idx += prod_post_dim
|
||||
}
|
||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||
for _ in 0..elem_per_slice {
|
||||
storage[idx] /= sum;
|
||||
idx += prod_post_dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::U32(_) => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Self::U32(storage) => {
|
||||
let mul = mul as u32;
|
||||
let add = add as u32;
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::BF16(storage) => {
|
||||
let mul = bf16::from_f64(mul);
|
||||
let add = bf16::from_f64(add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let mul = f16::from_f64(mul);
|
||||
let add = f16::from_f64(add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let mul = mul as f32;
|
||||
let add = add as f32;
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U32(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(lhs), Self::U32(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => {
|
||||
// This should be covered by the dtype check above.
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
match (self, dst) {
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(_, dst) => {
|
||||
// This should be covered by the dtype check above.
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: dst.dtype(),
|
||||
op: "copy_strided",
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
match (t, f) {
|
||||
(Self::BF16(t), Self::BF16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(t), Self::F16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(t), Self::F32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(t), Self::F64(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(t), Self::U32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: t.dtype(),
|
||||
rhs: f.dtype(),
|
||||
op: "where_cond",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rank = lhs_stride.len();
|
||||
let lhs_cs = lhs_stride[rank - 1];
|
||||
let lhs_rs = lhs_stride[rank - 2];
|
||||
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
if lhs_stride.len() > 2 {
|
||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||
|
||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||
// Temporary error before we support abitrary striding.
|
||||
return Err(Error::UnexpectedStriding);
|
||||
}
|
||||
}
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
let dst_strides = dst_shape.stride_contiguous();
|
||||
let dst_rs = dst_strides[0];
|
||||
let dst_cs = dst_strides[1];
|
||||
|
||||
match (self, rhs) {
|
||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||
let mut dst = vec![f16::ZERO; b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
f16::ONE,
|
||||
// beta: T,
|
||||
f16::ONE,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::F16(dst))
|
||||
}
|
||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||
let mut dst = vec![0f32; b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
1f32,
|
||||
// beta: T,
|
||||
1f32,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self::F32(dst))
|
||||
}
|
||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||
let mut dst = vec![0f64; b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
1f64,
|
||||
// beta: T,
|
||||
1f64,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::None,
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(Self::F64(dst))
|
||||
}
|
||||
_ => {
|
||||
// This should be covered by the dtype check above.
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: "matmul",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U32 => {
|
||||
let data = vec![1u32; elem_count];
|
||||
Self::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = vec![bf16::ONE; elem_count];
|
||||
Self::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = vec![f16::ONE; elem_count];
|
||||
Self::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = vec![1f32; elem_count];
|
||||
Self::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = vec![1f64; elem_count];
|
||||
Self::F64(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::U32 => {
|
||||
let data = vec![0u32; elem_count];
|
||||
Self::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = vec![bf16::ZERO; elem_count];
|
||||
Self::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = vec![f16::ZERO; elem_count];
|
||||
Self::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = vec![0f32; elem_count];
|
||||
Self::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = vec![0f64; elem_count];
|
||||
Self::F64(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn simple_matmul() -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(
|
||||
c.to_vec3::<f32>()?,
|
||||
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
978
candle-core/src/cuda_backend.rs
Normal file
978
candle-core/src/cuda_backend.rs
Normal file
@ -0,0 +1,978 @@
|
||||
use crate::{CpuStorage, DType, Shape};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// cudarc related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CudaError {
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] cudarc::driver::DriverError),
|
||||
|
||||
#[error(transparent)]
|
||||
Compiler(#[from] cudarc::nvrtc::CompileError),
|
||||
|
||||
#[error(transparent)]
|
||||
Cublas(#[from] cudarc::cublas::result::CublasError),
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: String },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{cuda} when loading {module_name}")]
|
||||
Load {
|
||||
cuda: cudarc::driver::DriverError,
|
||||
module_name: String,
|
||||
},
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, CudaError>;
|
||||
|
||||
/// Unique identifier for cuda devices.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct DeviceId(usize);
|
||||
|
||||
impl DeviceId {
|
||||
fn new() -> Self {
|
||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||
use std::sync::atomic;
|
||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
#[allow(dead_code)]
|
||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub(crate) fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal)?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone())?;
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
blas: Arc::new(blas),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn same_id(&self, rhs: &Self) -> bool {
|
||||
self.id == rhs.id
|
||||
}
|
||||
|
||||
pub(crate) fn ordinal(&self) -> usize {
|
||||
self.device.ordinal()
|
||||
}
|
||||
|
||||
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
let data = self.alloc_zeros::<f16>(elem_count)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
let data = self.alloc_zeros::<f64>(elem_count)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})?;
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
U32(CudaSlice<u32>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
F32(CudaSlice<f32>),
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
slice: CudaStorageSlice,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
fn gemm_config<T>(
|
||||
alpha: T,
|
||||
beta: T,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<StridedBatchedConfig<T>> {
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
||||
use cudarc::cublas::sys::cublasOperation_t;
|
||||
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||
// The a tensor has dims batching, k, n (rhs)
|
||||
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
|
||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The b tensor has dims batching, m, k (lhs)
|
||||
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
|
||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||
} else {
|
||||
Err(CudaError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
// The setup below was copied from:
|
||||
// https://github.com/lebedov/scikit-cuda/blob/7e7300474286019c917a6c8a4bca59405c64fbce/tests/test_cublas.py#L531
|
||||
let gemm = GemmConfig {
|
||||
alpha,
|
||||
beta,
|
||||
m: n as i32,
|
||||
n: m as i32,
|
||||
k: k as i32,
|
||||
lda,
|
||||
ldb,
|
||||
ldc: n as i32,
|
||||
transa,
|
||||
transb,
|
||||
};
|
||||
Ok(StridedBatchedConfig {
|
||||
batch_size: b as i32,
|
||||
gemm,
|
||||
stride_a: (m * k) as i64,
|
||||
stride_b: (n * k) as i64,
|
||||
stride_c: (m * n) as i64,
|
||||
})
|
||||
}
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
||||
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
|
||||
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
|
||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
||||
};
|
||||
let device = self.device.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
CudaStorageSlice::F32(_) => DType::F32,
|
||||
CudaStorageSlice::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
use cudarc::driver::DevicePtr;
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U32(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::F64(inp) => inp.device_ptr(),
|
||||
};
|
||||
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
||||
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
||||
let slice = match dtype {
|
||||
DType::U32 => {
|
||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
DType::BF16 => {
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
DType::F16 => {
|
||||
let out = unsafe { dev.alloc::<f16>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
DType::F32 => {
|
||||
let out = unsafe { dev.alloc::<f32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
DType::F64 => {
|
||||
let out = unsafe { dev.alloc::<f64>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||
let params = (
|
||||
el_count,
|
||||
dims.len(),
|
||||
&ds,
|
||||
arg,
|
||||
&out,
|
||||
bf16::from_f64(mul),
|
||||
bf16::from_f64(add),
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||
let params = (
|
||||
el_count,
|
||||
dims.len(),
|
||||
&ds,
|
||||
arg,
|
||||
&out,
|
||||
f16::from_f64(mul),
|
||||
f16::from_f64(add),
|
||||
);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let mut dst_el = el;
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dst_el /= src_dims[sum_dim];
|
||||
}
|
||||
let mut sum_dims = sum_dims.to_vec();
|
||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||
// indexes.
|
||||
sum_dims.sort();
|
||||
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||
let sum_dims_s: Vec<usize> = sum_dims
|
||||
.iter()
|
||||
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||
.collect();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<u32>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f64>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(CudaError::InternalError(
|
||||
"TODO: implement divide_by_sum_over_dim",
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(_arg) => {
|
||||
todo!("No unary kernels for u32");
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
||||
let params = (el_count, dims.len(), &ds, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
||||
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
|
||||
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
// The dtypes should have been checked at this point so this is an internal error.
|
||||
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::U32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::BF16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::BF16(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::F16(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::F16(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::F32(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::F32(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::F64(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::F64(cpu_storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
let ids = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => slice,
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u32",
|
||||
expected: DType::U32,
|
||||
got: self.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
||||
let slice = match (&t.slice, &f.slice) {
|
||||
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
||||
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
||||
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
|
||||
let out = unsafe { dev.alloc::<f64>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
|
||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, ids, t, f, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
// The dtypes should have been checked at this point so this is an internal error.
|
||||
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
rhs: &Self,
|
||||
h_size: usize, // hidden size
|
||||
v_size: usize, // vocab size
|
||||
) -> Result<Self> {
|
||||
let ids = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => slice,
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "embedding ids should be u32",
|
||||
expected: DType::U32,
|
||||
got: self.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let slice = match &rhs.slice {
|
||||
// The kernels below assume that rhs is contiguous.
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
|
||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let elem_count = b * m * n;
|
||||
let dev = &self.device;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::BF16(_lhs), CudaStorageSlice::BF16(_rhs)) => {
|
||||
todo!("bf16")
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}?;
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
_ => return Err(CudaError::InternalError("dtype mismatch in matmul op")),
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(CudaError::InternalError(
|
||||
"dtype mismatch in copy_strided op",
|
||||
))
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
159
candle-core/src/device.rs
Normal file
159
candle-core/src/device.rs
Normal file
@ -0,0 +1,159 @@
|
||||
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
|
||||
|
||||
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
|
||||
/// can live on the same location (typically for cuda devices).
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DeviceLocation {
|
||||
Cpu,
|
||||
Cuda { gpu_id: usize },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Device {
|
||||
Cpu,
|
||||
Cuda(crate::CudaDevice),
|
||||
}
|
||||
|
||||
// TODO: Should we back the cpu implementation using the NdArray crate or similar?
|
||||
pub trait NdArray {
|
||||
fn shape(&self) -> Result<Shape>;
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage;
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for S {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(&[*self])
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N: usize> NdArray for &[S; N] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType> NdArray for &[S] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from(self.len()))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((M, N)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
S::to_cpu_storage_owned(self.concat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
|
||||
for &[[[S; N3]; N2]; N1]
|
||||
{
|
||||
fn shape(&self) -> Result<Shape> {
|
||||
Ok(Shape::from((N1, N2, N3)))
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> CpuStorage {
|
||||
let mut vec = Vec::new();
|
||||
vec.reserve(N1 * N2 * N3);
|
||||
for i1 in 0..N1 {
|
||||
for i2 in 0..N2 {
|
||||
vec.extend(self[i1][i2])
|
||||
}
|
||||
}
|
||||
S::to_cpu_storage_owned(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl Device {
|
||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||
}
|
||||
|
||||
pub fn same_id(&self, rhs: &Self) -> bool {
|
||||
match (self, rhs) {
|
||||
(Self::Cpu, Self::Cpu) => true,
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_id(rhs),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn location(&self) -> DeviceLocation {
|
||||
match self {
|
||||
Self::Cpu => DeviceLocation::Cpu,
|
||||
Self::Cuda(device) => DeviceLocation::Cuda {
|
||||
gpu_id: device.ordinal(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_cuda(&self) -> bool {
|
||||
match self {
|
||||
Self::Cpu => false,
|
||||
Self::Cuda(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::ones_impl(shape, dtype);
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuStorage::zeros_impl(shape, dtype);
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.zeros_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = array.to_cpu_storage();
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
|
||||
Device::Cuda(device) => {
|
||||
let storage = S::to_cpu_storage_owned(data);
|
||||
let storage = device.cuda_from_cpu_storage(&storage)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
96
candle-core/src/dtype.rs
Normal file
96
candle-core/src/dtype.rs
Normal file
@ -0,0 +1,96 @@
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum DType {
|
||||
U32,
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
}
|
||||
|
||||
impl DType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::U32 => "u32",
|
||||
Self::BF16 => "bf16",
|
||||
Self::F16 => "f16",
|
||||
Self::F32 => "f32",
|
||||
Self::F64 => "f64",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn size_in_bytes(&self) -> usize {
|
||||
match self {
|
||||
Self::U32 => 4,
|
||||
Self::BF16 => 2,
|
||||
Self::F16 => 2,
|
||||
Self::F32 => 4,
|
||||
Self::F64 => 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WithDType: Sized + Copy {
|
||||
const DTYPE: DType;
|
||||
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
Self::to_cpu_storage_owned(data.to_vec())
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
|
||||
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
|
||||
}
|
||||
|
||||
macro_rules! with_dtype {
|
||||
($ty:ty, $dtype:ident) => {
|
||||
impl WithDType for $ty {
|
||||
const DTYPE: DType = DType::$dtype;
|
||||
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::$dtype(data)
|
||||
}
|
||||
|
||||
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
with_dtype!(u32, U32);
|
||||
with_dtype!(half::f16, F16);
|
||||
with_dtype!(half::bf16, BF16);
|
||||
with_dtype!(f32, F32);
|
||||
with_dtype!(f64, F64);
|
136
candle-core/src/dummy_cuda_backend.rs
Normal file
136
candle-core/src/dummy_cuda_backend.rs
Normal file
@ -0,0 +1,136 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::{CpuStorage, DType, Error, Result, Shape};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum DummyError {}
|
||||
pub type CudaError = DummyError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CudaDevice;
|
||||
|
||||
macro_rules! fail {
|
||||
() => {
|
||||
unimplemented!("cuda support has not been enabled")
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub(crate) fn new(_: usize) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn same_id(&self, _: &Self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn ordinal(&self) -> usize {
|
||||
fail!()
|
||||
}
|
||||
|
||||
pub(crate) fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage;
|
||||
|
||||
impl CudaStorage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
fail!()
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
fail!()
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
102
candle-core/src/error.rs
Normal file
102
candle-core/src/error.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use crate::{DType, DeviceLocation, Shape};
|
||||
|
||||
/// Main library error type.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||
UnexpectedDType {
|
||||
msg: &'static str,
|
||||
expected: DType,
|
||||
got: DType,
|
||||
},
|
||||
|
||||
#[error("{op} only supports contiguous tensors")]
|
||||
RequiresContiguous { op: &'static str },
|
||||
|
||||
#[error("{op} expects at least one tensor")]
|
||||
OpRequiresAtLeastOneTensor { op: &'static str },
|
||||
|
||||
#[error("backward is not supported for {op}")]
|
||||
BackwardNotSupported { op: &'static str },
|
||||
|
||||
#[error("{op} invalid index {index} with vocab {vocab_size}")]
|
||||
InvalidIndex {
|
||||
op: &'static str,
|
||||
index: usize,
|
||||
vocab_size: usize,
|
||||
},
|
||||
|
||||
#[error("the candle crate has not been built with cuda support")]
|
||||
NotCompiledWithCudaSupport,
|
||||
|
||||
#[error(
|
||||
"Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
|
||||
)]
|
||||
ShapeMismatch { buffer_size: usize, shape: Shape },
|
||||
|
||||
#[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
ShapeMismatchBinaryOp {
|
||||
lhs: Shape,
|
||||
rhs: Shape,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
|
||||
ShapeMismatchCat {
|
||||
dim: usize,
|
||||
first_shape: Shape,
|
||||
n: usize,
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DeviceMismatchBinaryOp {
|
||||
lhs: DeviceLocation,
|
||||
rhs: DeviceLocation,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DTypeMismatchBinaryOp {
|
||||
lhs: DType,
|
||||
rhs: DType,
|
||||
op: &'static str,
|
||||
},
|
||||
|
||||
#[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
|
||||
UnexpectedNumberOfDims {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
shape: Shape,
|
||||
},
|
||||
|
||||
// TODO this is temporary when we support arbitrary matmul
|
||||
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
||||
UnexpectedStriding,
|
||||
|
||||
#[error(transparent)]
|
||||
Cuda(#[from] crate::CudaError),
|
||||
|
||||
#[error(transparent)]
|
||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||
|
||||
#[error("npy/npz error {0}")]
|
||||
Npy(String),
|
||||
|
||||
/// Zip file format error.
|
||||
#[error(transparent)]
|
||||
Zip(#[from] zip::result::ZipError),
|
||||
|
||||
/// Integer parse error.
|
||||
#[error(transparent)]
|
||||
ParseInt(#[from] std::num::ParseIntError),
|
||||
|
||||
/// I/O error.
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
29
candle-core/src/lib.rs
Normal file
29
candle-core/src/lib.rs
Normal file
@ -0,0 +1,29 @@
|
||||
mod backprop;
|
||||
mod cpu_backend;
|
||||
#[cfg(feature = "cuda")]
|
||||
mod cuda_backend;
|
||||
mod device;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod error;
|
||||
mod npy;
|
||||
mod op;
|
||||
mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use shape::Shape;
|
||||
pub use storage::Storage;
|
||||
use strided_index::StridedIndex;
|
||||
pub use tensor::{Tensor, TensorId};
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_backend::{CudaDevice, CudaError, CudaStorage};
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaError, CudaStorage};
|
401
candle-core/src/npy.rs
Normal file
401
candle-core/src/npy.rs
Normal file
@ -0,0 +1,401 @@
|
||||
//! Numpy support for literals.
|
||||
//!
|
||||
//! The spec for the npy format can be found in
|
||||
//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
|
||||
//! The functions from this module can be used to read literals from npy/npz files
|
||||
//! or write literals to these files. A npy file contains a single literal (unnamed)
|
||||
//! whereas a npz file can contain multiple named literals. npz files are also compressed.
|
||||
//!
|
||||
//! These two formats are easy to use in Python using the numpy library.
|
||||
//!
|
||||
//! ```python
|
||||
//! import numpy as np
|
||||
//! x = np.arange(10)
|
||||
//!
|
||||
//! # Write a npy file.
|
||||
//! np.save("test.npy", x)
|
||||
//!
|
||||
//! # Read a value from the npy file.
|
||||
//! x = np.load("test.npy")
|
||||
//!
|
||||
//! # Write multiple values to a npz file.
|
||||
//! values = { "x": x, "x_plus_one": x + 1 }
|
||||
//! np.savez("test.npz", **values)
|
||||
//!
|
||||
//! # Load multiple values from a npz file.
|
||||
//! values = np.loadz("test.npz")
|
||||
//! ```
|
||||
use crate::{DType, Device, Error, Result, Shape, Tensor};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use half::{bf16, f16, slice::HalfFloatSliceExt};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read, Write};
|
||||
use std::path::Path;
|
||||
|
||||
const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
|
||||
const NPY_SUFFIX: &str = ".npy";
|
||||
|
||||
fn read_header<R: Read>(reader: &mut R) -> Result<String> {
|
||||
let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
|
||||
reader.read_exact(&mut magic_string)?;
|
||||
if magic_string != NPY_MAGIC_STRING {
|
||||
return Err(Error::Npy("magic string mismatch".to_string()));
|
||||
}
|
||||
let mut version = [0u8; 2];
|
||||
reader.read_exact(&mut version)?;
|
||||
let header_len_len = match version[0] {
|
||||
1 => 2,
|
||||
2 => 4,
|
||||
otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
|
||||
};
|
||||
let mut header_len = vec![0u8; header_len_len];
|
||||
reader.read_exact(&mut header_len)?;
|
||||
let header_len = header_len
|
||||
.iter()
|
||||
.rev()
|
||||
.fold(0_usize, |acc, &v| 256 * acc + v as usize);
|
||||
let mut header = vec![0u8; header_len];
|
||||
reader.read_exact(&mut header)?;
|
||||
Ok(String::from_utf8_lossy(&header).to_string())
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct Header {
|
||||
descr: DType,
|
||||
fortran_order: bool,
|
||||
shape: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Header {
|
||||
fn shape(&self) -> Shape {
|
||||
Shape::from(self.shape.as_slice())
|
||||
}
|
||||
|
||||
fn to_string(&self) -> Result<String> {
|
||||
let fortran_order = if self.fortran_order { "True" } else { "False" };
|
||||
let mut shape = self
|
||||
.shape
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
let descr = match self.descr {
|
||||
DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
|
||||
DType::F16 => "f2",
|
||||
DType::F32 => "f4",
|
||||
DType::F64 => "f8",
|
||||
DType::U32 => "u4",
|
||||
};
|
||||
if !shape.is_empty() {
|
||||
shape.push(',')
|
||||
}
|
||||
Ok(format!(
|
||||
"{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
|
||||
))
|
||||
}
|
||||
|
||||
// Hacky parser for the npy header, a typical example would be:
|
||||
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
|
||||
fn parse(header: &str) -> Result<Header> {
|
||||
let header =
|
||||
header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
|
||||
|
||||
let mut parts: Vec<String> = vec![];
|
||||
let mut start_index = 0usize;
|
||||
let mut cnt_parenthesis = 0i64;
|
||||
for (index, c) in header.chars().enumerate() {
|
||||
match c {
|
||||
'(' => cnt_parenthesis += 1,
|
||||
')' => cnt_parenthesis -= 1,
|
||||
',' => {
|
||||
if cnt_parenthesis == 0 {
|
||||
parts.push(header[start_index..index].to_owned());
|
||||
start_index = index + 1;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
parts.push(header[start_index..].to_owned());
|
||||
let mut part_map: HashMap<String, String> = HashMap::new();
|
||||
for part in parts.iter() {
|
||||
let part = part.trim();
|
||||
if !part.is_empty() {
|
||||
match part.split(':').collect::<Vec<_>>().as_slice() {
|
||||
[key, value] => {
|
||||
let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
|
||||
let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
|
||||
let _ = part_map.insert(key.to_owned(), value.to_owned());
|
||||
}
|
||||
_ => return Err(Error::Npy(format!("unable to parse header {header}"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
let fortran_order = match part_map.get("fortran_order") {
|
||||
None => false,
|
||||
Some(fortran_order) => match fortran_order.as_ref() {
|
||||
"False" => false,
|
||||
"True" => true,
|
||||
_ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
|
||||
},
|
||||
};
|
||||
let descr = match part_map.get("descr") {
|
||||
None => return Err(Error::Npy("no descr in header".to_string())),
|
||||
Some(descr) => {
|
||||
if descr.is_empty() {
|
||||
return Err(Error::Npy("empty descr".to_string()));
|
||||
}
|
||||
if descr.starts_with('>') {
|
||||
return Err(Error::Npy(format!("little-endian descr {descr}")));
|
||||
}
|
||||
// the only supported types in tensor are:
|
||||
// float64, float32, float16,
|
||||
// complex64, complex128,
|
||||
// int64, int32, int16, int8,
|
||||
// uint8, and bool.
|
||||
match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
|
||||
"e" | "f2" => DType::F16,
|
||||
"f" | "f4" => DType::F32,
|
||||
"d" | "f8" => DType::F64,
|
||||
// "i" | "i4" => DType::S32,
|
||||
// "q" | "i8" => DType::S64,
|
||||
// "h" | "i2" => DType::S16,
|
||||
// "b" | "i1" => DType::S8,
|
||||
// "B" | "u1" => DType::U8,
|
||||
"I" | "u4" => DType::U32,
|
||||
// "?" | "b1" => DType::Pred,
|
||||
// "F" | "F4" => DType::C64,
|
||||
// "D" | "F8" => DType::C128,
|
||||
descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
|
||||
}
|
||||
}
|
||||
};
|
||||
let shape = match part_map.get("shape") {
|
||||
None => return Err(Error::Npy("no shape in header".to_string())),
|
||||
Some(shape) => {
|
||||
let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
|
||||
if shape.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
shape
|
||||
.split(',')
|
||||
.map(|v| v.trim().parse::<usize>())
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Header {
|
||||
descr,
|
||||
fortran_order,
|
||||
shape,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
// TODO: Add the possibility to read directly to a device?
|
||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
match dtype {
|
||||
DType::BF16 => {
|
||||
let mut data_t = vec![bf16::ZERO; elem_count];
|
||||
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data_t = vec![f16::ZERO; elem_count];
|
||||
reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data_t = vec![0f32; elem_count];
|
||||
reader.read_f32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data_t = vec![0f64; elem_count];
|
||||
reader.read_f64_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
DType::U32 => {
|
||||
let mut data_t = vec![0u32; elem_count];
|
||||
reader.read_u32_into::<LittleEndian>(&mut data_t)?;
|
||||
Tensor::from_vec(data_t, shape, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a npy file and return the stored multi-dimensional array as a literal.
|
||||
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
|
||||
let mut reader = File::open(path.as_ref())?;
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let mut data: Vec<u8> = vec![];
|
||||
reader.read_to_end(&mut data)?;
|
||||
Self::from_reader(header.shape(), header.descr, &mut reader)
|
||||
}
|
||||
|
||||
/// Reads a npz file and returns the stored multi-dimensional arrays together with their names.
|
||||
pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
|
||||
let zip_reader = BufReader::new(File::open(path.as_ref())?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut result = vec![];
|
||||
for i in 0..zip.len() {
|
||||
let mut reader = zip.by_index(i).unwrap();
|
||||
let name = {
|
||||
let name = reader.name();
|
||||
name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
|
||||
};
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
|
||||
result.push((name, s))
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
|
||||
pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
|
||||
let zip_reader = BufReader::new(File::open(path.as_ref())?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut result = vec![];
|
||||
for name in names.iter() {
|
||||
let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
|
||||
Ok(reader) => reader,
|
||||
Err(_) => Err(Error::Npy(format!(
|
||||
"no array for {name} in {:?}",
|
||||
path.as_ref()
|
||||
)))?,
|
||||
};
|
||||
let header = read_header(&mut reader)?;
|
||||
let header = Header::parse(&header)?;
|
||||
if header.fortran_order {
|
||||
return Err(Error::Npy("fortran order not supported".to_string()));
|
||||
}
|
||||
let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
|
||||
result.push(s)
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn write<T: Write>(&self, f: &mut T) -> Result<()> {
|
||||
f.write_all(NPY_MAGIC_STRING)?;
|
||||
f.write_all(&[1u8, 0u8])?;
|
||||
let header = Header {
|
||||
descr: self.dtype(),
|
||||
fortran_order: false,
|
||||
shape: self.dims().to_vec(),
|
||||
};
|
||||
let mut header = header.to_string()?;
|
||||
let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
|
||||
for _ in 0..pad % 16 {
|
||||
header.push(' ')
|
||||
}
|
||||
header.push('\n');
|
||||
f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
|
||||
f.write_all(header.as_bytes())?;
|
||||
let elem_count = self.elem_count();
|
||||
match self.dtype() {
|
||||
DType::BF16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<bf16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
let vs = self.reshape(elem_count)?.to_vec1::<f16>()?;
|
||||
for &v in vs.reinterpret_cast() {
|
||||
f.write_u16::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
// TODO: Avoid using a buffer when data is already on the CPU.
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f32>()? {
|
||||
f.write_f32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::F64 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<f64>()? {
|
||||
f.write_f64::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
DType::U32 => {
|
||||
for v in self.reshape(elem_count)?.to_vec1::<u32>()? {
|
||||
f.write_u32::<LittleEndian>(v)?
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes a multi-dimensional array in the npy format.
|
||||
pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
|
||||
let mut f = File::create(path.as_ref())?;
|
||||
self.write(&mut f)
|
||||
}
|
||||
|
||||
/// Writes multiple multi-dimensional arrays using the npz format.
|
||||
pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
|
||||
ts: &[(S, T)],
|
||||
path: P,
|
||||
) -> Result<()> {
|
||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||
let options =
|
||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||
|
||||
for (name, tensor) in ts.iter() {
|
||||
zip.start_file(format!("{}.npy", name.as_ref()), options)?;
|
||||
tensor.as_ref().write(&mut zip)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::Header;
|
||||
|
||||
#[test]
|
||||
fn parse() {
|
||||
let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
|
||||
assert_eq!(
|
||||
Header::parse(h).unwrap(),
|
||||
Header {
|
||||
descr: crate::DType::F64,
|
||||
fortran_order: false,
|
||||
shape: vec![128]
|
||||
}
|
||||
);
|
||||
let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
|
||||
let h = Header::parse(h).unwrap();
|
||||
assert_eq!(
|
||||
h,
|
||||
Header {
|
||||
descr: crate::DType::F32,
|
||||
fortran_order: true,
|
||||
shape: vec![256, 1, 128]
|
||||
}
|
||||
);
|
||||
assert_eq!(
|
||||
h.to_string().unwrap(),
|
||||
"{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
|
||||
);
|
||||
|
||||
let h = Header {
|
||||
descr: crate::DType::U32,
|
||||
fortran_order: false,
|
||||
shape: vec![],
|
||||
};
|
||||
assert_eq!(
|
||||
h.to_string().unwrap(),
|
||||
"{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
|
||||
);
|
||||
}
|
||||
}
|
197
candle-core/src/op.rs
Normal file
197
candle-core/src/op.rs
Normal file
@ -0,0 +1,197 @@
|
||||
use crate::Tensor;
|
||||
use half::{bf16, f16};
|
||||
use num_traits::float::Float;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
Add(Tensor, Tensor),
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
#[allow(dead_code)] // add is currently unused.
|
||||
Affine {
|
||||
arg: Tensor,
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
Sum(Tensor, Vec<usize>),
|
||||
ToDType(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
Cos(Tensor),
|
||||
Abs(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
Neg(Tensor),
|
||||
Reshape(Tensor),
|
||||
Softmax(Tensor, usize),
|
||||
Sqr(Tensor),
|
||||
Sqrt(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Gelu(Tensor),
|
||||
// TODO: Support for custom ops.
|
||||
}
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
const KERNEL_BF16: &'static str;
|
||||
const KERNEL_F16: &'static str;
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
const KERNEL_U32: &'static str;
|
||||
fn bf16(v1: bf16) -> bf16;
|
||||
fn f16(v1: f16) -> f16;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
fn u32(v1: u32) -> u32;
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
const KERNEL_BF16: &'static str;
|
||||
const KERNEL_F16: &'static str;
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
const KERNEL_U32: &'static str;
|
||||
fn bf16(v1: bf16, v2: bf16) -> bf16;
|
||||
fn f16(v1: f16, v2: f16) -> f16;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn u32(v1: u32, v2: u32) -> u32;
|
||||
}
|
||||
|
||||
pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Exp;
|
||||
pub(crate) struct Log;
|
||||
pub(crate) struct Sin;
|
||||
pub(crate) struct Cos;
|
||||
pub(crate) struct Abs;
|
||||
pub(crate) struct Neg;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
pub(crate) struct Gelu;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr) => {
|
||||
impl BinaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL_BF16: &'static str = concat!("b", $name, "_bf16");
|
||||
const KERNEL_F16: &'static str = concat!("b", $name, "_f16");
|
||||
const KERNEL_F32: &'static str = concat!("b", $name, "_f32");
|
||||
const KERNEL_F64: &'static str = concat!("b", $name, "_f64");
|
||||
const KERNEL_U32: &'static str = concat!("b", $name, "_u32");
|
||||
fn bf16(v1: bf16, v2: bf16) -> bf16 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn f16(v1: f16, v2: f16) -> f16 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn f32(v1: f32, v2: f32) -> f32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
fn u32(v1: u32, v2: u32) -> u32 {
|
||||
$e(v1, v2)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bin_op!(Add, "add", |v1, v2| v1 + v2);
|
||||
bin_op!(Sub, "sub", |v1, v2| v1 - v2);
|
||||
bin_op!(Mul, "mul", |v1, v2| v1 * v2);
|
||||
bin_op!(Div, "div", |v1, v2| v1 / v2);
|
||||
|
||||
macro_rules! unary_op {
|
||||
($op: ident, $name: literal, $a: ident, $e: expr) => {
|
||||
impl UnaryOp for $op {
|
||||
const NAME: &'static str = $name;
|
||||
const KERNEL_BF16: &'static str = concat!("u", $name, "_bf16");
|
||||
const KERNEL_F16: &'static str = concat!("u", $name, "_f16");
|
||||
const KERNEL_F32: &'static str = concat!("u", $name, "_f32");
|
||||
const KERNEL_F64: &'static str = concat!("u", $name, "_f64");
|
||||
const KERNEL_U32: &'static str = concat!("u", $name, "_u32");
|
||||
fn bf16($a: bf16) -> bf16 {
|
||||
$e
|
||||
}
|
||||
fn f16($a: f16) -> f16 {
|
||||
$e
|
||||
}
|
||||
fn f32($a: f32) -> f32 {
|
||||
$e
|
||||
}
|
||||
fn f64($a: f64) -> f64 {
|
||||
$e
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
todo!("no unary function for u32")
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
unary_op!(Exp, "exp", v, v.exp());
|
||||
unary_op!(Log, "log", v, v.ln());
|
||||
unary_op!(Sin, "sin", v, v.sin());
|
||||
unary_op!(Cos, "cos", v, v.cos());
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Sqr, "sqr", v, v * v);
|
||||
unary_op!(Sqrt, "sqrt", v, v.sqrt());
|
||||
|
||||
/// `gelu` operation
|
||||
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
|
||||
impl UnaryOp for Gelu {
|
||||
const NAME: &'static str = "gelu";
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
bf16::from_f32_const(0.5)
|
||||
* v
|
||||
* (bf16::ONE
|
||||
+ bf16::tanh(
|
||||
(bf16::from_f32_const(2.0) / bf16::PI).sqrt()
|
||||
* v
|
||||
* (bf16::ONE + bf16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
fn f16(v: f16) -> f16 {
|
||||
f16::from_f32_const(0.5)
|
||||
* v
|
||||
* (f16::ONE
|
||||
+ f16::tanh(
|
||||
(f16::from_f32_const(2.0) / f16::PI).sqrt()
|
||||
* v
|
||||
* (f16::ONE + f16::from_f32_const(0.044715) * v * v),
|
||||
))
|
||||
}
|
||||
fn f32(v: f32) -> f32 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f32::tanh((2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
fn f64(v: f64) -> f64 {
|
||||
0.5 * v
|
||||
* (1.0
|
||||
+ f64::tanh((2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)))
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_BF16: &'static str = "gelu_bf16";
|
||||
const KERNEL_F16: &'static str = "gelu_f16";
|
||||
const KERNEL_F32: &'static str = "gelu_f32";
|
||||
const KERNEL_F64: &'static str = "gelu_f64";
|
||||
const KERNEL_U32: &'static str = "gelu_u32";
|
||||
}
|
199
candle-core/src/shape.rs
Normal file
199
candle-core/src/shape.rs
Normal file
@ -0,0 +1,199 @@
|
||||
use crate::{Error, Result};
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Shape(Vec<usize>);
|
||||
|
||||
impl std::fmt::Debug for Shape {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", &self.dims())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const C: usize> From<&[usize; C]> for Shape {
|
||||
fn from(dims: &[usize; C]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[usize]> for Shape {
|
||||
fn from(dims: &[usize]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Shape> for Shape {
|
||||
fn from(shape: &Shape) -> Self {
|
||||
Self(shape.0.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<()> for Shape {
|
||||
fn from(_: ()) -> Self {
|
||||
Self(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for Shape {
|
||||
fn from(d1: usize) -> Self {
|
||||
Self(vec![d1])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize)> for Shape {
|
||||
fn from(d12: (usize, usize)) -> Self {
|
||||
Self(vec![d12.0, d12.1])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize)> for Shape {
|
||||
fn from(d123: (usize, usize, usize)) -> Self {
|
||||
Self(vec![d123.0, d123.1, d123.2])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize)> for Shape {
|
||||
fn from(d1234: (usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: $cnt,
|
||||
got: self.0.len(),
|
||||
shape: self.clone(),
|
||||
})
|
||||
} else {
|
||||
Ok($dims(&self.0))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn from_dims(dims: &[usize]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
pub fn into_dims(self) -> Vec<usize> {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn elem_count(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
3,
|
||||
|d: &[usize]| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
4,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||
let mut stride: Vec<_> = self
|
||||
.0
|
||||
.iter()
|
||||
.rev()
|
||||
.scan(1, |prod, u| {
|
||||
let prod_pre_mult = *prod;
|
||||
*prod *= u;
|
||||
Some(prod_pre_mult)
|
||||
})
|
||||
.collect();
|
||||
stride.reverse();
|
||||
stride
|
||||
}
|
||||
|
||||
/// Returns true if the strides are C contiguous (aka row major).
|
||||
pub fn is_contiguous(&self, stride: &[usize]) -> bool {
|
||||
if self.0.len() != stride.len() {
|
||||
return false;
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
|
||||
if stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Returns true if the strides are Fortran contiguous (aka column major).
|
||||
pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
|
||||
if self.0.len() != stride.len() {
|
||||
return false;
|
||||
}
|
||||
let mut acc = 1;
|
||||
for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
|
||||
if stride != acc {
|
||||
return false;
|
||||
}
|
||||
acc *= dim;
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
|
||||
self.0.extend(additional_dims);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn stride() {
|
||||
let shape = Shape::from(());
|
||||
assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
|
||||
let shape = Shape::from(42);
|
||||
assert_eq!(shape.stride_contiguous(), [1]);
|
||||
let shape = Shape::from((42, 1337));
|
||||
assert_eq!(shape.stride_contiguous(), [1337, 1]);
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
}
|
261
candle-core/src/storage.rs
Normal file
261
candle-core/src/storage.rs
Normal file
@ -0,0 +1,261 @@
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
#[derive(Debug)]
|
||||
pub enum Storage {
|
||||
Cpu(CpuStorage),
|
||||
Cuda(CudaStorage),
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
pub fn try_clone(&self) -> Result<Self> {
|
||||
match self {
|
||||
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.try_clone()?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> Device {
|
||||
match self {
|
||||
Self::Cpu(_) => Device::Cpu,
|
||||
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Self::Cpu(storage) => storage.dtype(),
|
||||
Self::Cuda(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.device().location();
|
||||
let rhs = rhs.device().location();
|
||||
if lhs != rhs {
|
||||
Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op })
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
|
||||
let lhs = self.dtype();
|
||||
let rhs = rhs.dtype();
|
||||
if lhs != rhs {
|
||||
Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op })
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, B::NAME)?;
|
||||
self.same_dtype(rhs, B::NAME)?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: B::NAME,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(t, "where")?;
|
||||
self.same_device(f, "where")?;
|
||||
t.same_dtype(f, "where")?;
|
||||
match (self, t, f) {
|
||||
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
rhs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "embedding")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "matmul")?;
|
||||
self.same_dtype(rhs, "matmul")?;
|
||||
match (self, rhs) {
|
||||
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "matmul",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// self, the source can be strided whereas dst is contiguous.
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => {
|
||||
src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy",
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
61
candle-core/src/strided_index.rs
Normal file
61
candle-core/src/strided_index.rs
Normal file
@ -0,0 +1,61 @@
|
||||
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
||||
/// flat buffer using some potential strides.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StridedIndex<'a> {
|
||||
next_storage_index: Option<usize>,
|
||||
multi_index: Vec<usize>,
|
||||
dims: &'a [usize],
|
||||
stride: &'a [usize],
|
||||
}
|
||||
|
||||
impl<'a> StridedIndex<'a> {
|
||||
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
|
||||
let elem_count: usize = dims.iter().product();
|
||||
let next_storage_index = if elem_count == 0 {
|
||||
None
|
||||
} else {
|
||||
// This applies to the scalar case.
|
||||
Some(0)
|
||||
};
|
||||
StridedIndex {
|
||||
next_storage_index,
|
||||
multi_index: vec![0; dims.len()],
|
||||
dims,
|
||||
stride,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for StridedIndex<'a> {
|
||||
type Item = usize;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let storage_index = match self.next_storage_index {
|
||||
None => return None,
|
||||
Some(storage_index) => storage_index,
|
||||
};
|
||||
let mut updated = false;
|
||||
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
||||
let next_i = *multi_i + 1;
|
||||
if next_i < *max_i {
|
||||
*multi_i = next_i;
|
||||
updated = true;
|
||||
break;
|
||||
} else {
|
||||
*multi_i = 0
|
||||
}
|
||||
}
|
||||
self.next_storage_index = if updated {
|
||||
let next_storage_index = self
|
||||
.multi_index
|
||||
.iter()
|
||||
.zip(self.stride.iter())
|
||||
.map(|(&x, &y)| x * y)
|
||||
.sum();
|
||||
Some(next_storage_index)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Some(storage_index)
|
||||
}
|
||||
}
|
1062
candle-core/src/tensor.rs
Normal file
1062
candle-core/src/tensor.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user