From 84d003ff530eb14597f33c8c763eeb573370e22e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 2 Sep 2023 20:59:21 +0200 Subject: [PATCH] Handle arbitrary shapes in Tensor::new. (#718) --- candle-core/src/cpu_backend.rs | 84 ++++++++++++++++++++++++++++++++++ candle-core/src/device.rs | 23 ++++++++++ candle-pyo3/src/lib.rs | 25 ++++++++-- 3 files changed, 127 insertions(+), 5 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 17d64b10..ed3dd3fc 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1618,6 +1618,90 @@ impl CpuStorage { pub fn as_slice(&self) -> Result<&[D]> { D::cpu_storage_as_slice(self) } + + pub fn concat(storages: &[CpuStorage]) -> Result { + let storage0 = &storages[0]; + let s = match storage0 { + Self::U8(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::U8(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::U8(storages) + } + Self::U32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::U32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::U32(storages) + } + Self::I64(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I64(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I64(storages) + } + Self::BF16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::BF16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::BF16(storages) + } + Self::F16(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F16(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F16(storages) + } + Self::F32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F32(storages) + } + Self::F64(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F64(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F64(storages) + } + }; + Ok(s) + } } impl BackendStorage for CpuStorage { diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index c057595f..0ed23a18 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -100,6 +100,29 @@ impl NdArray for Vec { + fn shape(&self) -> Result { + if self.is_empty() { + crate::bail!("empty array") + } + let shape0 = self[0].shape()?; + let n = self.len(); + for v in self.iter() { + let shape = v.shape()?; + if shape != shape0 { + crate::bail!("two elements have different shapes {shape:?} {shape0:?}") + } + } + Ok(Shape::from([[n].as_slice(), shape0.dims()].concat())) + } + + fn to_cpu_storage(&self) -> CpuStorage { + // This allocates intermediary memory and shouldn't be necessary. + let storages = self.iter().map(|v| v.to_cpu_storage()).collect::>(); + CpuStorage::concat(storages.as_slice()).unwrap() + } +} + impl Device { pub fn new_cuda(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 5e6f48ea..79f86479 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -205,14 +205,29 @@ impl PyTensor { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>(py) { + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = vs.extract::>(py) { - Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)? + let len = vs.len(); + Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(vs) = vs.extract::>>>(py) { + Tensor::new(vs, &Cpu).map_err(wrap_err)? } else { let ty = vs.as_ref(py).get_type(); Err(PyTypeError::new_err(format!(