Sketch the candle-nn crate. (#115)

* Sketch the candle-nn crate.

* Tweak the cuda dependencies.

* More cuda tweaks.
This commit is contained in:
Laurent Mazare
2023-07-10 08:50:09 +01:00
committed by GitHub
parent bc3be6f9b0
commit 9ce0f1c010
13 changed files with 230 additions and 315 deletions

View File

@ -1,7 +1,7 @@
// T5 Text Encoder
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::nn::{Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use crate::nn::{linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
use anyhow::Result;
use candle::Tensor;
@ -104,8 +104,8 @@ struct T5DenseActDense {
impl T5DenseActDense {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let wi = Linear::load(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
let wo = Linear::load(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
let dropout = Dropout::new(cfg.dropout_rate);
Ok(Self {
wi,
@ -154,10 +154,10 @@ struct T5Attention {
impl T5Attention {
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let inner_dim = cfg.num_heads * cfg.d_kv;
let q = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
let k = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
let v = Linear::load(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = Linear::load(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
let relative_attention_bias = if h {
let emb = Embedding::load(
cfg.relative_attention_num_buckets,