FT-Mamba:一种高效的表回归的新深度学习模型

张开发
2026/4/11 4:51:16 15 分钟阅读

分享文章

FT-Mamba:一种高效的表回归的新深度学习模型
论文总结1、这篇论文主要是将FT-Transformer的功能和Mamba的高效性相结合。最近这几年mamba确实挺火的像Transformer是On2的时间复杂度Mamba模型能够降至On线性时间复杂度具有这方面优势。而且之前也有研究提出了FT-Transformer。2、作者通过随机增强、数据平衡和自蒸馏来提升模型的性能。将表格数据分为了数值型、分类型和范畴型三种。3、数据自蒸馏是指 对批次中的每个输入数据x创建两个不同的增强输入x1和x2标签相同。使用曼巴模型从每个增强输入中提取全局特征生成fMambax1和fM ambax2。 特征连接 fMambax1 和 fMambax2z ConcatfMambax1 fMambax2。在训练过程中最小化z1和z2之间的差异以鼓励模型学习更稳健的表示。 通过采用自蒸馏该方法有效提升了模型处理复杂序列建模任务的能力利用增强输入提升鲁棒性和性能。摘要深度学习在图像、音频和文本方面取得了进步但表格数据中异质性和相关性较弱等挑战依然存在。本文介绍了FT-Mamba架构该架构将FT-Transformer的功能与Mamba的高效性结合实现高效且可扩展的顺序数据处理。它还通过随机增强、数据平衡和自蒸馏来提升性能。该工作为表式数据的深度学习提供了新颖的见解和工具提升了其在现实世界中的实际应用。引言在当今以数据为中心的世界里数据在分析和决策中至关重要推动了制造业、医疗和金融等行业的发展[1]。表格数据的结构性质使其成为跨多个领域的多功能工具。图像、音频和文本深度学习的进步带来了卓越的人工智能能力促使人们探索将这些技术应用于表格数据。然而表格数据的异质性包括密集的数值特征和稀疏的类别特征以及弱的特征相关性给分析和模型构建带来了独特挑战。此外缺乏统一的基准测试和标准化的表格数据深度学习实验使得实际应用中的性能评估更加复杂。本研究综合了当前表格数据深度学习模型的最新进展并提出了基于结构化状态空间模型SSM[9]-[11]的创新Mamba模型。Mamba模型凭借其独特的架构解决了传统深度学习模型在处理表格数据时面临的效率和可扩展性问题。它利用SSM的优势处理长序列数据同时保持线性时间复杂度使Mamba在处理表格数据时展现出前所未有的效率。Mamba模型通过简化的架构促进端到端学习避免了复杂的注意力机制MLP。这种设计最大限度地减少了模型的复杂性和参数数量使其轻量化且易于部署。它高效处理数据为表格数据的深度学习带来了突破。我们的贡献包括1开发一种结合FTTransformer优势与Mamba序列处理效率的FT-mamba架构[3]。2引入随机增补方法以处理表格数据的异质性增强模型泛化性。3实施数据均衡和自蒸馏以解决阶级不平衡并提升性能。总之我们的工作不仅为深度学习领域带来了新的视角和工具也为我们的工作提供了新思路。我们的目标是在保持模型简洁的同时进一步提升模型的泛化能力和执行效率相关工作Mamba模型Mamba是状态空间模型SSM的一种新成功变体已被证明在拥有多达百万条长度序列的真实数据上实现更高性能[9]-[11]。凭借其高计算效率和远程能力它迅速引起了关注在视觉任务和语言任务领域已有许多尝试[12]-[17]。Vim [13] 是首次将 SSM 应用于强化视觉任务并将 SSM 引入通用视觉骨干克服了 Mamba 在单向建模和位置感知方面的局限。VMamba [14] 通过一种新颖的交叉扫描技术简化了从一维扫描到二维扫描的转换。U-Mamba [15] 将 Mamba 模块纳入任务特定架构展示了 SSM 在多样化视觉任务中的有效性构建基于U-net的基础[16]。Jamba 采用了混合 Transformer 和 Mamba 架构结合了 MoE 模块在保持高吞吐量的同时灵活平衡性能和内存需求[17]。本研究表明Mamba模型同样可以成功应用于表格问题。基于深度学习的表格数据模型对表格数据的深度学习研究可以分为几个关键领域以应对其固有的挑战。这些领域包括处理低质量数据和数据缺陷的方法如插值、清洗、归一化和增强以提升数据质量并确保输入的可靠性[18]-[19]。特征表示技术包括编码、构造和降维对于将原始表格数据转换为适合深度学习模型的形式至关重要。模型架构的创新如混合模型和基于注意力的方法旨在更好地捕捉和利用数据中的复杂关系[5]-[7]。正则化和优化策略包括高级优化算法如L1/L2正则化和Adam对于防止过拟合和提升训练效率至关重要[20]。此外增强对噪声、不一致性的鲁棒性以及整合多模态数据以利用不同数据源的方法对于构建具有弹性和适应性的模型至关重要。最后有效的评估指标和方法的开发确保了模型在表格数据上的表现得到准确评估和比较。在我们的实验中我们展示了所提出模型在数据质量提升和模型优化方面的优势。方法在本节中我们介绍了FT-MambaFeature TokenizeMamba这是一种专门为解决表格数据挑战而设计的新型随机增强Mamba模型。我们还将深入探讨其关键超参数这些参数对模型性能至关重要。我们的讨论将从以下三个关键方面来探讨该模型。数据增强图1随机增强深度学习依赖于丰富且高质量的数据来识别预测性关系。数据稀缺性是一个重大挑战尤其是在数据有限或昂贵时。数据增强在这种情况下至关重要因为它从现有来源生成新数据增强多样性并可能提升模型性能。它可以分为离线训练前和在线训练中两种方法每种方法都有其优势比如控制性和灵活性。与图像或文本数据不同表格数据缺乏空间相关性关系由列值决定。这对设计机器学习模型非常重要。本文提出了一种在线表格数据增强的方法将数据视为矩阵并利用掩蔽操作创建新实例。我们区分了跨列跨列和列内列内增强使用SCARF、VIME和随机遮罩等方法[7][8]。所选算法通过参数p1和p2结合随机性控制每个操作的概率。FT Mamba在本节中我们将介绍我们工作中强调的FT-Mamba设计以及现有的比较解决方案。考虑一个表格数据样本{xi yi}n i1其中xi代表第i个对象的特征yi代表其对应标签。本文重点关注Y R的回归任务。该对象的特征 xi 包含数值、二元和范畴特征二元和范式特征都被视为范式特征。该数据集分为三个不同的子集Dtrain、Dval 和 Dtest。训练子集Dtrain用于模型训练验证子集Dval用于早期停止和超参数调优测试子集Dtest保留用于最终评估。图2展示了FT-Mamba的主要架构。在特征分词器部分模块将输入 X 转换为嵌入序列T ∈ Rk∗d。它包括初始化一个分词器对象将输入数据转换为以矩阵表示的嵌入附加一个特殊的令牌[CLS]并从该令牌中汇总信息以进行分类或回归任务。这与基于变换器的模型中的标准做法一致其中[CLS]令牌在下游分类或回归任务中代表输入序列时起着关键作用。MambaBlock是Mamba模型的重要组成部分Mamba是一种受门控MLP和H3架构启发的复合神经网络块[9]。Mamba简化了H3的SSM模块和Transformers的MLP模块的组合集成了线性投影、卷积、激活函数、选择性状态空间层和残差连接。该结构通过多次变换处理输入序列高效捕捉相关模式和特征。给定输入 x x1 ... xlMambaBlock 的核心操作可表示为图2FT-Mamba回归问题方案其中⊗表示元素乘法。Mamba架构通过堆叠L个MambaBlock利用状态空间模型SSM将注意力复杂度从二次型降低到线性从而增强了训练和推理。因此Mamba在序列长度上以线性时间运行能够有效地捕捉跨扩展序列的相关信息数据蒸馏由于为所有任务寻找合适的教师模型存在挑战以及由于数据稀缺导致部分大型模型难以训练许多研究者提出了自学策略[4]。本质上这种方法使模型能够作为自己的导师通过自我提炼来优化自身。我们旨在提升类别内的稳健性并通过自蒸馏共同从同一类别样本中学习我们将以下方法应用于表格数据。• 对批次中的每个输入数据x创建两个不同的增强输入x1和x2标签相同。• 使用曼巴模型从每个增强输入中提取全局特征生成fMambax1和fM ambax2。• 特征连接 fMambax1 和 fMambax2z ConcatfMambax1 fMambax2。• 在训练过程中最小化z1和z2之间的差异以鼓励模型学习更稳健的表示。 通过采用自蒸馏该方法有效提升了模型处理复杂序列建模任务的能力利用增强输入提升鲁棒性和性能。优化我们的方法采用了一种损失函数旨在优化学生模型通过整合来自真实标签的直接督导和通过教师模型的知识蒸馏间接监督如图3所示。该损失函数主要包含两个组成部分均方误差MSE损失和蒸馏损失均采用加权焦点MSE损失进行测量[21]。均方误差MSE损耗成分衡量了模型的预测xˆ1和真实标签y1可以定义如下[21]。其中 e |xˆ − y|是绝对误差xˆ 表示预测y 表示目标。权重用w表示如果不提供则默认为1β和γ是标量参数。激活函数fz定义为该损失函数通过应用缩放激活函数和可选权重来修改均方误差从而强调更难预测的样本。蒸馏损失成分Ldistillxˆ1 xˆ2衡量使用相同加权焦点MSE损失时学生模型预测与教师模型预测之间的差异。该组成部分引导学生模型从教师的预测中学习利用精炼知识提升学生的表现和泛化能力。网络的全损耗函数将这两个组成部分结合为一个单一的优化目标由参数α调节。α平衡了真实标签的直接监督与教师模型的指导促进了拟合真实数据与纳入教师知识之间的最佳权衡。学生模型的整体优化目标可以表达如下该目标函数通过最小化MSE损失和蒸馏损失来指导训练过程。它在贴合真实标签与利用教师模式提炼知识之间取得了最佳平衡。对于任何指定的网络架构权重梯度可高效计算便于通过标准随机梯度下降SGD过程优化网络参数。其中 Wij 是网络权重矩阵 W 中的一个元素。这种方法确保模型在捕捉教师模型中提炼知识的同时有效适应训练数据从而提升表现和泛化能力。实验与结果数据集为了评估拟议FT-Mamba模型的性能我们利用多样化的基准数据集对现有替代方案进行了大量实验。图3。FT-Mamba训练的自蒸馏机制为了系统评估我们方法的有效性我们使用了多样化的公开数据集集合。具体来说我们使用了5个真实世界数据集这些数据集在大小、性质、特征数量及其分布上各不相同。这些数据集中许多此前已被[3]用于表格模型评估详见表2。数据集的详细信息和缩写见表1。我们还将所有数据集分为训练70%、验证10%和测试20%三类以确保对模型泛化和最终性能的全面评估Baseline为了对FT-Mamba模型进行全面基准测试我们特别将其与标准方法和当前最先进方法进行比较。这些包括MLP [27]、DKL [24]、ResNet [26]、FT-Transformer。超参数调优为了微调FT-Mamba模型及其基线模型我们采用了Optuna优化库详见[22]。在我们的实验中Optuna用于贝叶斯超参数优化。在模型训练阶段我们选择了AdamW优化器因为它通过无缝将权重衰减正则化融入优化过程有效减少了过拟合。在整个训练过程中保持恒定的学习率以确保稳定性所有算法的批量大小均为统一预定。此外为防止过拟合和检测收敛情况实施了早期停止机制将耐心阈值设为16个日历小时且一旦验证集无改善则停止训练。这种策略对于保持模型的普遍性、避免不必要的计算开支至关重要。我们的调优策略指标来自五种不同采样种子的平均得分以确保其稳健性。表II列出了调优FT-Mamba与不同随机种子的平均测试结果。如表二所示FT-Mamba在这些特定数据库和评估指标上表现优于其他模型。消融实验在本分析中我们通过比较FT-Mamba框架下的架构选择将其与FTTransformer进行比较后者是一种先进的表式深度学习模型能够将特征处理到嵌入中并采用自关注机制[23]。此外我们还研究了特征分词器中分位数归一化和随机增广以及数据蒸馏中关键的MSE损失对模型性能的影响。五项试验的汇总结果与表I中的“BR”数据一致见表III。这些结果凸显了曼巴骨架的优越性及其组成部分的关键作用。FT-M的RMSE显著下降至0.0757C指数提升至0.6334模型规模较FT-T减少了48.03%。C指数略有上升RMSE略有下降表明自蒸馏具有适度的正向效果。C指数略有上升RMSE略有下降表明随机增高有小但积极的影响。最显著的变化是C指数显著上升RMSE略有下降凸显了数据平衡对模型性能的重要性。消融研究显示虽然自蒸馏和随机增强能带来轻微的性能提升但数据平衡是FTMamba模型有效性不可或缺的组成部分。移除数据平衡后显著的性能提升几乎抵消了其他优化带来的好处凸显了其在模型成功中的关键作用。通过系统评估每个组成部分我们的研究展示了FT-Mamba在应对各种表格数据挑战方面的稳健性和高效性。总结总之FT-Mamba模型为复杂的表格数据深度学习任务提供了高效且简化的解决方案。通过利用结构化状态空间模型和简化架构的力量我们的方法在多种数据集上展现出更优的性能。它显著降低了处理表格数据的复杂性和计算成本从而促进了依赖此类数据的各个领域的进步。然而表数据深度学习领域也存在局限和挑战如高质量数据集稀缺以及需要更细粒度的特征表示。这些挑战凸显了进一步探索和创新的空间。

更多文章