基于互补学习系统的时空预测模型,实现时空预测模型自适应进化
论文标题:
ComS2T: A complementary spatiotemporal learning system for data-adaptive model evolution
https://arxiv.org/pdf/2403.01738.pdf
https://github.com/hqh0728/ComS2T
动机与引言
▲ 图1 城市数据随着时间演变产生的时空偏移现象
在这项工作中,我们尝试以互补的视角动态更新模型参数,以解决时空分布偏移问题,实现数据自适应的模型进化。然而,鉴于时空数据具备复杂的依赖关系并且与环境因素之间存在相互作用,设计互补时空学习系统还面临以下挑战:
首先,如何将复杂的时空学习器(ST learner)与互补学习无缝耦合在统一而高效的框架中,即如何在给定 ST Learner 的情况下,有效地识别稳定的新皮质神经模块和动态的海马体结构,分别用于可泛化性和模型更新? 其次,如何以整体视角共同建模时空观测和环境特征,恰当地处理未见数据以使海马体结构适应新环境? 最后,如何设计训练策略以同时保留历史信息并在新模式上赋予模型在线更新能力,同时减少计算资源的消耗?
因此,我们的互补学习赋予了模型在训练和测试阶段的演进能力。沿着训练过程,我们的 CLS 能够同时利用新皮质层保留历史信息并允许海马体灵活地更新网络适应新数据。在测试过程中,利用有限测试数据开展自监督训练空间-时间提示,进而将模型更新推广到测试阶段,进一步促进了模型的适应性。我们的贡献总结如下:
这是将神经科学中的互补学习与时空模型相结合实现泛化和数据适应的首次尝试,通过两个保留良好的变化矩阵设计了高效的神经架构分解。 提出了一种自监督的提示训练方法,用于建立环境因素与主要观测分布之间的关系,不仅允许神经网络微调提示,而且使模型参数敏感于数据分布的动态性和演变。 我们的框架可以同时处理空间和时间方面的分布变化,并构建了四种 OOD 场景来模拟模型验证的数据适应。实验表明,我们的 ComS2T 可以在时间偏移(数据分布偏移)时将性能提高 0.73% 至 20.70%,同时在结构偏移时提升 0.36% 至 17.30%。
方法介绍
▲ 图2 ComS2T 模型框架图
高效的神经网络解耦
自监督时空提示学习
首先基于“提示-回答”设计了一种自监督形式的预训练机制,巧妙地以自监督的形式训练提示表征,并将提示信息传递至互补学习系统中以作为一个条件变量,与输入的主观测叠加来更新“海马结构”。我们显式建模了每个时空域中连续时间序列的观测分布,使得模型能够构建起时空 prompt 与连续时间序列摘要的潜在关联。
对于 prompt 表征学习的监督信号,我们建模了当前连续 个时间步的序列观测,训练得到具有区分性模式的 prompt 表征来预测序列观测分布,~。我们的预测模型可以形式化地表达为
无论在训练还是测试阶段,这种自监督机制均仅可以在获得空间嵌入、时间步和对应观测信息的条件下快速构建 prompt 与数据分布间的关联,使其天然地能够在分布外的场景微调 prompt 表示。同时,这种动态性能够传递至主模型的海马结构中,为回归任务的 test-time training 提供了可能(条件)。
我们将解耦得到的海马体结构用于快速学习新信息,而利用新皮质层来保留时空学习中的稳定信息。具体而言,我们对空间学习模块和时间学习模块进行分别学习,冻结新皮质层、更新海马体层。鉴于上下文生成的 prompt 能够充分表达环境信号,使得模型能够感知到环境变化,因此我们基于时空 prompts来更新海马体层。
首先,我们对 prompt 信号对齐,将 spatial prompt 和 temporal prompt 分别注入到空间、时序学习的表征输入中,使之与时空主观测的输入保持相同维度和融合,并传送至海马体层。用 表示空间 prompt 并用 表示时序 prompt,在 fine-tune 阶段,空间学习层的输入为 。随后我们将冻结皮质层并且更新海马体层,形式化地可以得到 ,同理,我们对时序学习块也进行更新。学习过程可以表示为:
测试数据微调
为了赋予模型自我进化的能力,能够根据数据实际分布进行自我进化和调整,我们在测试阶段对 prompt 信息再次微调。当测试数据到达,我们重新抽样一小批数据来构建 self-supervised learning pairs, 随后优化 spatial and temporal prompts 并更新,我们将新的 prompt 与 X 观测表征相加,输入到时空模型中,以获得具有泛化能力的输出。这一部分的学习过程可以表示为:
实验
数据集
在数据集方面,我们选取了四个典型且不同领域的时空数据集合,两个交通数据集:SIP 和 Metr-LA,一个环境(空气质量)数据集:KnowAir,及一个气象数据集:Temperature。
实验设置
如图所示,我们分别构建了具有明确时序分布偏移和空间结构偏移的实验数据。
首先,时间分布偏移可以通过根据不同数据集上的数据分布特征进行两个训练-测试分割来模拟。
天级别划分:对于 SIP 和 Metr-LA 等动态的交通数据集,一天内的演变模式完全不同。因此,我们通过收集所有相同日期间隔(例如,每天的 8:00-16:00)来组织训练集,而在其他未见过的日期间隔(例如,每天的 1:00-7:00)上进行测试。 月份级别划分:对于相对短期内相对静态但季节性变化的空气质量和气候数据集,我们将全年记录分成四个季度,其中我们使用两个季度进行训练,而在一个季度进行测试。
其次,空间分布偏移是通过引入新节点和移除现有节点实现的。
节点引入:我们在训练期间屏蔽一些现有节点,并在测试阶段将它们重新添加,以模拟图结构的新连接。
节点移除:类似地,我们在测试阶段移除一些现有节点,以模拟动态图结构中的节点消失。
预测结果
▲ 表1 实验结果
实验结果分析
总体而言,我们的 ComS2T 在大多数场景下均取得了与基线模型相比更好的预测性能,在时序数据分布变化下,性能从 0.73% 提高到 20.70%,在结构变化下,性能从 1.19% 提高到 17.30%。我们的四个主要观察如下。
观测1:与传统 ST 学习器比较。虽然传统的 ST 学习器在连续序列预测的设置上显示出令人满意的性能,但在分布变化的情况下仍然存在不足,特别是在两个交通数据集上。MTGNN 和 ST-SSL 揭示了对结构变化的一些稳健性,主要是因为通过节点复制策略可以很好地将可学习的邻接关系转移到新节点上,而逐步和节点自监督信号可能在获得区分模式以进行泛化方面发挥重要作用。因此,SSL 学习的潜在优势在于改进表征,显著改善了在 Metr-LA 数据集上的性能,这也被继承到我们的 ComS2T.
观测2:与不变学习 ST 模型比较。一些模型考虑了跨环境的不变性和可转移性,以抵消时间分布的变化,实证结果表明,捕获不变性确实可以提高 OOD 学习能力,但仍不如我们的 ComS2T。这是由于即使将不变性转移到 OOD 场景,这些方法没有针对模型更新和数据适应构建具体方案。
观测3:与 ST 持续学习比较。对于明确考虑环境变化的预测模型,如 CauSTG、CaST 和 TrafficStream,无论是利用封闭的环境划分和码本,还是利用经验重放对模型进行重新训练,仍不能充分利用现有的环境信息来提高自适应能力。相比之下,我们的 ComS2T 利用了自监督提示和互补学习的优势,通过在主要观测值和环境提示之间建立桥梁来适应空间和时间提示, 因而我们的模型在时序分布偏移下至少提高了 8.17%,在结构变化下提高了 3.16%。
观测4:结构转移下的比较。尽管 CauSTG 考虑了空间转移、PECPM 关注了路网扩张问题,我们的工作通过更新空间提示与新的观察明确建模空间结构背景。结果表明,ComS2T 显著优于 CauSTG 和 PECPM,例如在 SIP 的数据分布偏移下,它比 CauSTG 提高了 3.01%,在 Metr-LA 节点去除场景中下,它比 PECPM 提高了 17.30%。此外,ComS2T 还解决了需多次训练(CauSTG)和模式级匹配(PECPM)的计算效率问题。
综上所述,我们的 ComS2T 将主要在两个方面优于其他基准,即:
1)我们的 ComS2T 不牺牲内存存储来保存新模式,也不牺牲序列级模式匹配的计算负担,我们的 ComS2T 直接解算稳定和动态的神经架构,并在整个训练过程中主动更新神经网络,从而获得较高的效率;
2)我们的 ComS2T 结合了自监督提示与分布重构的优势和互补的学习框架,允许随着新的观测值进行灵活的提示更新,实现了精确的数据适应与时空学习框架模型。
可视化结果
我们从参数演化行为、学习过程误差等方面来更具体地验证 ComS2T 的有效性。
从可视化的结果可以发现:
参数在互补学习优化过程中经历了波动-稳定-波动-稳定的过程,也正表明了互补学习系统对参数学习的有效性;
我们互补学习的视角能够在原有唯一学习路径的基础上进一步补充、迭代地学习新知识,从而获得更小的学习误差;
总结
本工作受神经科学启发,将互补学习的思想与时空预测相结合,提出了基于提示的互补学习系统 ComS2T,赋予模型数据适应和演化能力。我们首先将时空学习神经网络解耦为两个不同的神经结构,通过显式建模可学习权重的训练行为。为了实现高效和自适应的模型演化,引入额外的环境因素,使用空间-时间提示来描述观察数据的分布,并使提示能够通过自监督信号进行学习。
然后,我们可以逐步解耦神经结构,并将信息提示纳入动态海马体进行环境感知的微调。ComS2T 允许在训练阶段基于环境提示进行模型调整,因此,在环境变化时,可将提示微调扩展至测试阶段,从而增强模型对新数据模式的拟合,提升模型进化能力。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
微信扫码关注该文公众号作者