mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Adding tons of profiling and removing the metal allocation (still slow).
This commit is contained in:
@ -30,6 +30,7 @@ safetensors = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
@ -73,7 +73,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
match self.dtype{
|
match self.dtype{
|
||||||
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
|
DType::F32 => {
|
||||||
|
// self.buffer.read_to_vec(self.buffer.length() as usize / 4);
|
||||||
|
let mut buffer = vec![0.0; 32000];
|
||||||
|
buffer[0] = 1.0;
|
||||||
|
Ok(CpuStorage::F32(buffer))},
|
||||||
dtype => todo!("Unsupported dtype {dtype:?}")
|
dtype => todo!("Unsupported dtype {dtype:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -271,13 +275,16 @@ impl MetalStorage {
|
|||||||
let elem_count = b * m * n;
|
let elem_count = b * m * n;
|
||||||
match (self.dtype, rhs.dtype) {
|
match (self.dtype, rhs.dtype) {
|
||||||
(DType::F32, DType::F32) => {
|
(DType::F32, DType::F32) => {
|
||||||
|
let span= tracing::span!(tracing::Level::TRACE, "metal alloc matmul");
|
||||||
|
let _enter = span.enter();
|
||||||
|
|
||||||
|
let out_buffer = self.device.new_buffer(
|
||||||
|
(elem_count * mem::size_of::<f32>()) as u64,
|
||||||
|
MTLResourceOptions::empty(),
|
||||||
|
);
|
||||||
if b != 1 {
|
if b != 1 {
|
||||||
println!("TODO implement batched matmul for B={b}");
|
println!("TODO implement batched matmul for B={b}");
|
||||||
// bail!("Didn't implemented strided matmul yet");
|
// bail!("Didn't implemented strided matmul yet");
|
||||||
let out_buffer = self.device.new_buffer(
|
|
||||||
(elem_count * mem::size_of::<f32>()) as u64,
|
|
||||||
MTLResourceOptions::empty(),
|
|
||||||
);
|
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -286,20 +293,17 @@ impl MetalStorage {
|
|||||||
}
|
}
|
||||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||||
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
||||||
let out_buffer = self.device.new_buffer(
|
|
||||||
(elem_count * mem::size_of::<f32>()) as u64,
|
|
||||||
MTLResourceOptions::empty(),
|
|
||||||
);
|
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let out_buffer = self.device.new_buffer(
|
return Ok(Self {
|
||||||
(elem_count * mem::size_of::<f32>()) as u64,
|
buffer: out_buffer,
|
||||||
MTLResourceOptions::empty(),
|
device: self.device.clone(),
|
||||||
);
|
dtype: self.dtype(),
|
||||||
|
});
|
||||||
let m: u64 = m.try_into().expect("usize should fit u64");
|
let m: u64 = m.try_into().expect("usize should fit u64");
|
||||||
let n: u64 = n.try_into().expect("usize should fit u64");
|
let n: u64 = n.try_into().expect("usize should fit u64");
|
||||||
let k: u64 = k.try_into().expect("usize should fit u64");
|
let k: u64 = k.try_into().expect("usize should fit u64");
|
||||||
@ -359,6 +363,15 @@ impl MetalStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl MetalDevice{
|
||||||
|
pub fn flush(&mut self){
|
||||||
|
self.command_buffer.commit();
|
||||||
|
self.command_buffer.wait_until_completed();
|
||||||
|
self.command_buffer = self._command_queue.new_owned_command_buffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
impl BackendDevice for MetalDevice {
|
impl BackendDevice for MetalDevice {
|
||||||
type Storage = MetalStorage;
|
type Storage = MetalStorage;
|
||||||
|
|
||||||
@ -399,43 +412,47 @@ impl BackendDevice for MetalDevice {
|
|||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
||||||
let buffer = match storage {
|
let span= tracing::span!(tracing::Level::TRACE, "metal alloc");
|
||||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
let _enter = span.enter();
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
|
||||||
(storage.len() * mem::size_of::<u8>()) as u64,
|
let buffer = self.device.new_buffer(4, option);
|
||||||
option,
|
// let buffer = match storage {
|
||||||
),
|
// CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
// (storage.len() * mem::size_of::<u8>()) as u64,
|
||||||
(storage.len() * mem::size_of::<u32>()) as u64,
|
// option,
|
||||||
option,
|
// ),
|
||||||
),
|
// CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
// (storage.len() * mem::size_of::<u32>()) as u64,
|
||||||
(storage.len() * mem::size_of::<i64>()) as u64,
|
// option,
|
||||||
option,
|
// ),
|
||||||
),
|
// CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
// (storage.len() * mem::size_of::<i64>()) as u64,
|
||||||
(storage.len() * mem::size_of::<bf16>()) as u64,
|
// option,
|
||||||
option,
|
// ),
|
||||||
),
|
// CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
// (storage.len() * mem::size_of::<bf16>()) as u64,
|
||||||
(storage.len() * mem::size_of::<f16>()) as u64,
|
// option,
|
||||||
option,
|
// ),
|
||||||
),
|
// CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
// (storage.len() * mem::size_of::<f16>()) as u64,
|
||||||
(storage.len() * mem::size_of::<f32>()) as u64,
|
// option,
|
||||||
option,
|
// ),
|
||||||
),
|
// CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
// (storage.len() * mem::size_of::<f32>()) as u64,
|
||||||
(storage.len() * mem::size_of::<f64>()) as u64,
|
// option,
|
||||||
option,
|
// ),
|
||||||
),
|
// CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||||
};
|
// storage.as_ptr() as *const core::ffi::c_void,
|
||||||
|
// (storage.len() * mem::size_of::<f64>()) as u64,
|
||||||
|
// option,
|
||||||
|
// ),
|
||||||
|
// };
|
||||||
Ok(Self::Storage {
|
Ok(Self::Storage {
|
||||||
buffer,
|
buffer,
|
||||||
device: self.clone(),
|
device: self.clone(),
|
||||||
|
@ -232,7 +232,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = candle_examples::device(false)?;
|
let mut device = candle_examples::device(false)?;
|
||||||
let temperature = if args.temperature == 0. {
|
let temperature = if args.temperature == 0. {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
@ -384,17 +384,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
if let candle::Device::Metal(device) = &mut device{
|
||||||
|
device.flush()
|
||||||
|
}
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = if args.repeat_penalty == 1. {
|
// let logits = if args.repeat_penalty == 1. {
|
||||||
logits
|
// logits
|
||||||
} else {
|
// } else {
|
||||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
// let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
// candle_transformers::utils::apply_repeat_penalty(
|
||||||
&logits,
|
// &logits,
|
||||||
args.repeat_penalty,
|
// args.repeat_penalty,
|
||||||
&all_tokens[start_at..],
|
// &all_tokens[start_at..],
|
||||||
)?
|
// )?
|
||||||
};
|
// };
|
||||||
// TODO Remove this once implementation is finished.
|
// TODO Remove this once implementation is finished.
|
||||||
let logits = logits.ones_like()?;
|
let logits = logits.ones_like()?;
|
||||||
next_token = logits_processor.sample(&logits)?;
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
@ -79,6 +79,8 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
impl LayerWeights {
|
impl LayerWeights {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let _enter = self.span_rot.enter();
|
let _enter = self.span_rot.enter();
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cos");
|
||||||
|
let _enter = span.enter();
|
||||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||||
let cos = self
|
let cos = self
|
||||||
.cos
|
.cos
|
||||||
@ -88,21 +90,37 @@ impl LayerWeights {
|
|||||||
.sin
|
.sin
|
||||||
.narrow(0, index_pos, seq_len)?
|
.narrow(0, index_pos, seq_len)?
|
||||||
.reshape((seq_len, n_embd / 2, 1))?;
|
.reshape((seq_len, n_embd / 2, 1))?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad");
|
||||||
|
let _enter = span.enter();
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
|
||||||
|
drop(_enter);
|
||||||
// This mimics the llama.cpp behavior.
|
// This mimics the llama.cpp behavior.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
// https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105
|
||||||
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
|
||||||
// The resulting y0 and y1 are also interleaved with:
|
// The resulting y0 and y1 are also interleaved with:
|
||||||
// y0 = x0*cos - x1*sin
|
// y0 = x0*cos - x1*sin
|
||||||
// y1 = x0*sin + x1*cos
|
// y1 = x0*sin + x1*cos
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-reshape");
|
||||||
|
let _enter = span.enter();
|
||||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-broad-mul");
|
||||||
|
let _enter = span.enter();
|
||||||
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||||
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-cat");
|
||||||
|
let _enter = span.enter();
|
||||||
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
|
||||||
|
drop(_enter);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-rot-flatten");
|
||||||
|
let _enter = span.enter();
|
||||||
let rope = rope.flatten_from(D::Minus2)?;
|
let rope = rope.flatten_from(D::Minus2)?;
|
||||||
|
drop(_enter);
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user