mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Rename the .r functions to .dims so as to be a bit more explicit. (#220)
This commit is contained in:
@ -1688,7 +1688,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
|
|
||||||
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = self.as_slice::<u32>()?;
|
let ids = self.as_slice::<u32>()?;
|
||||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
|
||||||
Embedding {
|
Embedding {
|
||||||
vocab_size,
|
vocab_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -620,7 +620,7 @@ impl<'a> Map1 for Embedding<'a> {
|
|||||||
let shape = ids_l.shape();
|
let shape = ids_l.shape();
|
||||||
let (v_size, h_size) = rhs_l
|
let (v_size, h_size) = rhs_l
|
||||||
.shape()
|
.shape()
|
||||||
.r2()
|
.dims2()
|
||||||
.map_err(|e| CudaError::WrappedError(Box::new(e)))
|
.map_err(|e| CudaError::WrappedError(Box::new(e)))
|
||||||
.w()?;
|
.w()?;
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
|
@ -87,6 +87,12 @@ macro_rules! extract_dims {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl crate::Tensor {
|
||||||
|
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||||
|
self.shape().$fn_name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::convert::TryInto<$out_type> for Shape {
|
impl std::convert::TryInto<$out_type> for Shape {
|
||||||
type Error = crate::Error;
|
type Error = crate::Error;
|
||||||
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
|
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
|
||||||
@ -328,23 +334,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
|
||||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||||
extract_dims!(
|
extract_dims!(
|
||||||
r3,
|
dims3,
|
||||||
3,
|
3,
|
||||||
|d: &[usize]| (d[0], d[1], d[2]),
|
|d: &[usize]| (d[0], d[1], d[2]),
|
||||||
(usize, usize, usize)
|
(usize, usize, usize)
|
||||||
);
|
);
|
||||||
extract_dims!(
|
extract_dims!(
|
||||||
r4,
|
dims4,
|
||||||
4,
|
4,
|
||||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||||
(usize, usize, usize, usize)
|
(usize, usize, usize, usize)
|
||||||
);
|
);
|
||||||
extract_dims!(
|
extract_dims!(
|
||||||
r5,
|
dims5,
|
||||||
5,
|
5,
|
||||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||||
(usize, usize, usize, usize, usize)
|
(usize, usize, usize, usize, usize)
|
||||||
|
@ -772,7 +772,7 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Applies a 1D convolution over the input tensor.
|
/// Applies a 1D convolution over the input tensor.
|
||||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||||
let (b_size, c_in, l_in) = match *self.dims() {
|
let (b_size, c_in, l_in) = match *self.dims() {
|
||||||
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
[b_size, c_in, l_in] => (Some(b_size), c_in, l_in),
|
||||||
[c_in, l_in] => (None, c_in, l_in),
|
[c_in, l_in] => (None, c_in, l_in),
|
||||||
@ -931,8 +931,8 @@ impl Tensor {
|
|||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let ids_shape = ids.shape();
|
let ids_shape = ids.shape();
|
||||||
let seq_len = ids_shape.r1()?;
|
let seq_len = ids_shape.dims1()?;
|
||||||
let (_, hidden_size) = rhs.shape().r2()?;
|
let (_, hidden_size) = rhs.dims2()?;
|
||||||
let storage = ids
|
let storage = ids
|
||||||
.storage()
|
.storage()
|
||||||
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
||||||
@ -1013,7 +1013,7 @@ impl Tensor {
|
|||||||
// The number of element in indexes must match the dimension on which the add is
|
// The number of element in indexes must match the dimension on which the add is
|
||||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||||
// the target tensor self)
|
// the target tensor self)
|
||||||
mismatch || source_dims[dim] != indexes.shape().r1()?
|
mismatch || source_dims[dim] != indexes.dims1()?
|
||||||
};
|
};
|
||||||
if mismatch {
|
if mismatch {
|
||||||
Err(Error::ShapeMismatchBinaryOp {
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
@ -1144,7 +1144,7 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
|
/// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
|
||||||
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
|
||||||
let (dim1, dim2) = self.shape().r2()?;
|
let (dim1, dim2) = self.dims2()?;
|
||||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
let mut rows = vec![];
|
let mut rows = vec![];
|
||||||
@ -1164,7 +1164,7 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns the data contained in a 3D tensor.
|
/// Returns the data contained in a 3D tensor.
|
||||||
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
|
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
|
||||||
let (dim1, dim2, dim3) = self.shape().r3()?;
|
let (dim1, dim2, dim3) = self.dims3()?;
|
||||||
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
|
||||||
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
let data = S::cpu_storage_as_slice(cpu_storage)?;
|
||||||
let mut top_rows = vec![];
|
let mut top_rows = vec![];
|
||||||
|
@ -4,7 +4,7 @@ use test_utils::to_vec3_round;
|
|||||||
|
|
||||||
fn zeros(device: &Device) -> Result<()> {
|
fn zeros(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||||
let (dim1, dim2) = tensor.shape().r2()?;
|
let (dim1, dim2) = tensor.dims2()?;
|
||||||
assert_eq!(dim1, 5);
|
assert_eq!(dim1, 5);
|
||||||
assert_eq!(dim2, 2);
|
assert_eq!(dim2, 2);
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -12,7 +12,7 @@ fn zeros(device: &Device) -> Result<()> {
|
|||||||
|
|
||||||
fn add_mul(device: &Device) -> Result<()> {
|
fn add_mul(device: &Device) -> Result<()> {
|
||||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
let dim1 = tensor.shape().r1()?;
|
let dim1 = tensor.dims1()?;
|
||||||
assert_eq!(dim1, 3);
|
assert_eq!(dim1, 3);
|
||||||
let content: Vec<f32> = tensor.to_vec1()?;
|
let content: Vec<f32> = tensor.to_vec1()?;
|
||||||
assert_eq!(content, [3., 1., 4.]);
|
assert_eq!(content, [3., 1., 4.]);
|
||||||
@ -28,7 +28,7 @@ fn add_mul(device: &Device) -> Result<()> {
|
|||||||
fn tensor_2d(device: &Device) -> Result<()> {
|
fn tensor_2d(device: &Device) -> Result<()> {
|
||||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
let tensor = Tensor::new(data, device)?;
|
let tensor = Tensor::new(data, device)?;
|
||||||
let dims = tensor.shape().r2()?;
|
let dims = tensor.dims2()?;
|
||||||
assert_eq!(dims, (2, 5));
|
assert_eq!(dims, (2, 5));
|
||||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||||
assert_eq!(content, data);
|
assert_eq!(content, data);
|
||||||
@ -41,7 +41,7 @@ fn binary_op(device: &Device) -> Result<()> {
|
|||||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||||
let tensor2 = Tensor::new(data2, device)?;
|
let tensor2 = Tensor::new(data2, device)?;
|
||||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||||
let dims = tensor.shape().r2()?;
|
let dims = tensor.dims2()?;
|
||||||
assert_eq!(dims, (2, 5));
|
assert_eq!(dims, (2, 5));
|
||||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||||
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
|
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
|
||||||
@ -56,7 +56,7 @@ fn binary_op(device: &Device) -> Result<()> {
|
|||||||
fn transpose(device: &Device) -> Result<()> {
|
fn transpose(device: &Device) -> Result<()> {
|
||||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||||
let tensor = Tensor::new(data, device)?.t()?;
|
let tensor = Tensor::new(data, device)?.t()?;
|
||||||
let dims = tensor.shape().r2()?;
|
let dims = tensor.dims2()?;
|
||||||
assert_eq!(dims, (5, 2));
|
assert_eq!(dims, (5, 2));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.to_vec2::<f32>()?,
|
tensor.to_vec2::<f32>()?,
|
||||||
|
@ -161,7 +161,7 @@ fn main() -> Result<()> {
|
|||||||
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
let embeddings = model.forward(&token_ids, &token_type_ids)?;
|
||||||
println!("generated embeddings {:?}", embeddings.shape());
|
println!("generated embeddings {:?}", embeddings.shape());
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.shape().r3()?;
|
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||||
println!("pooled embeddings {:?}", embeddings.shape());
|
println!("pooled embeddings {:?}", embeddings.shape());
|
||||||
let mut similarities = vec![];
|
let mut similarities = vec![];
|
||||||
|
@ -87,7 +87,7 @@ impl LayerNorm {
|
|||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
d => d,
|
d => d,
|
||||||
};
|
};
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
@ -262,7 +262,7 @@ impl BertEmbeddings {
|
|||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||||
|
@ -182,7 +182,7 @@ impl FalconRotaryEmbedding {
|
|||||||
key: &Tensor,
|
key: &Tensor,
|
||||||
past_kv_len: usize,
|
past_kv_len: usize,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let (_batch, seq_len, _head_dim) = query.shape().r3()?;
|
let (_batch, seq_len, _head_dim) = query.dims3()?;
|
||||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||||
@ -245,7 +245,7 @@ impl FalconAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||||
let (b_sz, seq_len, _) = fused_qkv.shape().r3()?;
|
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
|
||||||
if !self.multi_query {
|
if !self.multi_query {
|
||||||
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
|
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
|
||||||
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
|
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
|
||||||
@ -267,7 +267,7 @@ impl FalconAttention {
|
|||||||
let fused_qkv = self.query_key_value.forward(x)?;
|
let fused_qkv = self.query_key_value.forward(x)?;
|
||||||
let head_dim = self.head_dim;
|
let head_dim = self.head_dim;
|
||||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||||
let (b_sz, seq_len, _, _) = query.shape().r4()?;
|
let (b_sz, seq_len, _, _) = query.dims4()?;
|
||||||
let query = query
|
let query = query
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||||
@ -465,7 +465,7 @@ impl Falcon {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||||
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
||||||
Some((k, _)) => k.dim(1)?,
|
Some((k, _)) => k.dim(1)?,
|
||||||
|
@ -116,11 +116,11 @@ impl RmsNorm {
|
|||||||
let in_dtype = x.dtype();
|
let in_dtype = x.dtype();
|
||||||
// This is a no-op if x's dtype is already f32.
|
// This is a no-op if x's dtype is already f32.
|
||||||
let x = x.to_dtype(DType::F32)?;
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().r1()?;
|
let size = self.scale.dims1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
@ -144,7 +144,7 @@ struct CausalSelfAttention {
|
|||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
|
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||||
@ -158,7 +158,7 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.q_proj.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
let k = self.k_proj.forward(x)?;
|
let k = self.k_proj.forward(x)?;
|
||||||
let v = self.v_proj.forward(x)?;
|
let v = self.v_proj.forward(x)?;
|
||||||
@ -219,7 +219,7 @@ impl CausalSelfAttention {
|
|||||||
if n_rep == 1 {
|
if n_rep == 1 {
|
||||||
Ok(x)
|
Ok(x)
|
||||||
} else {
|
} else {
|
||||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
|
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||||
let x = x
|
let x = x
|
||||||
.unsqueeze(2)?
|
.unsqueeze(2)?
|
||||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||||
@ -345,7 +345,7 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, index_pos, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx)?;
|
||||||
|
@ -123,7 +123,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let (_b_sz, _codebooks, seq_len) = input_ids.shape().r3()?;
|
let (_b_sz, _codebooks, seq_len) = input_ids.dims3()?;
|
||||||
if seq_len > self.weights.dim(0)? {
|
if seq_len > self.weights.dim(0)? {
|
||||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||||
}
|
}
|
||||||
@ -170,7 +170,7 @@ impl MusicgenAttention {
|
|||||||
kv_states: Option<&Tensor>,
|
kv_states: Option<&Tensor>,
|
||||||
attention_mask: &Tensor,
|
attention_mask: &Tensor,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (b_sz, tgt_len, _) = xs.shape().r3()?;
|
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||||
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
|
let query_states = (self.q_proj.forward(xs)? * self.scaling)?;
|
||||||
|
|
||||||
let kv_states = kv_states.unwrap_or(xs);
|
let kv_states = kv_states.unwrap_or(xs);
|
||||||
@ -308,7 +308,7 @@ impl MusicgenDecoder {
|
|||||||
|
|
||||||
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let dev = input_ids.device();
|
let dev = input_ids.device();
|
||||||
let (b_sz_times_codebooks, seq_len) = input_ids.shape().r2()?;
|
let (b_sz_times_codebooks, seq_len) = input_ids.dims2()?;
|
||||||
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
let b_sz = b_sz_times_codebooks / self.num_codebooks;
|
||||||
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
let input = input_ids.reshape((b_sz, self.num_codebooks, seq_len))?;
|
||||||
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
let mut inputs_embeds = Tensor::zeros((b_sz, seq_len, self.d_model), DType::F32, dev)?;
|
||||||
@ -352,7 +352,7 @@ impl MusicgenForCausalLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||||
let hidden_states = self.decoder.forward(input_ids)?;
|
let hidden_states = self.decoder.forward(input_ids)?;
|
||||||
let lm_logits = self
|
let lm_logits = self
|
||||||
.lm_heads
|
.lm_heads
|
||||||
|
@ -338,7 +338,7 @@ impl T5Stack {
|
|||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||||
let (_b_sz, _seq_len) = input_embeds.shape().r2()?;
|
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||||
|
|
||||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||||
for block in self.block.iter() {
|
for block in self.block.iter() {
|
||||||
|
@ -52,7 +52,7 @@ pub fn main() -> Result<()> {
|
|||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.sum_all()?
|
.sum_all()?
|
||||||
.to_scalar::<f32>()?;
|
.to_scalar::<f32>()?;
|
||||||
let test_accuracy = sum_ok / test_labels.shape().r1()? as f32;
|
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
|
||||||
println!(
|
println!(
|
||||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||||
loss.to_scalar::<f32>()?,
|
loss.to_scalar::<f32>()?,
|
||||||
|
@ -127,7 +127,7 @@ impl Decoder {
|
|||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (seq_len, _) = logits.shape().r2()?;
|
let (seq_len, _) = logits.dims2()?;
|
||||||
let logits = logits
|
let logits = logits
|
||||||
.get(seq_len - 1)?
|
.get(seq_len - 1)?
|
||||||
.broadcast_add(&self.suppress_tokens)?;
|
.broadcast_add(&self.suppress_tokens)?;
|
||||||
@ -195,7 +195,7 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
|
fn run(&mut self, mel: &Tensor) -> Result<Vec<Segment>> {
|
||||||
let (_, _, content_frames) = mel.shape().r3()?;
|
let (_, _, content_frames) = mel.dims3()?;
|
||||||
let mut seek = 0;
|
let mut seek = 0;
|
||||||
let mut segments = vec![];
|
let mut segments = vec![];
|
||||||
while seek < content_frames {
|
while seek < content_frames {
|
||||||
|
@ -132,7 +132,7 @@ impl MultiHeadAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||||
}
|
}
|
||||||
@ -144,7 +144,7 @@ impl MultiHeadAttention {
|
|||||||
v: &Tensor,
|
v: &Tensor,
|
||||||
mask: Option<&Tensor>,
|
mask: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (_, n_ctx, n_state) = q.shape().r3()?;
|
let (_, n_ctx, n_state) = q.dims3()?;
|
||||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||||
let q = (self.reshape_head(q)? * scale)?;
|
let q = (self.reshape_head(q)? * scale)?;
|
||||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||||
@ -270,7 +270,7 @@ impl AudioEncoder {
|
|||||||
let x = self.conv1.forward(x)?.gelu()?;
|
let x = self.conv1.forward(x)?.gelu()?;
|
||||||
let x = self.conv2.forward(&x)?.gelu()?;
|
let x = self.conv2.forward(&x)?.gelu()?;
|
||||||
let x = x.transpose(1, 2)?;
|
let x = x.transpose(1, 2)?;
|
||||||
let (_bsize, seq_len, _hidden) = x.shape().r3()?;
|
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
|
@ -41,7 +41,7 @@ impl Conv1d {
|
|||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
Some(bias) => {
|
Some(bias) => {
|
||||||
let b = bias.shape().r1()?;
|
let b = bias.dims1()?;
|
||||||
let bias = bias.reshape((1, b, 1))?;
|
let bias = bias.reshape((1, b, 1))?;
|
||||||
Ok(x.broadcast_add(&bias)?)
|
Ok(x.broadcast_add(&bias)?)
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ impl LayerNorm {
|
|||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
d => d,
|
d => d,
|
||||||
};
|
};
|
||||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
|
@ -164,7 +164,7 @@ impl MultiHeadAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (n_batch, n_ctx, n_state) = x.shape().r3()?;
|
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||||
}
|
}
|
||||||
@ -176,7 +176,7 @@ impl MultiHeadAttention {
|
|||||||
v: &Tensor,
|
v: &Tensor,
|
||||||
mask: Option<&Tensor>,
|
mask: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (_, n_ctx, n_state) = q.shape().r3()?;
|
let (_, n_ctx, n_state) = q.dims3()?;
|
||||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||||
let q = {
|
let q = {
|
||||||
let _timer = crate::Timer::new("q::reshape");
|
let _timer = crate::Timer::new("q::reshape");
|
||||||
@ -328,7 +328,7 @@ impl AudioEncoder {
|
|||||||
self.conv2.forward(&x)?.gelu()?
|
self.conv2.forward(&x)?.gelu()?
|
||||||
};
|
};
|
||||||
let x = x.transpose(1, 2)?;
|
let x = x.transpose(1, 2)?;
|
||||||
let (_bsize, seq_len, _hidden) = x.shape().r3()?;
|
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
|
@ -134,7 +134,7 @@ impl Decoder {
|
|||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
let (seq_len, _) = logits.shape().r2()?;
|
let (seq_len, _) = logits.dims2()?;
|
||||||
let logits = logits
|
let logits = logits
|
||||||
.get(seq_len - 1)?
|
.get(seq_len - 1)?
|
||||||
.broadcast_add(&self.suppress_tokens)?;
|
.broadcast_add(&self.suppress_tokens)?;
|
||||||
@ -207,7 +207,7 @@ impl Decoder {
|
|||||||
|
|
||||||
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
fn run(&self, mel: &Tensor) -> anyhow::Result<Vec<Segment>> {
|
||||||
let mut rng = StdRng::seed_from_u64(299792458);
|
let mut rng = StdRng::seed_from_u64(299792458);
|
||||||
let (_, _, content_frames) = mel.shape().r3()?;
|
let (_, _, content_frames) = mel.dims3()?;
|
||||||
let mut seek = 0;
|
let mut seek = 0;
|
||||||
let mut segments = vec![];
|
let mut segments = vec![];
|
||||||
while seek < content_frames {
|
while seek < content_frames {
|
||||||
|
Reference in New Issue
Block a user