Bendi新闻
>
进我的收藏夹吃灰吧:大模型加速超全指南来了
进我的收藏夹吃灰吧:大模型加速超全指南来了
9月前
最近,一位名为 Theia Vogel 的博主整理撰写了一篇长文博客,对加速 LLM 推理的方法进行了全面的总结,对各种方法展开了详细的介绍,值得 LLM 研究人员收藏查阅。
以下是博客原文内容。
之前,我使用经典的自回归采样器手动制作了一个 transformer,大致如下:
def generate(prompt: str, tokens_to_generate: int) -> str:
tokens = tokenize(prompt)
for i in range(tokens_to_generate):
next_token = model(tokens)
tokens.append(next_token)
return detokenize(tokens)
这种推理方法很优雅,是 LLM 工作机制的核心。自回归 LLM 在只有数千个参数的情况下运行得很好,但对于实际模型来说就太慢了。为什么会这样,我们怎样才能让它更快?
为什么简单推理这么慢?
Time to First Token(TtFT)—— 收到 prompt 和返回第一个 token 之间需要多长时间? 生成延迟 —— 收到 prompt 和返回最终 token 之间需要多长时间? 吞吐量 硬件利用率 —— 我们使用硬件的计算、内存带宽和其他功能的效率如何?
硬件
def foo(x):
s = torch.sin(x)
c = torch.cos(x)
return s + c
"trace.enabled": True, "trace.graph_diagram": True}) > compiled_foo = torch.compile(foo, options={
# call with an arbitrary value to trigger JIT >
10))) > compiled_foo(torch.tensor(range(
Writing FX graph to file: .../graph_diagram.svg
[2023-11-25 17:31:09,833] [6/0] torch._inductor.debug: [WARNING] model__24_inference_60 debug trace: /tmp/...zfa7e2jl.debug
tensor([ 1.0000, 1.3818, 0.4932, -0.8489, -1.4104, -0.6753, 0.6808, 1.4109,
0.8439, -0.4990])
extern "C" void kernel(const long* in_ptr0,
float* out_ptr0)
{
{
for(long i0=static_cast<long>(0L); i0<static_cast<long>(10L); i0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i0)];
auto tmp1 = static_cast<float>(tmp0);
auto tmp2 = std::sin(tmp1);
auto tmp3 = std::cos(tmp1);
auto tmp4 = tmp2 + tmp3;
out_ptr0[static_cast<long>(i0)] = tmp4;
}
}
}
10_000, 10_000)) > x = torch.rand((
> %timeit foo(x)
246 ms ± 8.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit compiled_foo(x)
91.3 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# (for small inputs `compiled_foo` was actually slower--not sure why)
10_000, 10_000)) > x = torch.rand((
> %timeit foo(x)
246 ms ± 8.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit compiled_foo(x)
91.3 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# (for small inputs `compiled_foo` was actually slower--not sure why)
"trace.enabled": True, "trace.graph_diagram": True}) compiled_gbreak = torch.compile(gbreak, options={
10))) compiled_gbreak(torch.tensor(range(
Writing FX graph to file: .../model__27_inference_63.9/graph_diagram.svg
[2023-11-25 17:59:32,823] [9/0] torch._inductor.debug: [WARNING] model__27_inference_63 debug trace: /tmp/torchinductor_user/p3/cp3the7mcowef7zjn7p5rugyrjdm6bhi36hf5fl4nqhqpfdqaczp.debug
Writing FX graph to file: .../graph_diagram.svg
[2023-11-25 17:59:34,815] [10/0] torch._inductor.debug: [WARNING] model__28_inference_64 debug trace: /tmp/torchinductor_user/nk/cnkikooz2z5sms2emkvwj5sml5ik67aqigynt7mp72k3muuvodlu.debug
tensor([ 1.0000, -0.1756, 2.6782, -0.7063, -2.5683, 2.7053, 0.9718, 0.5394,
7.6436, -0.0467])
extern "C" void kernel(const long* in_ptr0,
float* out_ptr0,
float* out_ptr1,
bool* out_ptr2)
{
{
{
float tmp_acc0 = 0;
for(long i0=static_cast<long>(0L); i0<static_cast<long>(10L); i0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i0)];
auto tmp1 = static_cast<float>(tmp0);
auto tmp2 = std::sin(tmp1);
auto tmp3 = std::cos(tmp1);
auto tmp4 = tmp2 + tmp3;
out_ptr0[static_cast<long>(i0)] = tmp4;
tmp_acc0 = tmp_acc0 + tmp4;
}
out_ptr1[static_cast<long>(0L)] = tmp_acc0;
}
}
{
auto tmp0 = out_ptr1[static_cast<long>(0L)];
auto tmp1 = static_cast<float>(0.0);
auto tmp2 = tmp0 < tmp1;
out_ptr2[static_cast<long>(0L)] = tmp2;
}
}
extern "C" void kernel(const float* in_ptr0,
const long* in_ptr1,
float* out_ptr0)
{
{
for(long i0=static_cast<long>(0L); i0<static_cast<long>(10L); i0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i0)];
auto tmp1 = in_ptr1[static_cast<long>(i0)];
auto tmp2 = static_cast<float>(tmp1);
auto tmp3 = std::cos(tmp2);
auto tmp4 = tmp0 - tmp3;
out_ptr0[static_cast<long>(i0)] = tmp4;
}
}
}
# get an explanation for a given input
>>> explained = torch._dynamo.explain(gbreak)(torch.tensor(range(10)))
# there's a break, because of a jump (if) on line 3
>>> explained.break_reasons
[GraphCompileReason(reason='generic_jump TensorVariable()', user_stack=[<FrameSummary file <stdin>, line 3 in gbreak>], graph_break=True)]
# there are two graphs, since there's a break
>>> explained.graphs
[GraphModule(), GraphModule()]
# let's see what each graph implements, without needing to dive into the kernels!
>>> for g in explained.graphs:
... g.graph.print_tabular()
... print()
...
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ------------ --------
placeholder l_x_ L_x_ () {}
call_function sin <built-in method sin of type object at 0x7fd57167aaa0> (l_x_,) {}
call_function cos <built-in method cos of type object at 0x7fd57167aaa0> (l_x_,) {}
call_function add <built-in function add> (sin, cos) {}
call_method sum_1 sum (add,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((add, lt),) {}
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ----------- --------
placeholder l_x_ L_x_ () {}
placeholder l_r_ L_r_ () {}
call_function tan <built-in method tan of type object at 0x7fd57167aaa0> (l_x_,) {}
call_function sub <built-in function sub> (l_r_, tan) {}
output output output ((sub,),) {}
# pretty cool!
批处理
20 tokens x 1 sequence = ~70ms 20 tokens x 5 sequences = ~220ms (线性扩展~350ms) 20 tokens x 10 sequences = ~400ms (线性扩展~700ms)
>>> gpt2.transformer.h[0].attn.c_attn.weight.dtype
torch.float32
KV cache
# the gpt2 tokenizer produces 3 tokens for this string
" A B C").input_ids > tokens = tokenizer(
> tokens
[317, 347, 327]
# if we put that into the model, we get 3 rows of logits
> logits = gpt2(input_ids=torch.tensor(tokens)).logits.squeeze()
> logits.shape
torch.Size([3, 50257])
# and if we argmax those, we see the model is predicting a next token
# for _every_ prompt token!
1)): > for i, y in enumerate(logits.argmax(-
... print(f"{tokenizer.decode(tokens[:i+1])!r} -> {tokenizer.decode(y)!r}")
' A' -> '.'
' A B' -> ' C'
' A B C' -> ' D'
tokens
[317, 347, 327] # the " A B C" string from before
key_values = gpt2(input_ids=torch.tensor(tokens)).past_key_values
for x in t) for t in key_values) tuple(tuple(x.shape
((torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])))
需要预先分配比所需更多的空间; 该保留空间不能被其他请求使用,即使还不需要它; 具有相同前缀的请求不能共享该前缀的 KV 缓存。
猜测解码
>>> for i, y in enumerate(logits.argmax(-1)):
... print(f"{tokenizer.decode(tokens[:i+1])!r} -> {tokenizer.decode(y)!r}")
' A' -> '.'
' A B' -> ' C'
' A B C' -> ' D'
def generate(prompt: str, tokens_to_generate: int) -> str:
tokens: list[int] = tokenize(prompt)
TO = tokenize(" going to")
for i in range(tokens_to_generate):
if tokens[-1] == GOING:
# do our speculative decoding trick
logits = model.forward(tokens + [TO])
# the token the model predicts will follow "... going"
going_pred = argmax(logits[-2, :])
# the token the model predicts will follow "... going to"
to_pred = argmax(logits[-1, :])
if going_pred == TO:
# if our guess was correct, accept "to" and the next token after
tokens += [TO, to_pred]
else:
# otherwise, accept the real next token
# (e.g. "for" if the true generation was "going for broke")
tokens += [going_pred]
else:
# do normal single-token generation
logits = model.forward(tokens)
tokens += [argmax(logits[-1])]
return detokenize(tokens)
def generate(prompt: str, tokens_to_generate: int, n_draft: int = 8) -> str:
tokens: list[int] = tokenize(prompt)
for i in range(tokens_to_generate):
# generate `n_draft` draft tokens in the usual autoregressive way
draft = tokens[:]
for _ in range(n_draft):
logits = draft_model.forward(draft)
draft.append(argmax(logits[-1]))
# run the draft tokens through the oracle model all at once
logits = model.forward(draft)
checked = logits[len(tokens) - 1 :].argmax(-1)
# find the index of the first draft/oracle mismatch—we'll accept every
# token before it
# (the index might be past the end of the draft, if every draft token
# was correct)
n_accepted = next(
idx + 1
for idx, (checked, draft) in enumerate(
# we add None here because the oracle model generates one extra
# token (the prediction for the last draft token)
draft[len(tokens) :] + [None])
)
if checked != draft
)
n_accepted]) :
return detokenize(tokens)
def speculative_threshold(
prompt: str,
max_draft: int = 16,
threshold: float = 0.4,
threshold_all_correct_boost: float = 0.1,
:
tokens = encoder.encode(prompt)
# homegrown KV cache setup has an `n_tokens` method that returns the length
# of the cached sequence, and a `truncate` method to truncate that sequence
# to a specific token
model_kv = gpt2.KVCache()
draft_kv = gpt2.KVCache()
while True:
# generate up to `max_draft` draft tokens autoregressively, stopping
# early if we fall below `threshold`
draft = tokens[:]
drafted_probs = []
for _ in range(max_draft):
logits = draft_model.forward(draft[draft_kv.n_tokens() :], draft_kv)
next_id = np.argmax(logits[-1])
next_prob = gpt2.softmax(logits[-1])[next_id]
if not len(drafted_probs):
drafted_probs.append(next_prob)
else:
* drafted_probs[-1])
draft.append(int(next_id))
if drafted_probs[-1] < threshold:
break
n_draft = len(draft) - len(tokens)
# run draft tokens through the oracle model
logits = model.forward(draft[model_kv.n_tokens() :], model_kv)
checked = logits[-n_draft - 1 :].argmax(-1)
n_accepted = next(
idx + 1
for idx, (checked, draft) in enumerate(
draft[len(tokens) :] + [None])
)
if checked != draft
)
yield from checked[:n_accepted]
n_accepted]) :
if n_accepted <= n_draft:
# adjust threshold towards prob of last accepted token, if we
# ignored any draft tokens
threshold = (threshold + drafted_probs[n_accepted - 1]) / 2
else:
# otherwise, lower the threshold slightly, we're probably being
# too conservative
threshold -= threshold_all_correct_boost
# clamp to avoid pathological thresholds
threshold = min(max(threshold, 0.05), 0.95)
# don't include oracle token in kv cache
- 1)
- 1)
扫描二维码添加小助手微信
关于我们
微信扫码关注该文公众号作者
来源:机器学习算法与自然语言处理
相关新闻
大模型推理速度飙升3.6倍,「美杜莎」论文来了,贾扬清:最优雅加速推理方案之一会颠勺的国产机器人来了:大模型加持,家务能力满分Meta无限长文本大模型来了:参数仅7B,已开源马斯克大模型Grok1.5来了:推理能力大升级,支持128k上下文国内首个开源千亿参数MoE大模型来了!性能超Grok-1,单张GPU可跑防暑降温指南来了 独立日全场大促 低至5折再整单减$20重磅!新西兰减税政策今日生效,超全解答来了!NZ总理:我减的税全捐了!最高飙升$130万!大量NZ学区房逆市猛涨!超全分析来了!@所有新西兰华人:看看有你家吗?下一代 RAG 技术来了!微软正式开源 GraphRAG:大模型行业将迎来新的升级?重磅!新西兰房市大地震:超严贷款新规要来了!?房价将大变?五种资源类别,如何提高大语言模型的资源效率,超详细综述来了中文创意写作能力超GPT-4,「最会写」的中文大模型Weaver来了超强阵容集结!小红书大模型论文分享会来了,四大国际顶会作者强势来袭【收藏】万人挤爆悉尼CBD!Vivid Sydney强势来袭!光影艺术、音乐大秀、港湾大桥...最全攻略来了!2024傅盛开年大课:企业“私有化大模型的时代”来了?锅来了:外国人咖啡价格暴涨,全怪中国人爱吃榴莲今年北京可能650进不了顶尖公办国际部?最全帝都国际化高中择校指南来了!日均tokens使用量超5000亿,AI生图玩法猛猛上新:豆包大模型为什么越来越「香」了?苹果开源7B大模型,训练过程数据集一口气全给了,网友:开放得不像苹果教程来了!3分钟教你搭建:AI大模型前端界面周鸿祎向李彦宏“开炮”:有些名人胡说八道别被忽悠了;全球最强开源大模型Llama 3发布:最大模型参数将超4000亿丨AI周报技术吃瓜:金句还是鸡汤,我们用大模型训练了一个报警器吵翻天!全网群嘲澳洲大学“充多了”!2025QS世界大学排名公开!墨大、悉大力压清华,世界TOP20?网友:表情包来了...神秘大模型一夜“征服”所有人,超GPT-4却无人认领?网友:OpenAI 要有大麻烦了