mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Merge pull request #29 from LaurentMazare/cpu-map
Switch from a macro to a trait to make things more generic.
This commit is contained in:
@ -1,10 +1,8 @@
|
|||||||
use crate::op::{BinaryOp, UnaryOp};
|
use crate::op::{BinaryOp, UnaryOp};
|
||||||
use crate::{DType, Error, Layout, Result, Shape};
|
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||||
use gemm::{gemm, Parallelism};
|
use gemm::{gemm, Parallelism};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
// TODO: Think about whether we would be better off with a dtype and
|
|
||||||
// a buffer as an owned slice of bytes.
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -16,21 +14,69 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wcond<T: Copy>(
|
trait Map1 {
|
||||||
pred: &[u32],
|
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||||
|
&self,
|
||||||
|
vs: &[T],
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
t: &[T],
|
) -> Result<Vec<T>>;
|
||||||
layout_t: &Layout,
|
|
||||||
f: &[T],
|
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||||
layout_f: &Layout,
|
match vs {
|
||||||
) -> Vec<T> {
|
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||||
match (
|
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||||
layout.contiguous_offsets(),
|
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||||
layout_t.contiguous_offsets(),
|
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||||
layout_f.contiguous_offsets(),
|
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type C = CpuStorage;
|
||||||
|
trait Map2 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType + Copy + num_traits::Num + 'static>(
|
||||||
|
&self,
|
||||||
|
v1: &[T],
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &[T],
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
v1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
v2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<CpuStorage> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WCond<'a>(&'a [u32], &'a Layout);
|
||||||
|
|
||||||
|
impl<'a> Map2 for WCond<'a> {
|
||||||
|
const OP: &'static str = "where";
|
||||||
|
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
||||||
|
let vs = match (
|
||||||
|
self.1.contiguous_offsets(),
|
||||||
|
t_l.contiguous_offsets(),
|
||||||
|
f_l.contiguous_offsets(),
|
||||||
) {
|
) {
|
||||||
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
|
(Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
|
||||||
let pred = &pred[o1..o2];
|
let pred = &self.0[o1..o2];
|
||||||
let t = &t[o_t1..o_t2];
|
let t = &t[o_t1..o_t2];
|
||||||
let f = &f[o_f1..o_f2];
|
let f = &f[o_f1..o_f2];
|
||||||
pred.iter()
|
pred.iter()
|
||||||
@ -38,39 +84,42 @@ fn wcond<T: Copy>(
|
|||||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
_ => layout
|
_ => self
|
||||||
|
.1
|
||||||
.strided_index()
|
.strided_index()
|
||||||
.zip(layout_t.strided_index().zip(layout_f.strided_index()))
|
.zip(t_l.strided_index().zip(f_l.strided_index()))
|
||||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
.map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! map1 {
|
|
||||||
($v: expr, $fn: ident, $( $args:expr ),*) => {{
|
|
||||||
let v = match $v {
|
|
||||||
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
|
|
||||||
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
|
|
||||||
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
|
|
||||||
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
|
|
||||||
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
|
|
||||||
};
|
};
|
||||||
Ok(v)
|
Ok(vs)
|
||||||
}};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
struct Sum<'a> {
|
||||||
|
dst_shape: &'a Shape,
|
||||||
|
sum_dims_and_stride: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Map1 for Sum<'a> {
|
||||||
|
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||||
|
&self,
|
||||||
src: &[T],
|
src: &[T],
|
||||||
dst_shape: &Shape,
|
|
||||||
src_layout: &Layout,
|
src_layout: &Layout,
|
||||||
to_dst_index: impl Fn(usize) -> usize,
|
|
||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
let mut dst = vec![T::zero(); dst_shape.elem_count()];
|
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
let mut dst_index = unstr_index;
|
||||||
|
// Set the sum_dims indexes to 0.
|
||||||
|
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||||
|
// The compiler is able to optimize the following in a single divmod op.
|
||||||
|
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||||
|
dst_index = (pre / dim) * stride + post;
|
||||||
|
}
|
||||||
|
dst[dst_index] += src[src_index];
|
||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||||
match layout.contiguous_offsets() {
|
match layout.contiguous_offsets() {
|
||||||
@ -101,24 +150,49 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> {
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
|
impl Map1 for Affine {
|
||||||
|
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||||
|
&self,
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
let mul = T::from_f64(self.0);
|
||||||
|
let add = T::from_f64(self.1);
|
||||||
|
Ok(unary_map(vs, layout, |v| v * mul + add))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Embedding<'a> {
|
||||||
|
vocab_size: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
ids: &'a [u32],
|
||||||
|
ids_l: &'a Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Map1 for Embedding<'a> {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||||
|
// TODO: We assume that vs is contiguous here.
|
||||||
|
let vs = &vs[layout.start_offset()..];
|
||||||
|
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
|
||||||
// TODO: Optimize for the case where ids are contiguous.
|
// TODO: Optimize for the case where ids are contiguous.
|
||||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
for index in self.ids_l.strided_index() {
|
||||||
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
|
let index = self.ids[index].try_into()?;
|
||||||
for index in layout.strided_index() {
|
if index >= self.vocab_size {
|
||||||
let index = ids[index].try_into()?;
|
|
||||||
if index >= vocab_size {
|
|
||||||
return Err(Error::InvalidIndex {
|
return Err(Error::InvalidIndex {
|
||||||
index,
|
index,
|
||||||
vocab_size,
|
vocab_size: self.vocab_size,
|
||||||
op: "take",
|
op: "take",
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
let hidden_size = self.hidden_size;
|
||||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(values)
|
Ok(values)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||||
src: &[T],
|
src: &[T],
|
||||||
@ -143,13 +217,18 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul<T: 'static + num_traits::Num + Copy>(
|
struct MatMul((usize, usize, usize, usize));
|
||||||
|
|
||||||
|
impl Map2 for MatMul {
|
||||||
|
const OP: &'static str = "mat_mul";
|
||||||
|
fn f<T: 'static + num_traits::Num + Copy>(
|
||||||
|
&self,
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
rhs: &[T],
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
|
rhs: &[T],
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
|
let (b, m, n, k) = self.0;
|
||||||
let lhs = &lhs[lhs_l.start_offset()..];
|
let lhs = &lhs[lhs_l.start_offset()..];
|
||||||
let rhs = &rhs[rhs_l.start_offset()..];
|
let rhs = &rhs[rhs_l.start_offset()..];
|
||||||
let a_skip: usize = m * k;
|
let a_skip: usize = m * k;
|
||||||
@ -211,6 +290,36 @@ fn matmul<T: 'static + num_traits::Num + Copy>(
|
|||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>(
|
||||||
|
s: &mut [T],
|
||||||
|
shape: &Shape,
|
||||||
|
dim: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
// [self] stores data in a contiguous way starting at offset 0.
|
||||||
|
let dims = shape.dims();
|
||||||
|
let elem_per_slice = dims[dim];
|
||||||
|
let prod_pre_dim = dims[..dim].iter().product();
|
||||||
|
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
sum += s[idx].to_f64();
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let sum = T::from_f64(sum);
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
s[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
@ -348,135 +457,26 @@ impl CpuStorage {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
||||||
.collect();
|
.collect();
|
||||||
let to_dst_index = |unstr_index: usize| {
|
Sum {
|
||||||
// TODO: Optimize, the following does lots of slow division.
|
dst_shape: &dst_shape,
|
||||||
let mut dst_index = unstr_index;
|
sum_dims_and_stride,
|
||||||
// Set the sum_dims indexes to 0.
|
|
||||||
for &(dim, stride) in sum_dims_and_stride.iter() {
|
|
||||||
// The compiler is able to optimize the following in a single divmod op.
|
|
||||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
|
||||||
dst_index = (pre / dim) * stride + post;
|
|
||||||
}
|
}
|
||||||
dst_index
|
.map(self, layout)
|
||||||
};
|
|
||||||
// TODO: Maybe provide an implementation with higher precision accumulators?
|
|
||||||
map1!(self, sum_impl1, &dst_shape, layout, to_dst_index)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
// [self] stores data in a contiguous way.
|
// [self] stores data in a contiguous way starting at offset 0.
|
||||||
let dims = shape.dims();
|
|
||||||
let elem_per_slice = dims[dim];
|
|
||||||
let prod_pre_dim = dims[..dim].iter().product();
|
|
||||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
|
||||||
match self {
|
match self {
|
||||||
Self::BF16(storage) => {
|
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
for pre_idx in 0..prod_pre_dim {
|
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
for post_idx in 0..prod_post_dim {
|
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
let mut sum = 0f64;
|
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
Self::U32(_) => Ok(()),
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx].to_f64();
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
}
|
||||||
let sum = bf16::from_f64(sum);
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::F16(storage) => {
|
|
||||||
for pre_idx in 0..prod_pre_dim {
|
|
||||||
for post_idx in 0..prod_post_dim {
|
|
||||||
let mut sum = 0f64;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx].to_f64();
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let sum = f16::from_f64(sum);
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::F32(storage) => {
|
|
||||||
for pre_idx in 0..prod_pre_dim {
|
|
||||||
for post_idx in 0..prod_post_dim {
|
|
||||||
let mut sum = 0f64;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx] as f64;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let sum = sum as f32;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::F64(storage) => {
|
|
||||||
for pre_idx in 0..prod_pre_dim {
|
|
||||||
for post_idx in 0..prod_post_dim {
|
|
||||||
let mut sum = 0f64;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx];
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::U32(_) => {}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
match self {
|
Affine(mul, add).map(self, layout)
|
||||||
Self::U32(storage) => {
|
|
||||||
let mul = mul as u32;
|
|
||||||
let add = add as u32;
|
|
||||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
|
||||||
Ok(Self::U32(data))
|
|
||||||
}
|
|
||||||
Self::BF16(storage) => {
|
|
||||||
let mul = bf16::from_f64(mul);
|
|
||||||
let add = bf16::from_f64(add);
|
|
||||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
|
||||||
Ok(Self::BF16(data))
|
|
||||||
}
|
|
||||||
Self::F16(storage) => {
|
|
||||||
let mul = f16::from_f64(mul);
|
|
||||||
let add = f16::from_f64(add);
|
|
||||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
|
||||||
Ok(Self::F16(data))
|
|
||||||
}
|
|
||||||
Self::F32(storage) => {
|
|
||||||
let mul = mul as f32;
|
|
||||||
let add = add as f32;
|
|
||||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
|
||||||
Ok(Self::F32(data))
|
|
||||||
}
|
|
||||||
Self::F64(storage) => {
|
|
||||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
|
||||||
Ok(Self::F64(data))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||||
@ -570,44 +570,25 @@ impl CpuStorage {
|
|||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
t: &Self,
|
t: &Self,
|
||||||
layout_t: &Layout,
|
t_l: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
layout_f: &Layout,
|
f_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// TODO: Support types that could be casted to a boolean.
|
// TODO: Support types that could be casted to a boolean.
|
||||||
let pred = self.as_slice::<u32>()?;
|
let pred = self.as_slice::<u32>()?;
|
||||||
match (t, f) {
|
WCond(pred, layout).map(t, t_l, f, f_l)
|
||||||
(Self::BF16(t), Self::BF16(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::BF16(data))
|
|
||||||
}
|
|
||||||
(Self::F16(t), Self::F16(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::F16(data))
|
|
||||||
}
|
|
||||||
(Self::F32(t), Self::F32(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::F32(data))
|
|
||||||
}
|
|
||||||
(Self::F64(t), Self::F64(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::F64(data))
|
|
||||||
}
|
|
||||||
(Self::U32(t), Self::U32(f)) => {
|
|
||||||
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
|
||||||
Ok(Self::U32(data))
|
|
||||||
}
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: t.dtype(),
|
|
||||||
rhs: f.dtype(),
|
|
||||||
op: "where_cond",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
map1!(rhs, take_impl1, ids, layout, rhs_l)
|
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||||
|
Embedding {
|
||||||
|
vocab_size,
|
||||||
|
hidden_size,
|
||||||
|
ids,
|
||||||
|
ids_l,
|
||||||
|
}
|
||||||
|
.map(rhs, rhs_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul(
|
pub(crate) fn matmul(
|
||||||
@ -617,76 +598,28 @@ impl CpuStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match (self, rhs) {
|
MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
|
||||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
|
||||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
|
||||||
Ok(Self::F16(dst))
|
|
||||||
}
|
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
|
||||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
|
||||||
Ok(Self::F32(dst))
|
|
||||||
}
|
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
|
||||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
|
||||||
Ok(Self::F64(dst))
|
|
||||||
}
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: self.dtype(),
|
|
||||||
rhs: rhs.dtype(),
|
|
||||||
op: "matmul",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::U32 => {
|
DType::U32 => Self::U32(vec![1u32; elem_count]),
|
||||||
let data = vec![1u32; elem_count];
|
DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]),
|
||||||
Self::U32(data)
|
DType::F16 => Self::F16(vec![f16::ONE; elem_count]),
|
||||||
}
|
DType::F32 => Self::F32(vec![1f32; elem_count]),
|
||||||
DType::BF16 => {
|
DType::F64 => Self::F64(vec![1f64; elem_count]),
|
||||||
let data = vec![bf16::ONE; elem_count];
|
|
||||||
Self::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = vec![f16::ONE; elem_count];
|
|
||||||
Self::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = vec![1f32; elem_count];
|
|
||||||
Self::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = vec![1f64; elem_count];
|
|
||||||
Self::F64(data)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::U32 => {
|
DType::U32 => Self::U32(vec![0u32; elem_count]),
|
||||||
let data = vec![0u32; elem_count];
|
DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]),
|
||||||
Self::U32(data)
|
DType::F16 => Self::F16(vec![f16::ZERO; elem_count]),
|
||||||
}
|
DType::F32 => Self::F32(vec![0f32; elem_count]),
|
||||||
DType::BF16 => {
|
DType::F64 => Self::F64(vec![0f64; elem_count]),
|
||||||
let data = vec![bf16::ZERO; elem_count];
|
|
||||||
Self::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = vec![f16::ZERO; elem_count];
|
|
||||||
Self::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = vec![0f32; elem_count];
|
|
||||||
Self::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = vec![0f64; elem_count];
|
|
||||||
Self::F64(data)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,6 +34,8 @@ impl DType {
|
|||||||
pub trait WithDType: Sized + Copy {
|
pub trait WithDType: Sized + Copy {
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
|
fn from_f64(v: f64) -> Self;
|
||||||
|
fn to_f64(self) -> f64;
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
@ -45,10 +47,18 @@ pub trait WithDType: Sized + Copy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! with_dtype {
|
macro_rules! with_dtype {
|
||||||
($ty:ty, $dtype:ident) => {
|
($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
|
||||||
impl WithDType for $ty {
|
impl WithDType for $ty {
|
||||||
const DTYPE: DType = DType::$dtype;
|
const DTYPE: DType = DType::$dtype;
|
||||||
|
|
||||||
|
fn from_f64(v: f64) -> Self {
|
||||||
|
$from_f64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_f64(self) -> f64 {
|
||||||
|
$to_f64(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::$dtype(data)
|
CpuStorage::$dtype(data)
|
||||||
}
|
}
|
||||||
@ -77,8 +87,10 @@ macro_rules! with_dtype {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
with_dtype!(u32, U32);
|
use half::{bf16, f16};
|
||||||
with_dtype!(half::f16, F16);
|
|
||||||
with_dtype!(half::bf16, BF16);
|
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
|
||||||
with_dtype!(f32, F32);
|
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||||
with_dtype!(f64, F64);
|
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||||
|
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||||
|
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||||
|
Reference in New Issue
Block a user