Skip to content

anuin-cat/TNET_TOMM

Repository files navigation

TNET: 基于频率重构的多模态情感分析网络

简介

本项目是一个名为 TNET 的多模态情感分析网络。该工作的最初灵感与部分代码实现源于 2023 年发表在 TMM 上的论文 "EMT-DLFR"。在此基础上,我构建了新的网络结构并进行了实验验证。

数据集说明

本项目使用了两个主流的多模态情感分析视频数据集:MOSIMOSEI。这些数据集已被预处理并保存为 .pkl 格式,包含了音频、视觉和文本三种模态。

  • 音频 (Audio) & 视觉 (Visual):这两种模态已经被提取为序列特征。
  • 文本 (Text):此模态保留了原始的英文文本,在代码中会通过 BERT 模型动态提取其序列特征。

在处理和对比基于这些数据集的论文时,请务必注意以下几个关键点,以确保公平比较:

  1. 文本特征提取器:不同的研究工作可能采用不同的模型来提取文本特征。例如,一些工作使用 BERT,而另一些则可能使用 T5、RoBERTa 等。本项目的文本特征统一使用 BERT 进行提取

  2. 数据对齐 (Alignment):数据存在“对齐 (Aligned)”和“未对齐 (Unaligned)”两个版本。

    • 对齐数据:指音频、视觉、文本三个模态的特征序列长度完全相同,且在时间戳上严格一一对应。
    • 未对齐数据:指模态间的序列长度不相等。
  3. 特征二次处理:原始数据集中,音频和视觉模态已经是序列特征。一些工作会在此基础上,使用 LSTM1D 卷积 等方法对特征序列进行二次处理,以捕捉邻近 Token 之间的依赖关系。这种操作是常见的,进行此类处理的工作与本项目进行对比时,通常被认为是公平的

数据下载

英文数据集 (MOSI & MOSEI)

数据集可以通过原项目 EMT-DLFR 仓库中分享的百度网盘链接下载。请在链接中找到 MMSA 文件夹,并下载其中的 MOSIMOSEI 数据集。

下载解压后,预期的文件结构如下:

MMSA/
├── MOSI/
│   ├── aligned_50.pkl
│   └── unaligned_50.pkl
└── MOSEI/
    ├── aligned_50.pkl
    └── unaligned_50.pkl

中文数据集 (CH-SIMS)

关于其他语言的数据集,目前主流且比较实用的一个是中文数据集:CH-SIMS: A Chinese Multimodal Sentiment Analysis Dataset with Fine-grained Annotation of Modality(论文中常简称 SIMS)。

