Handle arbitrary shapes in Tensor::new. (#718)

This commit is contained in:
Laurent Mazare
2023-09-02 20:59:21 +02:00
committed by GitHub
parent 21109e1983
commit 84d003ff53
3 changed files with 127 additions and 5 deletions

View File

@ -1618,6 +1618,90 @@ impl CpuStorage {
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}
pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
let storage0 = &storages[0];
let s = match storage0 {
Self::U8(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::U8(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::U8(storages)
}
Self::U32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::U32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::U32(storages)
}
Self::I64(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::I64(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::I64(storages)
}
Self::BF16(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::BF16(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::BF16(storages)
}
Self::F16(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F16(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F16(storages)
}
Self::F32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F32(storages)
}
Self::F64(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F64(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F64(storages)
}
};
Ok(s)
}
}
impl BackendStorage for CpuStorage {