Fixing tests + matmul from MFA

This commit is contained in:
Nicolas Patry
2023-12-13 16:58:36 +01:00
parent 0404a3eb5b
commit 931432ed55
5 changed files with 128 additions and 23 deletions

View File

@ -276,17 +276,17 @@ impl BackendStorage for MetalStorage {
self.device.wait_until_completed();
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(buffer.read_to_vec(length / size))),
DType::U32 => Ok(CpuStorage::U32(buffer.read_to_vec(length / size))),
DType::I64 => Ok(CpuStorage::I64(buffer.read_to_vec(length / size))),
DType::F16 => Ok(CpuStorage::F16(buffer.read_to_vec(length / size))),
DType::BF16 => Ok(CpuStorage::BF16(buffer.read_to_vec(length / size))),
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => {
let vec = buffer.read_to_vec(length / size);
let vec = read_to_vec(&buffer, length / size);
// println!("Got back {:?}", &vec[..1]);
Ok(CpuStorage::F32(vec))
}
DType::F64 => Ok(CpuStorage::F64(buffer.read_to_vec(length / size))),
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
}
}
@ -944,6 +944,8 @@ impl BackendStorage for MetalStorage {
};
let command_buffer = self.device.command_buffer();
// println!("MATMUL {b} {m} {n} {k}");
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm(
&self.device.device,
@ -952,16 +954,17 @@ impl BackendStorage for MetalStorage {
name,
(b, m, n, k),
&lhs_l.stride(),
lhs_l.start_offset(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
// Create kernel
command_buffer.commit();
self.device.wait_until_completed();
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
}
@ -1138,3 +1141,10 @@ impl BackendDevice for MetalDevice {
self.storage_from_cpu_storage(&cpu_storage)
}
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}