mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
cleanup.
This commit is contained in:
@ -91,7 +91,7 @@ impl MetalDevice {
|
|||||||
metal::MTLCommandBufferStatus::Committed
|
metal::MTLCommandBufferStatus::Committed
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
| metal::MTLCommandBufferStatus::Completed => {
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
panic!("Alredy committed");
|
panic!("Already committed");
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
@ -166,9 +166,6 @@ impl MetalDevice {
|
|||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||||
blit.update_fence(&self.fence);
|
blit.update_fence(&self.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
// drop(command_buffer);
|
|
||||||
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
|
||||||
// println!("Command {:?}", command.status());
|
|
||||||
|
|
||||||
// This is necessary, for mmaped safetensors
|
// This is necessary, for mmaped safetensors
|
||||||
// Because of the unsafe slice cast we're doing.
|
// Because of the unsafe slice cast we're doing.
|
||||||
@ -245,11 +242,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
||||||
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
||||||
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
||||||
DType::F32 => {
|
DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
|
||||||
let vec = read_to_vec(&buffer, length / size);
|
|
||||||
// println!("Got back {:?}", &vec[..1]);
|
|
||||||
Ok(CpuStorage::F32(vec))
|
|
||||||
}
|
|
||||||
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -302,7 +295,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,7 +393,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
|
||||||
Ok(Self::new(buffer, device.clone(), dtype))
|
Ok(Self::new(buffer, device.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -644,21 +635,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
("add", DType::F32) => contiguous::add::FLOAT,
|
("add", DType::F32) => contiguous::add::FLOAT,
|
||||||
// ("badd", DType::F32) => contiguous::add::FLOAT,
|
|
||||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||||
//("bsub", DType::F32) => contiguous::sub::FLOAT,
|
|
||||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||||
// ("bmul", DType::F32) => contiguous::mul::FLOAT,
|
|
||||||
("div", DType::F32) => contiguous::div::FLOAT,
|
("div", DType::F32) => contiguous::div::FLOAT,
|
||||||
// ("bdiv", DType::F32) => contiguous::div::FLOAT,
|
|
||||||
("add", DType::F16) => contiguous::add::HALF,
|
("add", DType::F16) => contiguous::add::HALF,
|
||||||
// ("badd", DType::F16) => contiguous::add::HALF,
|
|
||||||
("sub", DType::F16) => contiguous::sub::HALF,
|
("sub", DType::F16) => contiguous::sub::HALF,
|
||||||
// ("bsub", DType::F16) => contiguous::sub::HALF,
|
|
||||||
("mul", DType::F16) => contiguous::mul::HALF,
|
("mul", DType::F16) => contiguous::mul::HALF,
|
||||||
// ("bmul", DType::F16) => contiguous::mul::HALF,
|
|
||||||
("div", DType::F16) => contiguous::div::HALF,
|
("div", DType::F16) => contiguous::div::HALF,
|
||||||
// ("bdiv", DType::F16) => contiguous::div::HALF,
|
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
@ -877,8 +860,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// Create descriptors
|
|
||||||
|
|
||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "sgemm",
|
DType::F32 => "sgemm",
|
||||||
@ -889,8 +870,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
// println!("MATMUL {b} {m} {n} {k}");
|
|
||||||
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
|
|
||||||
command_buffer.set_label("matmul");
|
command_buffer.set_label("matmul");
|
||||||
candle_metal_kernels::call_gemm(
|
candle_metal_kernels::call_gemm(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
@ -907,14 +886,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
// Create kernel
|
|
||||||
|
|
||||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
// println!("Copy strided");
|
|
||||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||||
command_buffer.set_label("copy_contiguous");
|
command_buffer.set_label("copy_contiguous");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
@ -975,7 +951,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
type Storage = MetalStorage;
|
type Storage = MetalStorage;
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
// println!("CREATING DEVICE");
|
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
let n = 1;
|
let n = 1;
|
||||||
@ -1024,6 +999,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
let command_buffer = self.command_buffer();
|
let command_buffer = self.command_buffer();
|
||||||
command_buffer.set_label("zeros");
|
command_buffer.set_label("zeros");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.wait_for_fence(&self.fence);
|
||||||
blit.fill_buffer(
|
blit.fill_buffer(
|
||||||
&buffer,
|
&buffer,
|
||||||
metal::NSRange {
|
metal::NSRange {
|
||||||
@ -1032,6 +1008,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
},
|
},
|
||||||
0,
|
0,
|
||||||
);
|
);
|
||||||
|
blit.update_fence(&self.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user