mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Start refactoring the stride.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape};
|
||||
use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
// out of memory. Instead try_clone should be used.
|
||||
@ -53,33 +53,27 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
mul: f64,
|
||||
add: f64,
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.affine_impl(shape, stride, mul, add)?;
|
||||
let storage = storage.affine(layout, mul, add)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
let storage = storage.sum(layout, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(shape, stride, s)?;
|
||||
let storage = storage.sum(layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
@ -93,32 +87,28 @@ impl Storage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> {
|
||||
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.to_dtype(shape, stride, dtype)?;
|
||||
let storage = storage.to_dtype(layout, dtype)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
pub(crate) fn unary_impl<B: op::UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
let storage = storage.unary_impl::<B>(layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
@ -127,19 +117,18 @@ impl Storage {
|
||||
pub(crate) fn binary_impl<B: op::BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, B::NAME)?;
|
||||
self.same_dtype(rhs, B::NAME)?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
@ -156,23 +145,22 @@ impl Storage {
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Shape,
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
layout_t: &Layout,
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
layout_f: &Layout,
|
||||
) -> Result<Self> {
|
||||
self.same_device(t, "where")?;
|
||||
self.same_device(f, "where")?;
|
||||
t.same_dtype(f, "where")?;
|
||||
match (self, t, f) {
|
||||
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
@ -185,8 +173,7 @@ impl Storage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
layout: &Layout,
|
||||
rhs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
@ -194,11 +181,11 @@ impl Storage {
|
||||
self.same_device(rhs, "embedding")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
|
Reference in New Issue
Block a user