This commit is contained in:
Nicolas Patry
2023-12-15 01:41:14 +01:00
parent ece4c69a68
commit 40c3e1bd5a

View File

@ -91,7 +91,7 @@ impl MetalDevice {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => {
panic!("Alredy committed");
panic!("Already committed");
}
_ => {}
}
@ -166,9 +166,6 @@ impl MetalDevice {
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence);
blit.end_encoding();
// drop(command_buffer);
// real.did_modify_range(metal::NSRange::new(0, real.length()));
// println!("Command {:?}", command.status());
// This is necessary, for mmaped safetensors
// Because of the unsafe slice cast we're doing.
@ -245,11 +242,7 @@ impl BackendStorage for MetalStorage {
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => {
let vec = read_to_vec(&buffer, length / size);
// println!("Got back {:?}", &vec[..1]);
Ok(CpuStorage::F32(vec))
}
DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
}
}
@ -302,7 +295,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -401,7 +393,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -644,21 +635,13 @@ impl BackendStorage for MetalStorage {
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
// ("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
//("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
// ("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
// ("bdiv", DType::F32) => contiguous::div::FLOAT,
("add", DType::F16) => contiguous::add::HALF,
// ("badd", DType::F16) => contiguous::add::HALF,
("sub", DType::F16) => contiguous::sub::HALF,
// ("bsub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF,
// ("bmul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF,
// ("bdiv", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
@ -877,8 +860,6 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
let name = match self.dtype {
DType::F32 => "sgemm",
@ -889,8 +870,6 @@ impl BackendStorage for MetalStorage {
};
let command_buffer = self.device.command_buffer();
// println!("MATMUL {b} {m} {n} {k}");
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm(
&self.device.device,
@ -907,14 +886,11 @@ impl BackendStorage for MetalStorage {
&buffer,
)
.map_err(MetalError::from)?;
// Create kernel
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer();
// println!("Copy strided");
if src_l.is_contiguous() && self.dtype == dst.dtype() {
command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder();
@ -975,7 +951,6 @@ impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> {
// println!("CREATING DEVICE");
let device = metal::Device::all().swap_remove(ordinal);
let n = 1;
@ -1024,6 +999,7 @@ impl BackendDevice for MetalDevice {
let command_buffer = self.command_buffer();
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.fill_buffer(
&buffer,
metal::NSRange {
@ -1032,6 +1008,7 @@ impl BackendDevice for MetalDevice {
},
0,
);
blit.update_fence(&self.fence);
blit.end_encoding();
Ok(MetalStorage::new(buffer, self.clone(), dtype))
}