mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Support wider shapes for llama.
This commit is contained in:
34
src/shape.rs
34
src/shape.rs
@ -9,20 +9,8 @@ impl std::fmt::Debug for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[usize; 1]> for Shape {
|
||||
fn from(dims: &[usize; 1]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[usize; 2]> for Shape {
|
||||
fn from(dims: &[usize; 2]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&[usize; 3]> for Shape {
|
||||
fn from(dims: &[usize; 3]) -> Self {
|
||||
impl<const C: usize> From<&[usize; C]> for Shape {
|
||||
fn from(dims: &[usize; C]) -> Self {
|
||||
Self(dims.to_vec())
|
||||
}
|
||||
}
|
||||
@ -63,6 +51,18 @@ impl From<(usize, usize, usize)> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize)> for Shape {
|
||||
fn from(d1234: (usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
Self(dims)
|
||||
@ -121,6 +121,12 @@ impl Shape {
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
|
Reference in New Issue
Block a user