EMT-DLFR 项目的同一个百度网盘链接中,通常也包含了这个数据集的下载。关于该数据集的使用方法,许多论文和项目都提供了清晰的示例。例如:

  • EMT-DLFR 项目本身就展示了其使用方法。
  • Self-MM 项目 (论文及代码: https://github.com/thuiar/Self-MM) 也提供了完整的加载和处理流程。

您可以参考这些项目来学习如何使用 CH-SIMS 数据集。

网络架构 (TNET)

本代码实现的核心是 TNET 网络,其主要流程如下:

  1. 频率特征重构:代码首先将每个模态的序列特征分解为多个不同频段的特征。随后,借鉴 UNet 的思想,进行跨模态的特征交互与重构。具体来说,一个模态的某个频段特征(例如,音频的低频特征)会分别与另外两个模态的原始模态序列特征(例如,视觉和文本的模态特征)进行 cross-attention 交互,以此来增强和补充信息。

  2. 模态内频率融合:对于每个模态内部分解出的多个频段特征,网络会学习一个权重,用于表示每个频段的重要性。然后,通过加权求和的方式将这些频段特征融合,得到一个信息更丰富、重点更突出的单模态表征。

  3. 序列信息聚合与分类:融合后的每个模态序列特征会经过一个 self-attention 层进行自身信息的深度整合。最后,我们仅取出每个模态输出序列的第一个位置的特征(即 [CLS] Token),将这三个模态的 [CLS] Token 拼接起来,送入分类器进行最终的情感分类。

如何运行

GPU/CPU 平台运行

  1. 激活 Conda 环境

    conda activate pytorch

    原始环境太久远我已忘记需要配置哪些,应该也是基于这个emt-dlfr来进行环境配置的。

  2. 配置路径

    • 数据集路径:打开 config/get_data_root.py 文件,在其中指定您存放数据集的根目录。
    • BERT 模型路径:打开 models/subNets/BertTextEncoder.py 文件,在大约第 26 行的位置,指定您本地存放的预训练 BERT 模型的路径。
  3. 运行训练脚本 项目包含了针对两个数据集的运行脚本,超参数已在脚本内预设好。

    • 运行 MOSI 数据集:

      1. trains/missingTask/TNET_mosi.py 文件重命名为 TNET.py
      2. 执行脚本 run_mosi_best.py
    • 运行 MOSEI 数据集:

      1. trains/missingTask/TNET_mosei.py 文件重命名为 TNET.py
      2. 执行脚本 run_mosei_best.py

昇腾 NPU 平台运行

本项目已适配华为昇腾 NPU 平台,但由于硬件特性差异,存在一些特殊的适配处理。

环境准备

  1. 安装依赖

    # 首先安装 CANN 和 torch-npu,请根据您的 CANN 版本选择对应的 PyTorch 版本
    # 例如:
    pip install torch==2.1.0
    pip install torch-npu==2.1.0.post3
    
    # 安装其他依赖
    pip install -r requirements_ascend.txt
  2. 配置路径(与 GPU/CPU 平台相同)

    • 数据集路径:打开 config/get_data_root.py 文件,指定数据集根目录。
    • BERT 模型路径:打开 models/subNets/BertTextEncoder.py 文件,在大约第 26 行的位置,指定预训练 BERT 模型路径。
  3. 运行训练脚本

    # 运行 MOSI 数据集(昇腾版本)
    python run_mosi_best_npu.py

昇腾适配的主要改动

由于昇腾 NPU 的硬件特性与 CUDA GPU 存在差异,代码进行了以下关键适配:

  1. FFT 操作的纯 PyTorch 实现
    昇腾 NPU 目前不支持 torch.fft 复数 FFT 操作。为避免 CPU 回退带来的性能损失,代码使用纯 PyTorch 矩阵乘法直接在 NPU 上实现 DFT/IDFT

    涉及文件:

    • models/missingTask/TNET.py
    • models/missingTask/TNET_v3.py
    • models/missingTask/UNET.py

    主要实现方式:

    • 通过 DFT 公式 e^{-i2πkn/N} 展开为 cossin 矩阵
    • 使用矩阵乘法 x @ cos_mat.t() 实现频域变换,所有操作均在 NPU 上执行
    • 添加了 DFT 矩阵缓存机制(_dft_cache)以提升性能
    • 使用 torch.roll 实现 fftshift/ifftshift 操作
  2. Attention Mask 设备兼容性修复
    MultiheadAttention 模块中,添加了设备一致性检查,确保 attn_maskattn_weights 在同一设备上进行运算。

    涉及文件:

    • models/subNets/transformers_encoder/multihead_attention.py
  3. Future Mask 设备适配
    在 Transformer 编码器中,修改了设备判断逻辑,不仅检查 CUDA,而是直接将 mask 移至与输入 tensor 相同的设备。

    涉及文件:

    • models/subNets/transformers_encoder/transformer.py

性能说明

⚠️ 重要提示:昇腾 NPU 版本的实验结果与 GPU 平台存在性能差异。

  • 性能差距:在 MOSI 数据集上,昇腾 NPU 版本在 ACC-7 和 ACC-2 指标上的表现相比 GPU 平台大约低 1.5% 左右

  • 可能原因

    1. DFT 手动实现的数值差异:手动实现的 DFT(基于矩阵乘法)与 GPU 上高度优化的 torch.fft 在数值精度上可能存在细微差异,这些差异在反向传播中累积可能影响最终性能。
    2. NPU 与 GPU 浮点运算特性差异:不同硬件架构在浮点运算、舍入方式、数值稳定性等方面存在固有差异。
    3. 算子优化程度不同:矩阵乘法等基础算子在 NPU 上的优化程度可能与 CUDA 有所差异,虽然都能正常运行,但计算路径的细微不同可能影响训练收敛。
  • 建议

    • 追求最佳性能:建议使用 GPU 平台进行训练和测试,以获得论文中报告的最优指标。
    • 昇腾平台部署:如果需要在昇腾平台上部署,当前版本可作为参考基线。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages