mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Use candle_nn::embedding instead of local copies in a few models. (#1562)
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
pub const DTYPE: DType = DType::F32;
|
pub const DTYPE: DType = DType::F32;
|
||||||
@ -112,11 +112,6 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Dropout {
|
struct Dropout {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pr: f64,
|
pr: f64,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
|
|||||||
Ok(Linear::new(weight, bias))
|
Ok(Linear::new(weight, bias))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let weight = vb.get(size, "weight")?;
|
let weight = vb.get(size, "weight")?;
|
||||||
let bias = vb.get(size, "bias")?;
|
let bias = vb.get(size, "bias")?;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 5000;
|
const MAX_SEQ_LEN: usize = 5000;
|
||||||
|
|
||||||
@ -27,11 +27,6 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
|||||||
Ok(LayerNorm::new(weight, bias, eps))
|
Ok(LayerNorm::new(weight, bias, eps))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
@ -136,11 +136,6 @@ impl Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
inner: candle_nn::RmsNorm,
|
inner: candle_nn::RmsNorm,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -409,7 +404,7 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||||
|
@ -1,12 +1,7 @@
|
|||||||
use super::Config;
|
use super::Config;
|
||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
use candle_nn::{embedding, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||||
|
|
||||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
|
Reference in New Issue
Block a user