mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Get the cpu backend to compile.
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
use crate::op::{BinaryOp, UnaryOp};
|
use crate::op::{BinaryOp, UnaryOp};
|
||||||
use crate::{DType, Error, Layout, Result, Shape, StridedIndex};
|
use crate::{DType, Error, Layout, Result, Shape};
|
||||||
use gemm::{gemm, Parallelism};
|
use gemm::{gemm, Parallelism};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -81,14 +81,13 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
|
|||||||
|
|
||||||
// This function maps over two strided index sequences.
|
// This function maps over two strided index sequences.
|
||||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||||
shape: &Shape,
|
|
||||||
lhs_layout: &Layout,
|
lhs_layout: &Layout,
|
||||||
rhs_layout: &Layout,
|
rhs_layout: &Layout,
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
rhs: &[T],
|
rhs: &[T],
|
||||||
mut f: F,
|
mut f: F,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let dims = shape.dims();
|
let shape = lhs_layout.shape();
|
||||||
if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() {
|
if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() {
|
||||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||||
} else {
|
} else {
|
||||||
@ -148,17 +147,19 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul_impl<T: 'static + num_traits::Num + Copy>(
|
fn matmul<T: 'static + num_traits::Num + Copy>(
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
rhs: &[T],
|
rhs: &[T],
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
lhs_stride: &[usize],
|
lhs_layout: &Layout,
|
||||||
rhs_stride: &[usize],
|
rhs_layout: &Layout,
|
||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
let a_skip: usize = m * k;
|
let a_skip: usize = m * k;
|
||||||
let b_skip: usize = n * k;
|
let b_skip: usize = n * k;
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
|
let lhs_stride = lhs_layout.stride();
|
||||||
|
let rhs_stride = rhs_layout.stride();
|
||||||
let rank = lhs_stride.len();
|
let rank = lhs_stride.len();
|
||||||
let lhs_cs = lhs_stride[rank - 1];
|
let lhs_cs = lhs_stride[rank - 1];
|
||||||
let lhs_rs = lhs_stride[rank - 2];
|
let lhs_rs = lhs_stride[rank - 2];
|
||||||
@ -512,29 +513,28 @@ impl CpuStorage {
|
|||||||
pub(crate) fn binary_impl<B: BinaryOp>(
|
pub(crate) fn binary_impl<B: BinaryOp>(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
shape: &Shape,
|
|
||||||
lhs_layout: &Layout,
|
lhs_layout: &Layout,
|
||||||
rhs_layout: &Layout,
|
rhs_layout: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||||
let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::bf16);
|
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::bf16);
|
||||||
Ok(Self::BF16(data))
|
Ok(Self::BF16(data))
|
||||||
}
|
}
|
||||||
(Self::F16(lhs), Self::F16(rhs)) => {
|
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||||
let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f16);
|
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f16);
|
||||||
Ok(Self::F16(data))
|
Ok(Self::F16(data))
|
||||||
}
|
}
|
||||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||||
let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f32);
|
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f32);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||||
let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f64);
|
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f64);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
(Self::U32(lhs), Self::U32(rhs)) => {
|
(Self::U32(lhs), Self::U32(rhs)) => {
|
||||||
let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::u32);
|
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::u32);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
@ -622,24 +622,24 @@ impl CpuStorage {
|
|||||||
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
bmnk: (usize, usize, usize, usize),
|
bmnk: (usize, usize, usize, usize),
|
||||||
lhs_stride: &[usize],
|
lhs_layout: &Layout,
|
||||||
rhs_stride: &[usize],
|
rhs_layout: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::F16(dst))
|
Ok(Self::F16(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::F32(dst))
|
Ok(Self::F32(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::F64(dst))
|
Ok(Self::F64(dst))
|
||||||
}
|
}
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
use crate::{CpuStorage, DType, Error, Result, Shape};
|
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum DummyError {}
|
pub enum DummyError {}
|
||||||
@ -60,11 +60,11 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
|
pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
|
pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,65 +72,49 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result<Self> {
|
pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Shape, _: &[usize]) -> Result<Self> {
|
pub(crate) fn unary_impl<B: crate::op::UnaryOp>(&self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
pub(crate) fn binary_impl<B: crate::op::BinaryOp>(
|
||||||
&self,
|
&self,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Shape,
|
_: &Layout,
|
||||||
_: &[usize],
|
_: &Layout,
|
||||||
_: &[usize],
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
_: &Shape,
|
_: &Layout,
|
||||||
_: &[usize],
|
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &[usize],
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &[usize],
|
_: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: usize, _: usize) -> Result<Self> {
|
||||||
&self,
|
|
||||||
_: &Shape,
|
|
||||||
_: &[usize],
|
|
||||||
_: &Self,
|
|
||||||
_: usize,
|
|
||||||
_: usize,
|
|
||||||
) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul(
|
||||||
&self,
|
&self,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: (usize, usize, usize, usize),
|
_: (usize, usize, usize, usize),
|
||||||
_: &[usize],
|
_: &Layout,
|
||||||
_: &[usize],
|
_: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn copy_strided_src(
|
pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
&self,
|
|
||||||
_: &mut Self,
|
|
||||||
_: usize,
|
|
||||||
_: &Shape,
|
|
||||||
_: &[usize],
|
|
||||||
_: usize,
|
|
||||||
) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,7 +106,7 @@ impl Layout {
|
|||||||
if shape.rank() < self.shape().rank() {
|
if shape.rank() < self.shape().rank() {
|
||||||
Err(Error::BroadcastIncompatibleShapes {
|
Err(Error::BroadcastIncompatibleShapes {
|
||||||
src_shape: self.shape().clone(),
|
src_shape: self.shape().clone(),
|
||||||
dst_shape: shape,
|
dst_shape: shape.clone(),
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
let added_dims = shape.rank() - self.shape().rank();
|
let added_dims = shape.rank() - self.shape().rank();
|
||||||
@ -135,6 +135,6 @@ impl Layout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
crate::StridedIndex::new(&self)
|
crate::StridedIndex::new(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,6 +79,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This assumes a contiguous layout and no offset.
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
|
||||||
@ -196,22 +197,22 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
bmnk: (usize, usize, usize, usize),
|
bmnk: (usize, usize, usize, usize),
|
||||||
lhs_stride: &[usize],
|
lhs_layout: &Layout,
|
||||||
rhs_stride: &[usize],
|
rhs_layout: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, "matmul")?;
|
self.same_device(rhs, "matmul")?;
|
||||||
self.same_dtype(rhs, "matmul")?;
|
self.same_dtype(rhs, "matmul")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
(Self::Cpu(lhs), Self::Cpu(rhs)) => {
|
||||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?;
|
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
@ -432,7 +432,7 @@ impl Tensor {
|
|||||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||||
|
|
||||||
let storage = self.storage.matmul_impl(
|
let storage = self.storage.matmul(
|
||||||
&rhs.storage,
|
&rhs.storage,
|
||||||
(batching, m, n, k),
|
(batching, m, n, k),
|
||||||
self.layout(),
|
self.layout(),
|
||||||
@ -587,7 +587,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &Shape {
|
pub fn shape(&self) -> &Shape {
|
||||||
&self.layout().shape()
|
self.layout().shape()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dims(&self) -> &[usize] {
|
pub fn dims(&self) -> &[usize] {
|
||||||
@ -600,7 +600,7 @@ impl Tensor {
|
|||||||
|
|
||||||
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
// TODO: Rename to `stride` once the PR that introduced the layout has been merged.
|
||||||
pub fn stride_tmp(&self) -> &[usize] {
|
pub fn stride_tmp(&self) -> &[usize] {
|
||||||
&self.layout.stride()
|
self.layout.stride()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
|
Reference in New Issue
Block a user