MPT alibi fixes. (#1120)

* MPT alibi fixes.

* Some more fixes.

* Finally get the model to return some sensible outputs.

* Add a readme.
This commit is contained in:
Laurent Mazare
2023-10-18 10:58:05 +01:00
committed by GitHub
parent 662c186fd5
commit 767a6578f1
3 changed files with 64 additions and 13 deletions

View File

@ -0,0 +1,45 @@
# candle-replit-code: code completion specialized model.
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
language model specialized for code completion. This model uses 3.3B parameters
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
## Running some example
```bash
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
```
This produces the following output which actually doesn't generate the fibonacci
series properly.
```
def fibonacci(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
assert type(n) == int, "n must be an integer"
if (type(fib_list)==None or len==0 ):
fib_list = [1]
for i in range((len-2)): # start at 2nd element of list and go until end.
n += 1
print("Fibonacci number",n,"is:",i)
def main():
"""Call the functions."""
userInput=input('Enter a positive integer: ')
fibonacci(userInput)
if __name__ == '__main__': # only run if this file is called directly.
print("This program prints out Fibonacci numbers.")
main()
```

View File

@ -139,7 +139,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)] #[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long)]

View File

@ -103,23 +103,25 @@ impl GroupedQueryAttention {
(k, v) (k, v)
} }
}; };
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?; self.kv_cache = Some((key.clone(), value.clone()));
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?; let query = query.contiguous()?;
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
let attn_bias = { let attn_bias = {
let s_q = query.dim(D::Minus2)?; let s_q = query.dim(D::Minus2)?;
let s_k = key.dim(D::Minus1)?; let s_k = key.dim(D::Minus1)?;
let (_, _, a_q, a_k) = self.attn_bias.dims4()?; let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
self.attn_bias let start_q = a_q.saturating_sub(s_q);
.narrow(2, a_q - s_q, s_q)? let start_k = a_k.saturating_sub(s_k);
.narrow(3, a_k - s_k, s_k)? self.attn_bias.i((.., .., start_q.., start_k..))?
}; };
let attn_weights = (attn_weights + attn_bias)?; let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
let attn_weights = match mask { let attn_weights = match mask {
None => attn_weights, None => attn_weights,
Some(mask) => masked_fill( Some(mask) => masked_fill(
&attn_weights, &attn_weights,
&mask.broadcast_left(b_size * self.n_heads)?, &mask.broadcast_as(attn_weights.shape())?,
f32::NEG_INFINITY, f32::NEG_INFINITY,
)?, )?,
}; };
@ -128,7 +130,8 @@ impl GroupedQueryAttention {
.matmul(&value)? .matmul(&value)?
.transpose(1, 2)? .transpose(1, 2)?
.flatten_from(D::Minus2)?; .flatten_from(D::Minus2)?;
attn_output.apply(&self.out_proj) let out = attn_output.apply(&self.out_proj)?;
Ok(out)
} }
} }
@ -199,7 +202,7 @@ impl MPTBlock {
let xs = self.attn.forward(&xs, mask)?; let xs = self.attn.forward(&xs, mask)?;
let xs = (xs + residual)?; let xs = (xs + residual)?;
let residual = &xs; let residual = &xs;
let xs = xs.apply(&self.norm2)?.apply(&self.ffn); let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
xs + residual xs + residual
} }
} }
@ -275,12 +278,15 @@ impl Model {
Some(get_mask(seq_len, xs.device())?) Some(get_mask(seq_len, xs.device())?)
}; };
for block in self.blocks.iter_mut() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())? xs = block.forward(&xs, mask.as_ref())?;
} }
xs.narrow(1, seq_len - 1, 1)? let xs = xs.apply(&self.norm_f)?;
let logits = xs
.narrow(1, seq_len - 1, 1)?
.squeeze(1)? .squeeze(1)?
.matmul(&self.wte.embeddings().t()?)? .matmul(&self.wte.embeddings().t()?)?
.squeeze(1) .squeeze(1)?;
Ok(logits)
} }
} }