mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Sketch the candle-nn crate. (#115)
* Sketch the candle-nn crate. * Tweak the cuda dependencies. * More cuda tweaks.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::Tensor;
|
||||
|
||||
@ -104,8 +104,8 @@ struct T5DenseActDense {
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
||||
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
wi,
|
||||
@ -154,10 +154,10 @@ struct T5Attention {
|
||||
impl T5Attention {
|
||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
||||
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
||||
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = Embedding::load(
|
||||
cfg.relative_attention_num_buckets,
|
||||
|
Reference in New Issue
Block a user