Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

#使用Transformer进行文本分类#代码提交#937

Open
YinHang2515 wants to merge 19 commits intoPaddlePaddle:developfrom
YinHang2515:develop
Open

#使用Transformer进行文本分类#代码提交#937
YinHang2515 wants to merge 19 commits intoPaddlePaddle:developfrom
YinHang2515:develop

Conversation

@YinHang2515
Copy link
Copy Markdown

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Nov 30, 2020

CLA assistant check
All committers have signed the CLA.

"source": [
"import paddle\n",
"import paddle.nn as nn\n",
"import paddle.fluid.dygraph as dg\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paddle2.0不建议使用fluid,默认动态图开发模式。

"pad_id = word_dict['<pad>']\r\n",
"embed_dim = 32 # Embedding size for each token\r\n",
"num_heads = 2 # Number of attention heads\r\n",
"ff_dim = 32 # Hidden layer size in feed forward network inside transformer\r\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ff_dim变量命名不是很清晰。

" x = self.drop2(x)\r\n",
" x = self.soft(x)\r\n",
" return x\r\n",
"# class MyNet(paddle.nn.Layer):\r\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处注释可删除。

},
"source": [
"可以看到经过两轮的迭代训练,可以达到85%左右的准确率,当然你也可以通过调整参数、更改优化方式等等来进一步提升性能。"
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可使用model.predict进行预测,打印出句子,预测标签和实际标签,这样比较直观。

@YinHang2515
Copy link
Copy Markdown
Author

根据要求进行了相应的修改,并已同步更新至AIStudio

Copy link
Copy Markdown

@chenxiaozeng chenxiaozeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 suggestions.

"class PointWiseFeedForwardNetwork(nn.Layer):\r\n",
" def __init__(self, embed_dim, feed_dim):\r\n",
" super(PointWiseFeedForwardNetwork, self).__init__()\r\n",
" self.linear1 = pd.fluid.dygraph.Linear(embed_dim, feed_dim, act='relu')\r\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多处fluid需要改成nn

" loss=nn.CrossEntropyLoss())\r\n",
"\r\n",
"# 模型训练\r\n",
"model.fit(train_loader,\r\n",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练完成之后,可以调用model.predict()测试下模型在test数据集上的表现。

@@ -0,0 +1 @@

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file, to delete?

@YinHang2515
Copy link
Copy Markdown
Author

根据要求进行了相应的修改,并已同步更新至AIStudio

Copy link
Copy Markdown

@chenxiaozeng chenxiaozeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

},
"outputs": [],
"source": [
"class TransformerBlock(nn.Layer):\r\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Paddle中已经提供了Transformer的相关API https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/nn/layer/transformer/TransformerEncoder_cn.html#transformerencoder ,如果只是为了使用而不是要说明这些具体实现的话,可否直接使用这些API呢

Copy link
Copy Markdown
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了上述问题外,还有两处需要注意下:
1、2.0已经发布了,麻烦更新到2.0版本;
2、看预测的效果不是特别好,可以再优化一下网络
感谢~

},
"outputs": [],
"source": [
"import paddle as pd\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import paddle

"source": [
"import paddle as pd\n",
"import paddle.nn as nn\n",
"import paddle.nn.functional as func\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不推荐这么写

"train_dataset = IMDBDataset(train_sents, train_labels)\r\n",
"test_dataset = IMDBDataset(test_sents, test_labels)\r\n",
"\r\n",
"train_loader = pd.io.DataLoader(train_dataset, places=pd.CPUPlace(), return_list=True,\r\n",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

places=pd.CPUPlace() 可以删除

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants