CVPR 2024 | 迁移预训练生成模型中的知识到边缘小模型
©PaperWeekly 原创 · 作者 | 张剑清
单位 | 上海交通大学、清华大学(AIR)
研究方向 | 联邦学习
我们称这一过程为“知识迁移链路(KTL)”,并在每一轮联邦学习过程中进行一次知识迁移。此过程中生成模型只作推理不做训练,这种利用预训练模型的方式对资源的需求较少。
论文标题:
An Upload-Efficient Scheme for Transferring Knowledge From a Server-Side Pre-trained Generator to Clients in Heterogeneous Federated Learning
https://arxiv.org/abs/2403.15760
https://github.com/TsingZ0/FedKTL(含有PPT和Poster)
https://github.com/TsingZ0/PFLlib
运行实验所需仓库-异构联邦学习算法库:
https://github.com/TsingZ0/HtFLlib
动机
随着新一轮 AI 时代的到来,模型的量级越来越大,对数据的需求也越来越大。不论是哪个领域,有效且高质量的数据一直是一种稀缺存在,甚至成为了一种数据资产。与此同时,互联网上也广泛存在着能力强劲的开源生成模型。如果能够利用这些生成模型中存储的知识,来生成具体任务所需的数据,便可以让小公司和边缘智能设备都能享受到大模型带来的丰富成果。
为了解决数据稀缺问题,有以下四种常见途径:
1. 利用来自公开数据集的数据,但这类数据很难做到与具体任务相关,且任务无关数据甚至会产生负面影响;
2. 由数据请求方上传数据生成需求(比如分类任务中的标签语义)到云端的生成模型来生成数据,但这种文本很容易导致隐私问题;
3. 利用云端大模型随机生成无标签数据,但这样做依旧存在与利用公开数据集同样的问题,甚至无标签数据的引入增加了模型训练的难度;
4. 利用预训练模型引入额外知识,但适用于具体任务的预训练模型稀少且其中的额外知识不一定匹配当前任务。
换言之,不管是引入额外数据还是引入额外知识,都尽量需要与当前任务相关,才能最大限度地起到正面作用。
之后,我们将该任务相关知识作为输入,传递给预训练生成模型,并针对当前任务做了域对齐,从而生成任务相关的数据。为了有效利用该数据,我们将其传输到联邦学习参与方后,运行一个额外的有监督任务实现知识迁移。
异构联邦学习技术
传统联邦学习考虑了数据异质性,但依旧要求所有参与方训练同一个架构的模型,增加了寻找相似任务参与方的难度。于是我们考虑取消这一点要求,允许参与方采用各自的模型进行知识共享。然而,这样一来,传统联邦学习中基于模型参数共享的范式不再可用,对新型的知识共享机制提出了要求。其中包括:1)保护隐私,2)保护知识产权,3)轻量化,4)易于获得等。
▲ 图1:异构联邦学习技术
目前异构联邦学习技术还未形成统一的知识共享机制,我们考虑一种轻量化且不需要额外数据的知识共享机制:共享 prototype。本文考虑的是面向图像的多分类任务,其 prototype 的定义就是每个类别的代表性特征向量,可通过平均该类所有的特征向量获得。我们将 prototype 当作共享知识,输入到生成模型后得到相应图片数据,并将图片-向量对(image-vector pairs)传回参与者,如下图。
知识迁移链路(KTL)
上一节的最后已经简单描述了我们提出的知识迁移链路(KTL),但省略了很多细节,这里我们对重点步骤进行展开(其他步骤及细节详见论文)。下图是我们的整体框架,其中最重要的是步骤 3 和步骤 6。
步骤3:当我们在生成模型的特征空间采样时,可以生成清晰图像,但这样的图像并非任务相关。根据我们的实验观察,如果直接将参与方上传的 prototype 输入到预训练生成模型,由于参与方模型的特征空间和生成模型的特征空间不匹配(通常连维度都不一致),导致生成的图像跟随机输入一样模糊不清。
所以我们需要先将 prototype 映射到高维的生成模型特征空间,并保证这些 prototype 依旧是任务相关的。因为我们考虑的是分类问题,任务相关指的就是 prototype 映射后得到的特征向量依旧保持类别可分离特性。我们称这一过程为域对齐(domain alignment),如下图可见,对齐后的特征向量可以使生成模型产生清晰图片。
▲ 图4:生成模型在不同输入下得到的图片
为了实现域对齐,我们在服务器端额外训练了一个轻量化的特征转换器(F),并定义其训练目标为对齐特征空间(使用 MMD 损失)和保证类别可分离性(使用 MSE 损失)。
▲ 图5:域对齐实现方案
▲ 图6:域对齐的一个例子。这是一个三分类任务,其中参与方模型特征空间维度为 2,生成模型特征空间维度为 3,W 是生成模型的有效特征空间。
▲ 图6:参与方本地数据样本和四种用于生成模型预训练的数据集样本。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
微信扫码关注该文公众号作者