Handle arbitrary shapes in Tensor::new. (#718)

This commit is contained in:
Laurent Mazare
2023-09-02 20:59:21 +02:00
committed by GitHub
parent 21109e1983
commit 84d003ff53
3 changed files with 127 additions and 5 deletions

View File

@ -1618,6 +1618,90 @@ impl CpuStorage {
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}
pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
let storage0 = &storages[0];
let s = match storage0 {
Self::U8(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::U8(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::U8(storages)
}
Self::U32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::U32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::U32(storages)
}
Self::I64(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::I64(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::I64(storages)
}
Self::BF16(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::BF16(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::BF16(storages)
}
Self::F16(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F16(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F16(storages)
}
Self::F32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F32(storages)
}
Self::F64(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F64(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F64(storages)
}
};
Ok(s)
}
}
impl BackendStorage for CpuStorage {

View File

@ -100,6 +100,29 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
}
}
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
}
fn to_cpu_storage(&self) -> CpuStorage {
// This allocates intermediary memory and shouldn't be necessary.
let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
CpuStorage::concat(storages.as_slice()).unwrap()
}
}
impl Device {
pub fn new_cuda(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))

View File

@ -205,14 +205,29 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
let ty = vs.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(