mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Handle arbitrary shapes in Tensor::new. (#718)
This commit is contained in:
@ -1618,6 +1618,90 @@ impl CpuStorage {
|
||||
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
|
||||
D::cpu_storage_as_slice(self)
|
||||
}
|
||||
|
||||
pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
|
||||
let storage0 = &storages[0];
|
||||
let s = match storage0 {
|
||||
Self::U8(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::U8(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::U8(storages)
|
||||
}
|
||||
Self::U32(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::U32(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::U32(storages)
|
||||
}
|
||||
Self::I64(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::I64(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::I64(storages)
|
||||
}
|
||||
Self::BF16(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::BF16(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::BF16(storages)
|
||||
}
|
||||
Self::F16(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::F16(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::F16(storages)
|
||||
}
|
||||
Self::F32(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::F32(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::F32(storages)
|
||||
}
|
||||
Self::F64(_) => {
|
||||
let storages = storages
|
||||
.iter()
|
||||
.map(|s| match s {
|
||||
Self::F64(s) => Ok(s.as_slice()),
|
||||
_ => crate::bail!("dtype mismatch"),
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.concat();
|
||||
Self::F64(storages)
|
||||
}
|
||||
};
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendStorage for CpuStorage {
|
||||
|
Reference in New Issue
Block a user