mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Support wider shapes for llama.
This commit is contained in:
34
src/shape.rs
34
src/shape.rs
@ -9,20 +9,8 @@ impl std::fmt::Debug for Shape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<&[usize; 1]> for Shape {
|
impl<const C: usize> From<&[usize; C]> for Shape {
|
||||||
fn from(dims: &[usize; 1]) -> Self {
|
fn from(dims: &[usize; C]) -> Self {
|
||||||
Self(dims.to_vec())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&[usize; 2]> for Shape {
|
|
||||||
fn from(dims: &[usize; 2]) -> Self {
|
|
||||||
Self(dims.to_vec())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<&[usize; 3]> for Shape {
|
|
||||||
fn from(dims: &[usize; 3]) -> Self {
|
|
||||||
Self(dims.to_vec())
|
Self(dims.to_vec())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -63,6 +51,18 @@ impl From<(usize, usize, usize)> for Shape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<(usize, usize, usize, usize)> for Shape {
|
||||||
|
fn from(d1234: (usize, usize, usize, usize)) -> Self {
|
||||||
|
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||||
|
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
|
||||||
|
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<Vec<usize>> for Shape {
|
impl From<Vec<usize>> for Shape {
|
||||||
fn from(dims: Vec<usize>) -> Self {
|
fn from(dims: Vec<usize>) -> Self {
|
||||||
Self(dims)
|
Self(dims)
|
||||||
@ -121,6 +121,12 @@ impl Shape {
|
|||||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||||
(usize, usize, usize, usize)
|
(usize, usize, usize, usize)
|
||||||
);
|
);
|
||||||
|
extract_dims!(
|
||||||
|
r5,
|
||||||
|
5,
|
||||||
|
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||||
|
(usize, usize, usize, usize, usize)
|
||||||
|
);
|
||||||
|
|
||||||
/// The strides given in number of elements for a contiguous n-dimensional
|
/// The strides given in number of elements for a contiguous n-dimensional
|
||||||
/// arrays using this shape.
|
/// arrays using this shape.
|
||||||
|
Reference in New Issue
Block a user