mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Propagate the layout refactoring.
This commit is contained in:
@ -24,21 +24,22 @@ fn wcond<T: Copy>(
|
|||||||
f: &[T],
|
f: &[T],
|
||||||
layout_f: &Layout,
|
layout_f: &Layout,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() {
|
||||||
{
|
let elem_count = layout.shape().elem_count();
|
||||||
let elem_count = shape.elem_count();
|
let offset = layout.start_offset();
|
||||||
let pred = &pred[..elem_count];
|
let offset_t = layout_t.start_offset();
|
||||||
let t = &t[..elem_count];
|
let offset_f = layout_f.start_offset();
|
||||||
let f = &f[..elem_count];
|
let pred = &pred[offset..offset + elem_count];
|
||||||
|
let t = &t[offset_t..offset_t + elem_count];
|
||||||
|
let f = &f[offset_f..offset_f + elem_count];
|
||||||
pred.iter()
|
pred.iter()
|
||||||
.zip(t.iter().zip(f.iter()))
|
.zip(t.iter().zip(f.iter()))
|
||||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
} else {
|
} else {
|
||||||
let dims = shape.dims();
|
let it_p = StridedIndex::new(layout);
|
||||||
let it_p = StridedIndex::new(dims, stride);
|
let it_t = StridedIndex::new(layout_t);
|
||||||
let it_t = StridedIndex::new(dims, stride_t);
|
let it_f = StridedIndex::new(layout_f);
|
||||||
let it_f = StridedIndex::new(dims, stride_f);
|
|
||||||
it_p.zip(it_t.zip(it_f))
|
it_p.zip(it_t.zip(it_f))
|
||||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
@ -107,13 +108,13 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
fn take_impl1<T: Copy>(
|
fn take_impl1<T: Copy>(
|
||||||
vs: &[T],
|
vs: &[T],
|
||||||
ids: &[u32],
|
ids: &[u32],
|
||||||
shape: &Shape,
|
layout: &Layout,
|
||||||
stride: &[usize],
|
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
) -> Result<Vec<T>> {
|
) -> Result<Vec<T>> {
|
||||||
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
// TODO: Optimize for the case where ids are contiguous.
|
||||||
for index in StridedIndex::new(shape.dims(), stride) {
|
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
|
||||||
|
for index in StridedIndex::new(layout) {
|
||||||
let index = ids[index].try_into()?;
|
let index = ids[index].try_into()?;
|
||||||
if index >= vocab_size {
|
if index >= vocab_size {
|
||||||
return Err(Error::InvalidIndex {
|
return Err(Error::InvalidIndex {
|
||||||
@ -132,16 +133,14 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
src: &[T],
|
src: &[T],
|
||||||
dst: &mut [T],
|
dst: &mut [T],
|
||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
src_shape: &Shape,
|
src_l: &Layout,
|
||||||
src_stride: &[usize],
|
|
||||||
src_offset: usize,
|
|
||||||
) {
|
) {
|
||||||
let src = &src[src_offset..];
|
let src = &src[src_l.start_offset()..];
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_l.is_contiguous() {
|
||||||
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
let elem_to_copy = (dst.len() - dst_offset).min(src.len());
|
||||||
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
|
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy])
|
||||||
} else {
|
} else {
|
||||||
let src_indexes = StridedIndex::new(src_shape.dims(), src_stride);
|
let src_indexes = StridedIndex::new(src_l);
|
||||||
for (dst_index, src_index) in src_indexes.enumerate() {
|
for (dst_index, src_index) in src_indexes.enumerate() {
|
||||||
let dst_index = dst_index + dst_offset;
|
let dst_index = dst_index + dst_offset;
|
||||||
if dst_index >= dst.len() {
|
if dst_index >= dst.len() {
|
||||||
@ -556,29 +555,14 @@ impl CpuStorage {
|
|||||||
&self,
|
&self,
|
||||||
dst: &mut Self,
|
dst: &mut Self,
|
||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
src_shape: &Shape,
|
src_l: Layout,
|
||||||
src_stride: &[usize],
|
|
||||||
src_offset: usize,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if src_shape.rank() != src_stride.len() {
|
|
||||||
panic!("incoherent shape and strides {src_shape:?} {src_stride:?}")
|
|
||||||
}
|
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::U32(src), Self::U32(dst)) => {
|
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
}
|
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
(Self::BF16(src), Self::BF16(dst)) => {
|
(Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
(Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||||
}
|
|
||||||
(Self::F16(src), Self::F16(dst)) => {
|
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
|
||||||
}
|
|
||||||
(Self::F32(src), Self::F32(dst)) => {
|
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
|
||||||
}
|
|
||||||
(Self::F64(src), Self::F64(dst)) => {
|
|
||||||
copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset)
|
|
||||||
}
|
|
||||||
(_, dst) => {
|
(_, dst) => {
|
||||||
// This should be covered by the dtype check above.
|
// This should be covered by the dtype check above.
|
||||||
return Err(Error::DTypeMismatchBinaryOp {
|
return Err(Error::DTypeMismatchBinaryOp {
|
||||||
@ -593,34 +577,33 @@ impl CpuStorage {
|
|||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
layout: &Layout,
|
||||||
stride: &[usize],
|
|
||||||
t: &Self,
|
t: &Self,
|
||||||
stride_t: &[usize],
|
layout_t: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
stride_f: &[usize],
|
layout_f: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// TODO: Support types that could be casted to a boolean.
|
// TODO: Support types that could be casted to a boolean.
|
||||||
let pred = self.as_slice::<u32>()?;
|
let pred = self.as_slice::<u32>()?;
|
||||||
match (t, f) {
|
match (t, f) {
|
||||||
(Self::BF16(t), Self::BF16(f)) => {
|
(Self::BF16(t), Self::BF16(f)) => {
|
||||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||||
Ok(Self::BF16(data))
|
Ok(Self::BF16(data))
|
||||||
}
|
}
|
||||||
(Self::F16(t), Self::F16(f)) => {
|
(Self::F16(t), Self::F16(f)) => {
|
||||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||||
Ok(Self::F16(data))
|
Ok(Self::F16(data))
|
||||||
}
|
}
|
||||||
(Self::F32(t), Self::F32(f)) => {
|
(Self::F32(t), Self::F32(f)) => {
|
||||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
(Self::F64(t), Self::F64(f)) => {
|
(Self::F64(t), Self::F64(f)) => {
|
||||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
(Self::U32(t), Self::U32(f)) => {
|
(Self::U32(t), Self::U32(f)) => {
|
||||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
let data = wcond(pred, layout, t, layout_t, f, layout_f);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
@ -631,16 +614,15 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
layout: &Layout,
|
||||||
stride: &[usize],
|
|
||||||
vs: &Self,
|
vs: &Self,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size)
|
map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
|
@ -9,16 +9,20 @@ pub struct Layout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Layout {
|
impl Layout {
|
||||||
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
|
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
Self {
|
Self {
|
||||||
shape,
|
shape,
|
||||||
stride,
|
stride,
|
||||||
start_offset: 0,
|
start_offset,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn contiguous<S: Into<Shape>>(shape: S) -> Self {
|
||||||
|
Self::contiguous_with_offset(shape, 0)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dims(&self) -> &[usize] {
|
pub fn dims(&self) -> &[usize] {
|
||||||
self.shape.dims()
|
self.shape.dims()
|
||||||
}
|
}
|
||||||
@ -45,7 +49,7 @@ impl Layout {
|
|||||||
self.shape.is_fortran_contiguous(&self.stride)
|
self.shape.is_fortran_contiguous(&self.stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
|
||||||
let dims = self.shape().dims();
|
let dims = self.shape().dims();
|
||||||
if dim >= dims.len() {
|
if dim >= dims.len() {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
@ -65,4 +69,61 @@ impl Layout {
|
|||||||
start_offset: self.start_offset + self.stride[dim] * start,
|
start_offset: self.start_offset + self.stride[dim] * start,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||||
|
let rank = self.shape.rank();
|
||||||
|
if rank <= dim1 || rank <= dim2 {
|
||||||
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: usize::max(dim1, dim2),
|
||||||
|
got: rank,
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let mut stride = self.stride().to_vec();
|
||||||
|
let mut dims = self.shape().dims().to_vec();
|
||||||
|
dims.swap(dim1, dim2);
|
||||||
|
stride.swap(dim1, dim2);
|
||||||
|
Ok(Self {
|
||||||
|
shape: Shape::from(dims),
|
||||||
|
stride,
|
||||||
|
start_offset: self.start_offset,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||||
|
let shape = shape.into();
|
||||||
|
if shape.rank() < self.shape().rank() {
|
||||||
|
Err(Error::BroadcastIncompatibleShapes {
|
||||||
|
src_shape: self.shape().clone(),
|
||||||
|
dst_shape: shape,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
let added_dims = shape.rank() - self.shape().rank();
|
||||||
|
let mut stride = vec![0; added_dims];
|
||||||
|
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||||
|
.iter()
|
||||||
|
.zip(self.dims().iter().zip(self.stride()))
|
||||||
|
{
|
||||||
|
let s = if dst_dim == src_dim {
|
||||||
|
src_stride
|
||||||
|
} else if src_dim != 1 {
|
||||||
|
return Err(Error::BroadcastIncompatibleShapes {
|
||||||
|
src_shape: self.shape().clone(),
|
||||||
|
dst_shape: shape,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
stride.push(s)
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
start_offset: self.start_offset,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
|
crate::StridedIndex::new(&self)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,7 +145,7 @@ impl Storage {
|
|||||||
|
|
||||||
pub(crate) fn where_cond(
|
pub(crate) fn where_cond(
|
||||||
&self,
|
&self,
|
||||||
layout: &Shape,
|
layout: &Layout,
|
||||||
t: &Self,
|
t: &Self,
|
||||||
layout_t: &Layout,
|
layout_t: &Layout,
|
||||||
f: &Self,
|
f: &Self,
|
||||||
@ -171,7 +171,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding(
|
||||||
&self,
|
&self,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -181,11 +181,11 @@ impl Storage {
|
|||||||
self.same_device(rhs, "embedding")?;
|
self.same_device(rhs, "embedding")?;
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
@ -227,15 +227,11 @@ impl Storage {
|
|||||||
&self,
|
&self,
|
||||||
dst: &mut Self,
|
dst: &mut Self,
|
||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
src_layout: &Layout,
|
src_l: &Layout,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
match (self, dst) {
|
match (self, dst) {
|
||||||
(Self::Cpu(src), Self::Cpu(dst)) => {
|
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
|
||||||
src.copy_strided_src(dst, dst_offset, src_layout, src_offset)
|
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
|
||||||
}
|
|
||||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
|
||||||
Ok(src.copy_strided_src(dst, dst_offset, src_layout, src_offset)?)
|
|
||||||
}
|
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
rhs: rhs.device().location(),
|
rhs: rhs.device().location(),
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
|
use crate::Layout;
|
||||||
|
|
||||||
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
/// An iterator over offset position for items of an N-dimensional arrays stored in a
|
||||||
/// flat buffer using some potential strides.
|
/// flat buffer using some potential strides.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct StridedIndex<'a> {
|
pub(crate) struct StridedIndex<'a> {
|
||||||
next_storage_index: Option<usize>,
|
next_storage_index: Option<usize>,
|
||||||
multi_index: Vec<usize>,
|
multi_index: Vec<usize>,
|
||||||
dims: &'a [usize],
|
layout: &'a Layout,
|
||||||
stride: &'a [usize],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> StridedIndex<'a> {
|
impl<'a> StridedIndex<'a> {
|
||||||
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
|
pub(crate) fn new(layout: &'a Layout) -> Self {
|
||||||
|
let dims = layout.dims();
|
||||||
let elem_count: usize = dims.iter().product();
|
let elem_count: usize = dims.iter().product();
|
||||||
let next_storage_index = if elem_count == 0 {
|
let next_storage_index = if elem_count == 0 {
|
||||||
None
|
None
|
||||||
@ -20,8 +22,7 @@ impl<'a> StridedIndex<'a> {
|
|||||||
StridedIndex {
|
StridedIndex {
|
||||||
next_storage_index,
|
next_storage_index,
|
||||||
multi_index: vec![0; dims.len()],
|
multi_index: vec![0; dims.len()],
|
||||||
dims,
|
layout,
|
||||||
stride,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> {
|
|||||||
Some(storage_index) => storage_index,
|
Some(storage_index) => storage_index,
|
||||||
};
|
};
|
||||||
let mut updated = false;
|
let mut updated = false;
|
||||||
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
for (multi_i, max_i) in self
|
||||||
|
.multi_index
|
||||||
|
.iter_mut()
|
||||||
|
.zip(self.layout.dims().iter())
|
||||||
|
.rev()
|
||||||
|
{
|
||||||
let next_i = *multi_i + 1;
|
let next_i = *multi_i + 1;
|
||||||
if next_i < *max_i {
|
if next_i < *max_i {
|
||||||
*multi_i = next_i;
|
*multi_i = next_i;
|
||||||
@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> {
|
|||||||
let next_storage_index = self
|
let next_storage_index = self
|
||||||
.multi_index
|
.multi_index
|
||||||
.iter()
|
.iter()
|
||||||
.zip(self.stride.iter())
|
.zip(self.layout.stride().iter())
|
||||||
.map(|(&x, &y)| x * y)
|
.map(|(&x, &y)| x * y)
|
||||||
.sum();
|
.sum::<usize>()
|
||||||
|
+ self.layout.start_offset();
|
||||||
Some(next_storage_index)
|
Some(next_storage_index)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
@ -481,13 +481,9 @@ impl Tensor {
|
|||||||
let ids_shape = ids.shape();
|
let ids_shape = ids.shape();
|
||||||
let seq_len = ids_shape.r1()?;
|
let seq_len = ids_shape.r1()?;
|
||||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||||
let storage = ids.storage.embedding_impl(
|
let storage = ids
|
||||||
ids.layout(),
|
.storage
|
||||||
&ids.stride,
|
.embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?;
|
||||||
&rhs.storage,
|
|
||||||
hidden_size,
|
|
||||||
vocab_size,
|
|
||||||
)?;
|
|
||||||
let shape: Shape = (seq_len, hidden_size).into();
|
let shape: Shape = (seq_len, hidden_size).into();
|
||||||
let op = if ids.track_op() || rhs.track_op() {
|
let op = if ids.track_op() || rhs.track_op() {
|
||||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||||
@ -498,7 +494,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
|
||||||
crate::StridedIndex::new(self.dims(), self.stride())
|
self.layout.strided_index()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns data from the underlying storage, this does not take the strides
|
/// Returns data from the underlying storage, this does not take the strides
|
||||||
@ -591,7 +587,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &Shape {
|
pub fn shape(&self) -> &Shape {
|
||||||
&self.shape
|
&self.layout().shape()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dims(&self) -> &[usize] {
|
pub fn dims(&self) -> &[usize] {
|
||||||
@ -682,18 +678,6 @@ impl Tensor {
|
|||||||
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||||
/// swapped.
|
/// swapped.
|
||||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||||
let rank = self.rank();
|
|
||||||
if rank <= dim1 || rank <= dim2 {
|
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
|
||||||
expected: usize::max(dim1, dim2),
|
|
||||||
got: rank,
|
|
||||||
shape: self.shape().clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
let mut stride = self.stride().to_vec();
|
|
||||||
let mut dims = self.shape().dims().to_vec();
|
|
||||||
dims.swap(dim1, dim2);
|
|
||||||
stride.swap(dim1, dim2);
|
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||||
} else {
|
} else {
|
||||||
@ -702,8 +686,7 @@ impl Tensor {
|
|||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
shape: Shape::from(dims),
|
layout: self.layout.transpose(dim1, dim2)?,
|
||||||
stride,
|
|
||||||
op,
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
@ -795,36 +778,10 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let shape = shape.into();
|
|
||||||
if shape.rank() < self.rank() {
|
|
||||||
return Err(Error::BroadcastIncompatibleShapes {
|
|
||||||
src_shape: self.shape().clone(),
|
|
||||||
dst_shape: shape,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
let added_dims = shape.rank() - self.rank();
|
|
||||||
let mut stride = vec![0; added_dims];
|
|
||||||
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
|
||||||
.iter()
|
|
||||||
.zip(self.dims().iter().zip(self.stride()))
|
|
||||||
{
|
|
||||||
let s = if dst_dim == src_dim {
|
|
||||||
src_stride
|
|
||||||
} else if src_dim != 1 {
|
|
||||||
return Err(Error::BroadcastIncompatibleShapes {
|
|
||||||
src_shape: self.shape().clone(),
|
|
||||||
dst_shape: shape,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
stride.push(s)
|
|
||||||
}
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
shape,
|
layout: self.layout.broadcast_as(shape)?,
|
||||||
stride,
|
|
||||||
op,
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
@ -888,12 +845,10 @@ impl Tensor {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
if self.is_contiguous() {
|
if self.is_contiguous() {
|
||||||
let stride = shape.stride_contiguous();
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
shape,
|
layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()),
|
||||||
stride,
|
|
||||||
op,
|
op,
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user