mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Handle arbitrary shapes in Tensor::new. (#718)
This commit is contained in:
@ -1618,6 +1618,90 @@ impl CpuStorage {
|
|||||||
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
||||||
D::cpu_storage_as_slice(self)
|
D::cpu_storage_as_slice(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
|
||||||
|
let storage0 = &storages[0];
|
||||||
|
let s = match storage0 {
|
||||||
|
Self::U8(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::U8(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::U8(storages)
|
||||||
|
}
|
||||||
|
Self::U32(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::U32(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::U32(storages)
|
||||||
|
}
|
||||||
|
Self::I64(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::I64(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::I64(storages)
|
||||||
|
}
|
||||||
|
Self::BF16(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::BF16(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::BF16(storages)
|
||||||
|
}
|
||||||
|
Self::F16(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::F16(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::F16(storages)
|
||||||
|
}
|
||||||
|
Self::F32(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::F32(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::F32(storages)
|
||||||
|
}
|
||||||
|
Self::F64(_) => {
|
||||||
|
let storages = storages
|
||||||
|
.iter()
|
||||||
|
.map(|s| match s {
|
||||||
|
Self::F64(s) => Ok(s.as_slice()),
|
||||||
|
_ => crate::bail!("dtype mismatch"),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>>>()?
|
||||||
|
.concat();
|
||||||
|
Self::F64(storages)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendStorage for CpuStorage {
|
impl BackendStorage for CpuStorage {
|
||||||
|
@ -100,6 +100,29 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S: NdArray> NdArray for Vec<S> {
|
||||||
|
fn shape(&self) -> Result<Shape> {
|
||||||
|
if self.is_empty() {
|
||||||
|
crate::bail!("empty array")
|
||||||
|
}
|
||||||
|
let shape0 = self[0].shape()?;
|
||||||
|
let n = self.len();
|
||||||
|
for v in self.iter() {
|
||||||
|
let shape = v.shape()?;
|
||||||
|
if shape != shape0 {
|
||||||
|
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_cpu_storage(&self) -> CpuStorage {
|
||||||
|
// This allocates intermediary memory and shouldn't be necessary.
|
||||||
|
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
|
||||||
|
CpuStorage::concat(storages.as_slice()).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Device {
|
impl Device {
|
||||||
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
pub fn new_cuda(ordinal: usize) -> Result<Self> {
|
||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
|
@ -205,14 +205,29 @@ impl PyTensor {
|
|||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<i64>(py) {
|
} else if let Ok(vs) = vs.extract::<i64>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
|
||||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
|
|
||||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
|
||||||
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
} else if let Ok(vs) = vs.extract::<f32>(py) {
|
||||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
|
||||||
|
let len = vs.len();
|
||||||
|
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
|
||||||
|
let len = vs.len();
|
||||||
|
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
let len = vs.len();
|
||||||
|
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
|
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||||
|
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||||
} else {
|
} else {
|
||||||
let ty = vs.as_ref(py).get_type();
|
let ty = vs.as_ref(py).get_type();
|
||||||
Err(PyTypeError::new_err(format!(
|
Err(PyTypeError::new_err(format!(
|
||||||
|
Reference in New Issue
Block a user