mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
More cleanup.
This commit is contained in:
@ -1863,10 +1863,7 @@ impl Tensor {
|
|||||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||||
}
|
}
|
||||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
(Storage::Metal(storage), Device::Cpu) => {
|
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||||
// println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
|
||||||
Storage::Cpu(storage.to_cpu_storage()?)
|
|
||||||
}
|
|
||||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||||
// are the same.
|
// are the same.
|
||||||
|
@ -900,9 +900,7 @@ fn matmul(device: &Device) -> Result<()> {
|
|||||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
let d = a.matmul(&c)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
assert_eq!(d.to_vec2::<f32>()?, &[[37.0, 54.0], [81.0, 118.0]]);
|
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
let data = vec![1.0f32, 2.0];
|
||||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
Reference in New Issue
Block a user