from config import *
import torch
import torch.utils.data as data
from transformers import BertTokenizer
import pandas as pd
import random
# 获得实体位置索引
def get_ent_pos(lst):
items = []
for i in range(len(lst)):
# B-ASP 开头对应 id=1
# I-ASP 对应 id=2
if lst[i] == 1:
item = [i]
while True:
i += 1
# 一直到 I-ASP 结束
if i >= len(lst) or lst[i] != 2:
items.append(item)
break
else:
item.append(i)
i += 1
return items
# [CLS]这个手机外观时尚,美中不足的是拍照像素低。[SEP]
# 0 0 0 0 0 1 2 0 0 0 0 0 0 0 0 0 1 2 2 2 0 0 0
# 输出结果: [[5, 6], [16, 17, 18, 19]]
# print(get_ent_pos([0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0]))
# exit()
def get_ent_weight(max_len, ent_pos):
cdm = []
cdw = []
for i in range(max_len):
dst = min(abs(i - ent_pos[0]), abs(i - ent_pos[-1]))
if dst <= SRD:
cdm.append(1)
cdw.append(1)
else:
cdm.append(0)
cdw.append(1 - (dst - SRD + 1) / max_len)
return cdm, cdw
# print(get_ent_weight(23, [5, 6]))
# exit()
class Dataset(data.Dataset):
def __init__(self, type="train"):
super().__init__()
file_path = TRAIN_FILE_PATH if type == "train" else TEST_FILE_PATH
self.df = pd.read_csv(file_path)
self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
def __len__(self):
# 合并句子,当前句子与下一个句子合并,总样本数量 -1
return len(self.df) - 1
def __getitem__(self, index):
# 相邻两个句子拼接
text1, bio1, pola1 = self.df.loc[index]
text2, bio2, pola2 = self.df.loc[index + 1]
# 注意拼接符号中间的空格 " ; " " 0 " " -1 ",与原格式匹配
text = text1 + " ; " + text2
bio = bio1 + " O " + bio2
pola = pola1 + " -1 " + pola2
# print(text)
# print(bio)
# print(pola)
# exit()
# 按自己的规则分词 ***触类旁通***
# eg: "英 寸 液 晶 屏 显 示 效 果 出 色 ; forever 清 晰 度 高"
# bert 会把 forever 标记为 "[UNK]",但是情感分类的时候,会把它索引成好多值
tokens = ["[CLS]"] + text.split(" ") + ["[SEP]"]
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# BIO 标签转 id
bio_arr = ["O"] + bio.split(" ") + ["O"]
bio_label = [BIO_MAP[l] for l in bio_arr]
# 情感值转数字
pola_arr = ["-1"] + pola.split(" ") + ["-1"]
# 同样功能 pola_label = [int(i) for i in pola_arr]
pola_label = list(map(int, pola_arr))
# print(input_ids)
# print(bio_label)
# print(pola_label)
# exit()
return input_ids, bio_label, pola_label
def collate_fn(self, batch):
# 统计最大句子长度
batch.sort(key=lambda x: len(x[0]), reverse=True)
max_len = len(batch[0][0])
# 变量初始化
batch_input_ids = []
batch_bio_label = []
batch_mask = []
batch_ent_cdm = []
batch_ent_cdw = []
batch_pola_label = []
batch_pairs = []
for input_ids, bio_label, pola_label in batch:
# 获取实体位置,没有实体跳过
# ent_pos 形状 [[实体1索引], [实体2索引],,,]
# eg:[[3, 4, 5], [13, 14, 15]]
ent_pos = get_ent_pos(bio_label)
if len(ent_pos) == 0:
continue
# 填充句子长度
pad_len = max_len - len(input_ids)
batch_input_ids.append(input_ids + [BERT_PAD_ID] * pad_len)
batch_mask.append([1] * len(input_ids) + [0] * pad_len)
batch_bio_label.append(bio_label + [BIO_O_ID] * pad_len)
# 实体和情感分类对应
pairs = []
for pos in ent_pos:
# 此时 pola = 0 or 1
pola = pola_label[pos[0]]
# 异常值替换,下面相同
pola = 0 if pola == -1 else pola
pairs.append((pos, pola))
batch_pairs.append(pairs)
# print(batch_pairs)
# exit()
# 随机取一个实体
sg_ent_pos = random.choice(ent_pos)
cdm, cdw = get_ent_weight(max_len, sg_ent_pos)
# print(sg_ent_pos)
# print(cdm, cdw)
# exit()
# 计算加权参数
batch_ent_cdm.append(cdm)
batch_ent_cdw.append(cdw)
# 实体第一个字的情感极性,不出意外的话,pola 为 0 or 1
pola = pola_label[sg_ent_pos[0]]
# 排除意外情况,总共有三种选择: 0 1 -1
pola = 0 if pola == -1 else pola
batch_pola_label.append(pola)
# print("input_ids", input_ids)
# print("bio_label", bio_label)
# print("pola_label", pola_label)
# print("ent_pos", ent_pos)
# print("pairs", pairs)
# print("batch_pairs", batch_pairs)
# print("ent_pos", ent_pos)
# print("sg_ent_pos", sg_ent_pos)
# print("cdm", cdm)
# print("cdw", cdw)
# print("pola", pola)
# print("batch_pola_label", batch_pola_label)
# exit()
return (
torch.tensor(batch_input_ids),
torch.tensor(batch_mask).bool(),
torch.tensor(batch_bio_label),
# *** 这三个对应同一个实体 ***
torch.tensor(batch_ent_cdm),
torch.tensor(batch_ent_cdw),
torch.tensor(batch_pola_label),
# *** End ***
batch_pairs,
)
def get_pola(model, input_ids, mask, ent_label):
# 变量初始化
b_input_ids = []
b_mask = []
b_ent_cdm = []
b_ent_cdw = []
b_ent_pos = []
# 根据label解析实体位置
ent_pos = get_ent_pos(ent_label)
n = len(ent_pos)
if n == 0:
return None, None
# n个实体一起预测,同一个句子复制n份,作为一个batch
# 虽然 input_ids 只是一个句子,但是这个句子可能包含很多实体
# 不同句子对应不同的实体,当前句子的情感极性就表示对应实体的情感
b_input_ids.extend([input_ids] * n)
b_mask.extend([mask] * n)
b_ent_pos.extend(ent_pos)
for sg_ent_pos in ent_pos:
cdm, cdw = get_ent_weight(len(input_ids), sg_ent_pos)
b_ent_cdm.append(cdm)
b_ent_cdw.append(cdw)
# 列表转 tensor,注意 torch.stack() 与 torch.cat() 之间的联系
b_input_ids = torch.stack(b_input_ids, dim=0).to(DEVICE)
b_mask = torch.stack(b_mask, dim=0).to(DEVICE)
b_ent_cdm = torch.tensor(b_ent_cdm).to(DEVICE)
b_ent_cdw = torch.tensor(b_ent_cdw).to(DEVICE)
b_ent_pola = model.get_pola(b_input_ids, b_mask, b_ent_cdm, b_ent_cdw)
return b_ent_pos, b_ent_pola
if __name__ == "__main__":
dataset = Dataset()
loader = data.DataLoader(dataset, batch_size=2, collate_fn=dataset.collate_fn)
print(iter(loader).__next__())
细粒度情感分类,这个是通过python pytorch实现的一个细粒度情感分类
需积分: 0 115 浏览量
更新于2023-12-10
收藏 1.55MB ZIP 举报
细粒度情感分类是一种情感分析任务,它比传统的情感分类更为深入,旨在识别文本中更为具体、微妙的情感极性。在传统的二元或三元情感分类中,我们可能只能判断文本是正面、负面还是中性,而在细粒度情感分类中,我们可以进一步区分如“非常满意”、“稍微失望”等更为细致的情感状态。这对于理解用户情绪、产品反馈或社交媒体分析具有重要意义。
在这个项目中,开发者使用Python和PyTorch框架实现了这样一个细粒度情感分类模型。PyTorch是一个强大的深度学习库,它的动态计算图机制使得模型构建和调试变得更加灵活。利用PyTorch,我们可以轻松地构建复杂的神经网络架构,如BERT模型。
BERT(Bidirectional Encoder Representations from Transformers)是一种预训练语言模型,它在大量无标注文本上进行了预训练,能够理解和生成自然语言。在情感分类任务中,BERT可以捕捉上下文中的深层语义信息,这对于识别细粒度情感至关重要。在这个实现中,BERT可能被用作特征提取器,将输入文本转换为向量表示。
除了BERT,模型还结合了条件随机场(CRF,Conditional Random Fields)。CRF是一种统计建模方法,常用于序列标注任务,如命名实体识别和词性标注。在情感分类中,CRF可以帮助模型考虑整个句子的情感一致性,而不是仅仅依赖于单个词语的预测结果。通过引入CRF,模型能够更好地处理情感标记的转移概率,从而提高整体的分类性能。
此外,提到的注意力机制可能是指自注意力(Self-Attention),这是Transformer架构的核心组成部分。自注意力允许模型对每个位置的输入给予不同的权重,使模型能够更好地聚焦于文本中关键信息,对于理解和处理长距离依赖特别有效。
在这个实现中,使用了两个联合损失函数。这可能意味着模型不仅优化了标准的交叉熵损失,还可能引入了额外的正则化项或者特定于任务的损失函数,以促进模型学习更为复杂的情感表示和提高泛化能力。
总体而言,这个项目展示了一个综合运用现代深度学习技术解决自然语言处理任务的例子。通过结合BERT的上下文理解能力、CRF的全局序列信息处理以及自注意力的动态焦点调整,模型能够对文本进行细粒度的情感分类,提供更准确的情感分析结果。这样的系统在实际应用中,如舆情分析、产品评价处理等领域,有着广泛的应用前景。
MrGao
- 粉丝: 738
- 资源: 28
最新资源
- MATLAB【面板】车辆检测.zip
- MATLAB【面板】车牌出入库计费系统.zip
- MATLAB【面板】车道线检测定位.zip
- MATLAB【面板】车牌识别.zip
- 微电网,下垂控制(三相交流) 传统阻感型下垂控制输出有功 无功 频率波形
- MATLAB【面板】车牌号码出入库管理.zip
- MATLAB【面板】车牌识别设计.zip
- MATLAB【面板】车牌识别GUI实现.zip
- MATLAB【面板】车牌识别GUI界面.zip
- MATLAB【面板】答题卡识别GUI.zip
- MATLAB【面板】虫害检测.zip
- MATLAB【面板】答题卡自动识别系统.zip
- MATLAB【面板】答题卡识别系统.zip
- MATLAB【面板】打印纸缺陷检测GUI设计.zip
- MATLAB【面板】道路桥梁裂缝检测.zip
- 八木天线计算器,如果您想制作天线,这个计算器非常好用