Adding tons of profiling and removing the metal allocation (still slow).

This commit is contained in:
Nicolas Patry
2023-11-02 17:48:07 +01:00
parent 7161002a34
commit 9a27f11c3f
4 changed files with 100 additions and 61 deletions

View File

@ -30,6 +30,7 @@ safetensors = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
yoke = { workspace = true } yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
tracing = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -73,7 +73,11 @@ impl BackendStorage for MetalStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype{ match self.dtype{
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))), DType::F32 => {
// self.buffer.read_to_vec(self.buffer.length() as usize / 4);
let mut buffer = vec![0.0; 32000];
buffer[0] = 1.0;
Ok(CpuStorage::F32(buffer))},
dtype => todo!("Unsupported dtype {dtype:?}") dtype => todo!("Unsupported dtype {dtype:?}")
} }
} }
@ -271,13 +275,16 @@ impl MetalStorage {
let elem_count = b * m * n; let elem_count = b * m * n;
match (self.dtype, rhs.dtype) { match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => { (DType::F32, DType::F32) => {
let span= tracing::span!(tracing::Level::TRACE, "metal alloc matmul");
let _enter = span.enter();
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
if b != 1 { if b != 1 {
println!("TODO implement batched matmul for B={b}"); println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet"); // bail!("Didn't implemented strided matmul yet");
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
@ -286,20 +293,17 @@ impl MetalStorage {
} }
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous()); println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}); });
} }
let out_buffer = self.device.new_buffer( return Ok(Self {
(elem_count * mem::size_of::<f32>()) as u64, buffer: out_buffer,
MTLResourceOptions::empty(), device: self.device.clone(),
); dtype: self.dtype(),
});
let m: u64 = m.try_into().expect("usize should fit u64"); let m: u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64"); let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64"); let k: u64 = k.try_into().expect("usize should fit u64");
@ -359,6 +363,15 @@ impl MetalStorage {
} }
} }
impl MetalDevice{
pub fn flush(&mut self){
self.command_buffer.commit();
self.command_buffer.wait_until_completed();
self.command_buffer = self._command_queue.new_owned_command_buffer();
}
}
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
type Storage = MetalStorage; type Storage = MetalStorage;
@ -399,43 +412,47 @@ impl BackendDevice for MetalDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let buffer = match storage { let span= tracing::span!(tracing::Level::TRACE, "metal alloc");
CpuStorage::U8(storage) => self.device.new_buffer_with_data( let _enter = span.enter();
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64, let buffer = self.device.new_buffer(4, option);
option, // let buffer = match storage {
), // CpuStorage::U8(storage) => self.device.new_buffer_with_data(
CpuStorage::U32(storage) => self.device.new_buffer_with_data( // storage.as_ptr() as *const core::ffi::c_void,
storage.as_ptr() as *const core::ffi::c_void, // (storage.len() * mem::size_of::<u8>()) as u64,
(storage.len() * mem::size_of::<u32>()) as u64, // option,
option, // ),
), // CpuStorage::U32(storage) => self.device.new_buffer_with_data(
CpuStorage::I64(storage) => self.device.new_buffer_with_data( // storage.as_ptr() as *const core::ffi::c_void,
storage.as_ptr() as *const core::ffi::c_void, // (storage.len() * mem::size_of::<u32>()) as u64,
(storage.len() * mem::size_of::<i64>()) as u64, // option,
option, // ),
), // CpuStorage::I64(storage) => self.device.new_buffer_with_data(
CpuStorage::BF16(storage) => self.device.new_buffer_with_data( // storage.as_ptr() as *const core::ffi::c_void,
storage.as_ptr() as *const core::ffi::c_void, // (storage.len() * mem::size_of::<i64>()) as u64,
(storage.len() * mem::size_of::<bf16>()) as u64, // option,
option, // ),
), // CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
CpuStorage::F16(storage) => self.device.new_buffer_with_data( // storage.as_ptr() as *const core::ffi::c_void,
storage.as_ptr() as *const core::ffi::c_void, // (storage.len() * mem::size_of::<bf16>()) as u64,
(storage.len() * mem::size_of::<f16>()) as u64, // option,
option, // ),
), // CpuStorage::F16(storage) => self.device.new_buffer_with_data(
CpuStorage::F32(storage) => self.device.new_buffer_with_data( // storage.as_ptr() as *const core::ffi::c_void,
storage.as_ptr() as *const core::ffi::c_void, // (storage.len() * mem::size_of::<f16>()) as u64,
(storage.len() * mem::size_of::<f32>()) as u64, // option,
option, // ),
), // CpuStorage::F32(storage) => self.device.new_buffer_with_data(
CpuStorage::F64(storage) => self.device.new_buffer_with_data( // storage.as_ptr() as *const core::ffi::c_void,
storage.as_ptr() as *const core::ffi::c_void, // (storage.len() * mem::size_of::<f32>()) as u64,
(storage.len() * mem::size_of::<f64>()) as u64, // option,
option, // ),
), // CpuStorage::F64(storage) => self.device.new_buffer_with_data(
}; // storage.as_ptr() as *const core::ffi::c_void,
// (storage.len() * mem::size_of::<f64>()) as u64,
// option,
// ),
// };
Ok(Self::Storage { Ok(Self::Storage {
buffer, buffer,
device: self.clone(), device: self.clone(),

View File

@ -232,7 +232,7 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let device = candle_examples::device(false)?; let mut device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. { let temperature = if args.temperature == 0. {
None None
} else { } else {
@ -384,17 +384,20 @@ fn main() -> anyhow::Result<()> {
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
if let candle::Device::Metal(device) = &mut device{
device.flush()
}
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { // let logits = if args.repeat_penalty == 1. {
logits // logits
} else { // } else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); // let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty( // candle_transformers::utils::apply_repeat_penalty(
&logits, // &logits,
args.repeat_penalty, // args.repeat_penalty,
&all_tokens[start_at..], // &all_tokens[start_at..],
)? // )?
}; // };
// TODO Remove this once implementation is finished. // TODO Remove this once implementation is finished.
let logits = logits.ones_like()?; let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?; next_token = logits_processor.sample(&logits)?;

View File

@ -79,6 +79,8 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
impl LayerWeights { impl LayerWeights {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter(); let _enter = self.span_rot.enter();
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
let _enter = span.enter();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let cos = self let cos = self
.cos .cos
@ -88,21 +90,37 @@ impl LayerWeights {
.sin .sin
.narrow(0, index_pos, seq_len)? .narrow(0, index_pos, seq_len)?
.reshape((seq_len, n_embd / 2, 1))?; .reshape((seq_len, n_embd / 2, 1))?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
let _enter = span.enter();
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
drop(_enter);
// This mimics the llama.cpp behavior. // This mimics the llama.cpp behavior.
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
// The resulting y0 and y1 are also interleaved with: // The resulting y0 and y1 are also interleaved with:
// y0 = x0*cos - x1*sin // y0 = x0*cos - x1*sin
// y1 = x0*sin + x1*cos // y1 = x0*sin + x1*cos
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
let _enter = span.enter();
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?; let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?; let x1 = x.narrow(D::Minus1, 1, 1)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
let _enter = span.enter();
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
let _enter = span.enter();
let rope = Tensor::cat(&[y0, y1], D::Minus1)?; let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
drop(_enter);
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
let _enter = span.enter();
let rope = rope.flatten_from(D::Minus2)?; let rope = rope.flatten_from(D::Minus2)?;
drop(_enter);
Ok(rope) Ok(rope)
} }