mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge pull request #27 from LaurentMazare/layout-refactor
Refactor the stride/shape handling
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Result, Shape, StridedIndex};
|
||||
use crate::{DType, Error, Layout, Result, Shape};
|
||||
use gemm::{gemm, Parallelism};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -18,31 +18,31 @@ pub enum CpuStorage {
|
||||
|
||||
fn wcond<T: Copy>(
|
||||
pred: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
t: &[T],
|
||||
stride_t: &[usize],
|
||||
layout_t: &Layout,
|
||||
f: &[T],
|
||||
stride_f: &[usize],
|
||||
layout_f: &Layout,
|
||||
) -> Vec<T> {
|
||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||
{
|
||||
let elem_count = shape.elem_count();
|
||||
let pred = &pred[..elem_count];
|
||||
let t = &t[..elem_count];
|
||||
let f = &f[..elem_count];
|
||||
match (
|
||||
layout.contiguous_offsets(),
|
||||
layout_t.contiguous_offsets(),
|
||||
layout_f.contiguous_offsets(),
|
||||
) {
|
||||
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
|
||||
let pred = &pred[o1..o2];
|
||||
let t = &t[o_t1..o_t2];
|
||||
let f = &f[o_f1..o_f2];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
let dims = shape.dims();
|
||||
let it_p = StridedIndex::new(dims, stride);
|
||||
let it_t = StridedIndex::new(dims, stride_t);
|
||||
let it_f = StridedIndex::new(dims, stride_f);
|
||||
it_p.zip(it_t.zip(it_f))
|
||||
}
|
||||
_ => layout
|
||||
.strided_index()
|
||||
.zip(layout_t.strided_index().zip(layout_f.strided_index()))
|
||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.collect::<Vec<_>>()
|
||||
.collect::<Vec<_>>(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,64 +62,50 @@ macro_rules! map1 {
|
||||
fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
||||
src: &[T],
|
||||
dst_shape: &Shape,
|
||||
src_dims: &[usize],
|
||||
stride: &[usize],
|
||||
src_layout: &Layout,
|
||||
to_dst_index: impl Fn(usize) -> usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() {
|
||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
vs: &[T],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mut f: F,
|
||||
) -> Vec<U> {
|
||||
if shape.is_contiguous(stride) {
|
||||
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect()
|
||||
} else {
|
||||
StridedIndex::new(shape.dims(), stride)
|
||||
.map(|i| f(vs[i]))
|
||||
.collect()
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||
match layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(),
|
||||
None => layout.strided_index().map(|i| f(vs[i])).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
let dims = shape.dims();
|
||||
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||
} else {
|
||||
let lhs_index = StridedIndex::new(dims, lhs_stride);
|
||||
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
||||
lhs_index
|
||||
.zip(rhs_index)
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn take_impl1<T: Copy>(
|
||||
vs: &[T],
|
||||
ids: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
||||
for index in StridedIndex::new(shape.dims(), stride) {
|
||||
fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
|
||||
for index in layout.strided_index() {
|
||||
let index = ids[index].try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
@ -138,17 +124,15 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
src_l: &Layout,
|
||||
) {
|
||||
let src = &src[src_offset..];
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
|
||||
} else {
|
||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o_dst1, o_dst2)) => {
|
||||
let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1);
|
||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2])
|
||||
}
|
||||
None => {
|
||||
for (dst_index, src_index) in src_l.strided_index().enumerate() {
|
||||
let dst_index = dst_index + dst_offset;
|
||||
if dst_index >= dst.len() {
|
||||
break;
|
||||
@ -156,19 +140,24 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
dst[dst_index] = src[src_index]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul_impl<T: 'static + num_traits::Num + Copy>(
|
||||
fn matmul<T: 'static + num_traits::Num + Copy>(
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
let lhs_cs = lhs_stride[rank - 1];
|
||||
let lhs_rs = lhs_stride[rank - 2];
|
||||
@ -238,118 +227,114 @@ impl CpuStorage {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
pub fn as_mut_slice<D: crate::WithDType>(&mut self) -> Result<&mut [D]> {
|
||||
D::cpu_storage_as_mut_slice(self)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
// TODO: find a way around the quadratic number of cases below.
|
||||
match (self, dtype) {
|
||||
(Self::U32(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32));
|
||||
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32()));
|
||||
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F32(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, bf16::from_f32);
|
||||
let data = unary_map(storage, layout, bf16::from_f32);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F64(storage), DType::BF16) => {
|
||||
let data = unary_map(storage, shape, stride, bf16::from_f64);
|
||||
let data = unary_map(storage, layout, bf16::from_f64);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32));
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32()));
|
||||
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, f16::from_f32);
|
||||
let data = unary_map(storage, layout, f16::from_f32);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F16) => {
|
||||
let data = unary_map(storage, shape, stride, f16::from_f64);
|
||||
let data = unary_map(storage, layout, f16::from_f64);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||
let data = unary_map(storage, layout, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||
let data = unary_map(storage, layout, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32());
|
||||
let data = unary_map(storage, layout, |v| v.to_f32());
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f32);
|
||||
let data = unary_map(storage, layout, |v| v as f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F16(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32);
|
||||
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F32(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::F64(storage), DType::U32) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as u32);
|
||||
let data = unary_map(storage, layout, |v| v as u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
(Self::U32(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::BF16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||
let data = unary_map(storage, layout, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F16(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v.to_f64());
|
||||
let data = unary_map(storage, layout, |v| v.to_f64());
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F32(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v as f64);
|
||||
let data = unary_map(storage, layout, |v| v as f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::F64(storage), DType::F64) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v);
|
||||
let data = unary_map(storage, layout, |v| v);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = shape.dims();
|
||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = layout.dims();
|
||||
let mut dst_dims = src_dims.to_vec();
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dst_dims[sum_dim] = 1;
|
||||
@ -375,7 +360,7 @@ impl CpuStorage {
|
||||
dst_index
|
||||
};
|
||||
// TODO: Maybe provide an implementation with higher precision accumulators?
|
||||
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index)
|
||||
map1!(self, sum_impl1, &dst_shape, layout, to_dst_index)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
@ -461,65 +446,59 @@ impl CpuStorage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Self::U32(storage) => {
|
||||
let mul = mul as u32;
|
||||
let add = add as u32;
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::BF16(storage) => {
|
||||
let mul = bf16::from_f64(mul);
|
||||
let add = bf16::from_f64(add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let mul = f16::from_f64(mul);
|
||||
let add = f16::from_f64(add);
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let mul = mul as f32;
|
||||
let add = add as f32;
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, shape, stride, |v| v * mul + add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
match self {
|
||||
Self::BF16(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::bf16);
|
||||
let data = unary_map(storage, layout, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f16);
|
||||
let data = unary_map(storage, layout, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f32);
|
||||
let data = unary_map(storage, layout, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::f64);
|
||||
let data = unary_map(storage, layout, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
Self::U32(storage) => {
|
||||
let data = unary_map(storage, shape, stride, B::u32);
|
||||
let data = unary_map(storage, layout, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
}
|
||||
@ -528,29 +507,28 @@ impl CpuStorage {
|
||||
pub(crate) fn binary_impl<B: BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(lhs), Self::U32(rhs)) => {
|
||||
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => {
|
||||
@ -568,29 +546,14 @@ impl CpuStorage {
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
match (self, dst) {
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
(_, dst) => {
|
||||
// This should be covered by the dtype check above.
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
@ -605,34 +568,33 @@ impl CpuStorage {
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
layout_t: &Layout,
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
layout_f: &Layout,
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
match (t, f) {
|
||||
(Self::BF16(t), Self::BF16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(t), Self::F16(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(t), Self::F32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(t), Self::F64(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(t), Self::U32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
@ -643,36 +605,29 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
|
||||
map1!(rhs, take_impl1, ids, layout, rhs_l)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
pub(crate) fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
||||
Ok(Self::F16(dst))
|
||||
}
|
||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
||||
Ok(Self::F32(dst))
|
||||
}
|
||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
||||
Ok(Self::F64(dst))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{CpuStorage, DType, Shape};
|
||||
use crate::{CpuStorage, DType, Layout, Shape};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
|
||||
@ -26,6 +26,9 @@ pub enum CudaError {
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
|
||||
|
||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||
MatMulNonContiguous {
|
||||
lhs_stride: Vec<usize>,
|
||||
@ -242,13 +245,14 @@ enum CudaStorageSlice {
|
||||
|
||||
fn slice_src_and_dst<'a, T>(
|
||||
src: &'a CudaSlice<T>,
|
||||
src_offset: usize,
|
||||
src_l: &Layout,
|
||||
dst: &'a mut CudaSlice<T>,
|
||||
dst_offset: usize,
|
||||
) -> (
|
||||
cudarc::driver::CudaView<'a, T>,
|
||||
cudarc::driver::CudaViewMut<'a, T>,
|
||||
) {
|
||||
let src_offset = src_l.start_offset();
|
||||
let to_copy = dst
|
||||
.len()
|
||||
.saturating_sub(dst_offset)
|
||||
@ -268,12 +272,14 @@ fn gemm_config<T>(
|
||||
alpha: T,
|
||||
beta: T,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<StridedBatchedConfig<T>> {
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm
|
||||
use cudarc::cublas::sys::cublasOperation_t;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||
@ -352,20 +358,27 @@ impl CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
use cudarc::driver::DevicePtr;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||
let start_o = layout.start_offset();
|
||||
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
||||
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||
// is used.
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U32(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::F64(inp) => inp.device_ptr(),
|
||||
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
};
|
||||
let inp = &inp;
|
||||
|
||||
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
||||
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
||||
let slice = match dtype {
|
||||
@ -406,20 +419,16 @@ impl CudaStorage {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
||||
@ -429,6 +438,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||
@ -446,6 +456,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||
@ -463,6 +474,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
||||
@ -472,6 +484,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
||||
@ -485,7 +498,8 @@ impl CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let src_dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let mut dst_el = el;
|
||||
@ -503,9 +517,10 @@ impl CudaStorage {
|
||||
.collect();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?;
|
||||
let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<u32>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
@ -514,6 +529,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<bf16>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
@ -522,6 +538,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f16>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
@ -530,6 +547,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f32>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
@ -538,6 +556,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<f64>(dst_el)?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out);
|
||||
@ -556,21 +575,19 @@ impl CudaStorage {
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<U: crate::op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||
let slice = match &self.slice {
|
||||
CudaStorageSlice::U32(_arg) => {
|
||||
todo!("No unary kernels for u32");
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
||||
@ -580,6 +597,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
||||
@ -589,6 +607,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
||||
@ -598,6 +617,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let arg = &arg.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
||||
@ -614,17 +634,19 @@ impl CudaStorage {
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let shape = lhs_l.shape();
|
||||
let dims = shape.dims();
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?;
|
||||
let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?;
|
||||
let slice = match (&self.slice, &rhs.slice) {
|
||||
(CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(elem_count) }?;
|
||||
@ -634,6 +656,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
@ -643,6 +667,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
@ -652,6 +678,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?;
|
||||
let out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||
@ -661,6 +689,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
(CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => {
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?;
|
||||
let out = unsafe { dev.alloc::<u32>(elem_count) }?;
|
||||
@ -708,28 +738,32 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
layout_t: &Layout,
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
layout_f: &Layout,
|
||||
) -> Result<Self> {
|
||||
let ids = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => slice,
|
||||
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "where conditions should be u32",
|
||||
expected: DType::U32,
|
||||
got: self.dtype(),
|
||||
})?,
|
||||
};
|
||||
let ids = &ids;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?;
|
||||
let ds =
|
||||
dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
||||
let slice = match (&t.slice, &f.slice) {
|
||||
(CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => {
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el) }?;
|
||||
@ -739,6 +773,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => {
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el) }?;
|
||||
@ -748,6 +784,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => {
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el) }?;
|
||||
@ -757,6 +795,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => {
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?;
|
||||
let out = unsafe { dev.alloc::<f64>(el) }?;
|
||||
@ -766,6 +806,8 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F64(out)
|
||||
}
|
||||
(CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => {
|
||||
let t = &t.slice(layout_t.start_offset()..);
|
||||
let f = &f.slice(layout_f.start_offset()..);
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?;
|
||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||
@ -775,36 +817,36 @@ impl CudaStorage {
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
// The dtypes should have been checked at this point so this is an internal error.
|
||||
_ => return Err(CudaError::InternalError("dtype mismatch in binary op")),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
};
|
||||
let device = dev.clone();
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
rhs: &Self,
|
||||
h_size: usize, // hidden size
|
||||
v_size: usize, // vocab size
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = match &self.slice {
|
||||
CudaStorageSlice::U32(slice) => slice,
|
||||
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "embedding ids should be u32",
|
||||
expected: DType::U32,
|
||||
got: self.dtype(),
|
||||
})?,
|
||||
};
|
||||
let ids = &ids;
|
||||
let shape = layout.shape();
|
||||
let (v_size, h_size) = rhs_l
|
||||
.shape()
|
||||
.r2()
|
||||
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let dev = self.device();
|
||||
let ds = dev.htod_copy([dims, stride].concat())?;
|
||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||
let slice = match &rhs.slice {
|
||||
// The kernels below assume that rhs is contiguous.
|
||||
CudaStorageSlice::U32(arg) => {
|
||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
||||
@ -814,6 +856,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::U32(out)
|
||||
}
|
||||
CudaStorageSlice::BF16(arg) => {
|
||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
||||
@ -823,6 +866,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
CudaStorageSlice::F16(arg) => {
|
||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
||||
@ -832,6 +876,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
CudaStorageSlice::F32(arg) => {
|
||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
|
||||
@ -841,6 +886,7 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
CudaStorageSlice::F64(arg) => {
|
||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
|
||||
@ -854,12 +900,12 @@ impl CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
pub(crate) fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let elem_count = b * m * n;
|
||||
let dev = &self.device;
|
||||
@ -868,7 +914,9 @@ impl CudaStorage {
|
||||
todo!("bf16")
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
@ -878,7 +926,9 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
@ -888,7 +938,9 @@ impl CudaStorage {
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?;
|
||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f64>(elem_count) }?;
|
||||
unsafe {
|
||||
self.device
|
||||
@ -907,22 +959,18 @@ impl CudaStorage {
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
if src_shape.rank() != src_stride.len() {
|
||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
||||
}
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
let el_count = src_shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let dev = &self.device;
|
||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||
let ds = dev.htod_copy([dims, src_l.stride()].concat())?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
||||
@ -933,8 +981,8 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
||||
@ -945,8 +993,8 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||
@ -957,8 +1005,8 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
||||
@ -969,8 +1017,8 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||
|
@ -41,7 +41,6 @@ pub trait WithDType: Sized + Copy {
|
||||
}
|
||||
|
||||
fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>;
|
||||
fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
|
||||
}
|
||||
|
||||
@ -75,17 +74,6 @@ macro_rules! with_dtype {
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> {
|
||||
match s {
|
||||
CpuStorage::$dtype(data) => Ok(data),
|
||||
_ => Err(Error::UnexpectedDType {
|
||||
expected: DType::$dtype,
|
||||
got: s.dtype(),
|
||||
msg: "unexpected dtype",
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
#![allow(dead_code)]
|
||||
use crate::{CpuStorage, DType, Error, Result, Shape};
|
||||
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum DummyError {}
|
||||
@ -60,11 +60,11 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
|
||||
pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -72,65 +72,49 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
||||
pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &[usize],
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &[usize],
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
pub(crate) fn matmul(
|
||||
&self,
|
||||
_: &Self,
|
||||
_: (usize, usize, usize, usize),
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
_: &Layout,
|
||||
_: &Layout,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn copy_strided_src(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
140
candle-core/src/layout.rs
Normal file
140
candle-core/src/layout.rs
Normal file
@ -0,0 +1,140 @@
|
||||
use crate::{Error, Result, Shape};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct Layout {
|
||||
shape: Shape,
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
stride: Vec<usize>,
|
||||
start_offset: usize,
|
||||
}
|
||||
|
||||
impl Layout {
|
||||
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
|
||||
Self::contiguous_with_offset(shape, 0)
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
self.shape.dims()
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.stride
|
||||
}
|
||||
|
||||
pub fn start_offset(&self) -> usize {
|
||||
self.start_offset
|
||||
}
|
||||
|
||||
/// Returns the appropriate start and stop offset if the data is stored in a C
|
||||
/// contiguous (aka row major) way.
|
||||
pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
|
||||
if self.is_contiguous() {
|
||||
let start_o = self.start_offset;
|
||||
Some((start_o, start_o + self.shape.elem_count()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.shape.is_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
|
||||
pub fn is_fortran_contiguous(&self) -> bool {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: dim + 1,
|
||||
got: dims.len(),
|
||||
shape: self.shape().clone(),
|
||||
})?
|
||||
}
|
||||
if start + length > dims[dim] {
|
||||
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
|
||||
}
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = length;
|
||||
Ok(Self {
|
||||
shape: Shape::from(dims),
|
||||
stride: self.stride.clone(),
|
||||
start_offset: self.start_offset + self.stride[dim] * start,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: usize::max(dim1, dim2),
|
||||
got: rank,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut dims = self.shape().dims().to_vec();
|
||||
dims.swap(dim1, dim2);
|
||||
stride.swap(dim1, dim2);
|
||||
Ok(Self {
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape.clone(),
|
||||
})?
|
||||
}
|
||||
let added_dims = shape.rank() - self.shape().rank();
|
||||
let mut stride = vec![0; added_dims];
|
||||
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||
.iter()
|
||||
.zip(self.dims().iter().zip(self.stride()))
|
||||
{
|
||||
let s = if dst_dim == src_dim {
|
||||
src_stride
|
||||
} else if src_dim != 1 {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
} else {
|
||||
0
|
||||
};
|
||||
stride.push(s)
|
||||
}
|
||||
Ok(Self {
|
||||
shape,
|
||||
stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(self)
|
||||
}
|
||||
}
|
@ -7,6 +7,7 @@ pub mod display;
|
||||
mod dtype;
|
||||
mod dummy_cuda_backend;
|
||||
mod error;
|
||||
mod layout;
|
||||
mod npy;
|
||||
mod op;
|
||||
mod shape;
|
||||
@ -19,6 +20,7 @@ pub use cpu_backend::CpuStorage;
|
||||
pub use device::{Device, DeviceLocation};
|
||||
pub use dtype::{DType, WithDType};
|
||||
pub use error::{Error, Result};
|
||||
pub use layout::Layout;
|
||||
pub use shape::Shape;
|
||||
pub use storage::Storage;
|
||||
use strided_index::StridedIndex;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -53,38 +53,33 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
let storage = storage.sum(layout, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
let storage = storage.sum(layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This assumes a contiguous layout and no offset.
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||
@ -93,32 +88,28 @@ impl Storage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
@ -127,19 +118,18 @@ impl Storage {
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, B::NAME)?;
|
||||
self.same_dtype(rhs, B::NAME)?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
@ -156,49 +146,41 @@ impl Storage {
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
layout_t: &Layout,
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
layout_f: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(t, "where")?;
|
||||
self.same_device(f, "where")?;
|
||||
t.same_dtype(f, "where")?;
|
||||
match (self, t, f) {
|
||||
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
op: "where",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
rhs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
self.same_device(rhs, "embedding")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
let storage = lhs.embedding(layout, rhs, rhs_l)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
@ -209,22 +191,22 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn matmul_impl(
|
||||
pub(crate) fn matmul(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "matmul")?;
|
||||
self.same_dtype(rhs, "matmul")?;
|
||||
match (self, rhs) {
|
||||
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
@ -240,17 +222,11 @@ impl Storage {
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
dst_offset: usize,
|
||||
src_shape: &Shape,
|
||||
src_stride: &[usize],
|
||||
src_offset: usize,
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => {
|
||||
src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)
|
||||
}
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?)
|
||||
}
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -1,27 +1,28 @@
|
||||
use crate::Layout;
|
||||
|
||||
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
||||
/// flat buffer using some potential strides.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StridedIndex<'a> {
|
||||
next_storage_index: Option<usize>,
|
||||
multi_index: Vec<usize>,
|
||||
dims: &'a [usize],
|
||||
stride: &'a [usize],
|
||||
layout: &'a Layout,
|
||||
}
|
||||
|
||||
impl<'a> StridedIndex<'a> {
|
||||
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
|
||||
pub(crate) fn new(layout: &'a Layout) -> Self {
|
||||
let dims = layout.dims();
|
||||
let elem_count: usize = dims.iter().product();
|
||||
let next_storage_index = if elem_count == 0 {
|
||||
None
|
||||
} else {
|
||||
// This applies to the scalar case.
|
||||
Some(0)
|
||||
Some(layout.start_offset())
|
||||
};
|
||||
StridedIndex {
|
||||
next_storage_index,
|
||||
multi_index: vec![0; dims.len()],
|
||||
dims,
|
||||
stride,
|
||||
layout,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> {
|
||||
Some(storage_index) => storage_index,
|
||||
};
|
||||
let mut updated = false;
|
||||
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
||||
for (multi_i, max_i) in self
|
||||
.multi_index
|
||||
.iter_mut()
|
||||
.zip(self.layout.dims().iter())
|
||||
.rev()
|
||||
{
|
||||
let next_i = *multi_i + 1;
|
||||
if next_i < *max_i {
|
||||
*multi_i = next_i;
|
||||
@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> {
|
||||
let next_storage_index = self
|
||||
.multi_index
|
||||
.iter()
|
||||
.zip(self.stride.iter())
|
||||
.zip(self.layout.stride().iter())
|
||||
.map(|(&x, &y)| x * y)
|
||||
.sum();
|
||||
.sum::<usize>()
|
||||
+ self.layout.start_offset();
|
||||
Some(next_storage_index)
|
||||
} else {
|
||||
None
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -17,9 +17,7 @@ impl TensorId {
|
||||
pub struct Tensor_ {
|
||||
id: TensorId,
|
||||
storage: Arc<Storage>,
|
||||
shape: Shape,
|
||||
// The strides are given in number of elements and not in bytes.
|
||||
stride: Vec<usize>,
|
||||
layout: Layout,
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
}
|
||||
@ -50,7 +48,7 @@ macro_rules! unary_op {
|
||||
let shape = self.shape();
|
||||
let storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::$op_name>(self.shape(), self.stride())?;
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::$op_name(self.clone()))
|
||||
} else {
|
||||
@ -67,9 +65,8 @@ macro_rules! binary_op {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$op_name>(
|
||||
&rhs.storage,
|
||||
shape,
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||
@ -107,13 +104,10 @@ fn from_storage<S: Into<Shape>>(
|
||||
op: Option<Op>,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let shape = shape.into();
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
shape,
|
||||
stride,
|
||||
layout: Layout::contiguous(shape),
|
||||
op,
|
||||
is_variable,
|
||||
};
|
||||
@ -323,6 +317,7 @@ impl Tensor {
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
unary_op!(relu, Relu);
|
||||
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
if self.rank() != 0 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
@ -342,8 +337,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?;
|
||||
let storage = self.storage.affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
@ -353,42 +347,25 @@ impl Tensor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||
/// ranges from `start` to `start + length`.
|
||||
// TODO: Once we've refactored the shape and strides, make this return a view of the same data
|
||||
// rather than copying.
|
||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: dim + 1,
|
||||
got: dims.len(),
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
if start + length > dims[dim] {
|
||||
todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}")
|
||||
}
|
||||
let mut dims = dims.to_vec();
|
||||
dims[dim] = length;
|
||||
let adjusted_shape = Shape::from(dims);
|
||||
let mut storage = self.device().zeros(&adjusted_shape, self.dtype())?;
|
||||
self.storage.copy_strided_src(
|
||||
&mut storage,
|
||||
/* dst_offset= */ 0,
|
||||
&adjusted_shape,
|
||||
&self.stride,
|
||||
/* src_offest= */ self.stride[dim] * start,
|
||||
)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, length))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, adjusted_shape, op, false))
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout().narrow(dim, start, length)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn softmax(&self, dim: usize) -> Result<Self> {
|
||||
@ -401,9 +378,7 @@ impl Tensor {
|
||||
exp.broadcast_div(&sum_exp)
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let mut storage = self
|
||||
.storage
|
||||
.unary_impl::<crate::op::Exp>(shape, self.stride())?;
|
||||
let mut storage = self.storage.unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
@ -416,7 +391,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn sum(&self, sum_dims: &[usize]) -> Result<Self> {
|
||||
let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?;
|
||||
let storage = self.storage.sum(self.layout(), sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -458,11 +433,11 @@ impl Tensor {
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
|
||||
let storage = self.storage.matmul_impl(
|
||||
let storage = self.storage.matmul(
|
||||
&rhs.storage,
|
||||
(batching, m, n, k),
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::Matmul(self.clone(), rhs.clone()))
|
||||
@ -476,12 +451,11 @@ impl Tensor {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
let storage = self.storage.where_cond(
|
||||
shape,
|
||||
self.stride(),
|
||||
self.layout(),
|
||||
&on_true.storage,
|
||||
on_true.stride(),
|
||||
on_true.layout(),
|
||||
&on_false.storage,
|
||||
on_false.stride(),
|
||||
on_false.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
Some(Op::WhereCond(
|
||||
@ -498,23 +472,19 @@ impl Tensor {
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||
} else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 {
|
||||
} else if rhs.rank() != 2 || ids.rank() != 1 {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: ids.shape.clone(),
|
||||
rhs: rhs.shape.clone(),
|
||||
lhs: ids.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op: "embedding",
|
||||
});
|
||||
}
|
||||
let ids_shape = ids.shape();
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids.storage.embedding_impl(
|
||||
ids_shape,
|
||||
&ids.stride,
|
||||
&rhs.storage,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
)?;
|
||||
let (_, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding(ids.layout(), &rhs.storage, rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
@ -525,7 +495,7 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||
crate::StridedIndex::new(self.dims(), self.stride())
|
||||
self.layout.strided_index()
|
||||
}
|
||||
|
||||
/// Returns data from the underlying storage, this does not take the strides
|
||||
@ -618,15 +588,20 @@ impl Tensor {
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape {
|
||||
&self.shape
|
||||
self.layout().shape()
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
self.shape().dims()
|
||||
}
|
||||
|
||||
pub fn stride(&self) -> &[usize] {
|
||||
&self.stride
|
||||
pub fn layout(&self) -> &Layout {
|
||||
&self.layout
|
||||
}
|
||||
|
||||
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
||||
pub fn stride_tmp(&self) -> &[usize] {
|
||||
self.layout.stride()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
@ -704,18 +679,6 @@ impl Tensor {
|
||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||
/// swapped.
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||
let rank = self.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
return Err(Error::UnexpectedNumberOfDims {
|
||||
expected: usize::max(dim1, dim2),
|
||||
got: rank,
|
||||
shape: self.shape().clone(),
|
||||
});
|
||||
}
|
||||
let mut stride = self.stride().to_vec();
|
||||
let mut dims = self.shape().dims().to_vec();
|
||||
dims.swap(dim1, dim2);
|
||||
stride.swap(dim1, dim2);
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||
} else {
|
||||
@ -724,8 +687,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
layout: self.layout.transpose(dim1, dim2)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -734,12 +696,12 @@ impl Tensor {
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.shape.is_contiguous(&self.stride)
|
||||
self.layout.is_contiguous()
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a Fortran contiguous (aka column major) way.
|
||||
pub fn is_fortran_contiguous(&self) -> bool {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
self.layout.is_fortran_contiguous()
|
||||
}
|
||||
|
||||
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
||||
@ -748,8 +710,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(self.storage.try_clone()?),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: None, // TODO
|
||||
is_variable: false,
|
||||
};
|
||||
@ -762,8 +723,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: None,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -796,8 +756,7 @@ impl Tensor {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(storage),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -810,7 +769,7 @@ impl Tensor {
|
||||
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||
let left_shape = left_shape.into();
|
||||
let mut dims = left_shape.into_dims();
|
||||
dims.extend(self.shape.dims());
|
||||
dims.extend(self.dims());
|
||||
self.broadcast_as(dims)
|
||||
}
|
||||
|
||||
@ -820,36 +779,10 @@ impl Tensor {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.rank() {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
}
|
||||
let added_dims = shape.rank() - self.rank();
|
||||
let mut stride = vec![0; added_dims];
|
||||
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||
.iter()
|
||||
.zip(self.dims().iter().zip(self.stride()))
|
||||
{
|
||||
let s = if dst_dim == src_dim {
|
||||
src_stride
|
||||
} else if src_dim != 1 {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
} else {
|
||||
0
|
||||
};
|
||||
stride.push(s)
|
||||
}
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
stride,
|
||||
layout: self.layout.broadcast_as(shape)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -866,7 +799,7 @@ impl Tensor {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||
let storage = self.storage.to_dtype(self.layout(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
@ -883,7 +816,7 @@ impl Tensor {
|
||||
let shape = self.shape();
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
shape.clone(),
|
||||
@ -913,12 +846,10 @@ impl Tensor {
|
||||
None
|
||||
};
|
||||
if self.is_contiguous() {
|
||||
let stride = shape.stride_contiguous();
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape,
|
||||
stride,
|
||||
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
|
||||
op,
|
||||
is_variable: false,
|
||||
};
|
||||
@ -926,7 +857,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut storage = self.device().zeros(&shape, self.dtype())?;
|
||||
self.storage
|
||||
.copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
}
|
||||
@ -1063,7 +994,7 @@ impl Tensor {
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage
|
||||
.copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?;
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
|
Reference in New Issue
Block a user