PyTorch快速入门教程八(使用word embedding做自然语言处理的词语预测) 顶转

earnpls

数据预处理

``````CONTEXT_SIZE = 2
EMBEDDING_DIM = 10
# We will use Shakespeare Sonnet 2
test_sentence = """When forty winters shall besiege thy brow,
And dig deep trenches in thy beauty's field,
Thy youth's proud livery so gazed on now,
Will be a totter'd weed of small worth held:
Then being asked, where all thy beauty lies,
Where all the treasure of thy lusty days;
To say, within thine own deep sunken eyes,
Were an all-eating shame, and thriftless praise.
How much more praise deserv'd thy beauty's use,
If thou couldst answer 'This fair child of mine
Shall sum my count, and make my old excuse,'
Proving his beauty by succession thine!
This were to be new made when thou art old,
And see thy blood warm when thou feel'st it cold.""".split()``````

`CONTEXT_SIZE`表示我们想由前面的几个单词来预测这个单词，这里设置为2，就是说我们希望通过这个单词的前两个单词来预测这一个单词。 `EMBEDDING_DIM`表示`word embedding`的维数，上一篇已经介绍过了。

``````trigram = [((test_sentence[i], test_sentence[i+1]), test_sentence[i+2])
for i in range(len(test_sentence)-2)]``````

``````vocb = set(test_sentence) # 通过set将重复的单词去掉
word_to_idx = {word: i for i, word in enumerate(vocb)}
idx_to_word = {word_to_idx[word]: word for word in word_to_idx}``````

定义模型

``````class NgramModel(nn.Module):
def __init__(self, vocb_size, context_size, n_dim):
super(NgramModel, self).__init__()
self.n_word = vocb_size
self.embedding = nn.Embedding(self.n_word, n_dim)
self.linear1 = nn.Linear(context_size*n_dim, 128)
self.linear2 = nn.Linear(128, self.n_word)

def forward(self, x):
emb = self.embedding(x)
emb = emb.view(1, -1)
out = self.linear1(emb)
out = F.relu(out)
out = self.linear2(out)
log_prob = F.log_softmax(out)
return log_prob

ngrammodel = NgramModel(len(word_to_idx), CONTEXT_SIZE, 100)
criterion = nn.NLLLoss()
optimizer = optim.SGD(ngrammodel.parameters(), lr=1e-3)``````

训练

``````for epoch in range(100):
print('epoch: {}'.format(epoch+1))
print('*'*10)
running_loss = 0
for data in trigram:
word, label = data
word = Variable(torch.LongTensor([word_to_idx[i] for i in word]))
label = Variable(torch.LongTensor([word_to_idx[label]]))
# forward
out = ngrammodel(word)
loss = criterion(out, label)
running_loss += loss.data[0]
# backward
loss.backward()
optimizer.step()
print('Loss: {:.6f}'.format(running_loss / len(word_to_idx)))``````

``````word, label = trigram[3]
word = Variable(torch.LongTensor([word_to_idx[i] for i in word]))
out = ngrammodel(word)
_, predict_label = torch.max(out, 1)
predict_word = idx_to_word[predict_label.data[0][0]]
print('real word is {}, predict word is {}'.format(label, predict_word))``````

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布，完美翻译，更加方便大家浏览：

Pytorch中文网：https://ptorch.com/

Pytorch中文文档：https://ptorch.com/docs/1/

earnpls

PyTorch 官方中文教程包含 60 分钟快速入门教程，强化教程

PyTorch 是一个基于 Torch 的 Python 开源机器学习库，用于自然语言处理等应用程序。它主要由 d 的人工智能小组开发，不仅能够 实现强大的 GPU 加速，同时还支持动态神经网络，这一点是现在很...

08/14
0
0
PyTorch 你想知道的都在这里

2018/10/20
0
0

2018/10/22
0
0

NeuronBlocks：像搭积木一样构建自然语言理解深度学习模型 中文教程 概览 NeuronBlocks 是一个模块化 NLP 深度学习建模工具包，可以帮助工程师/研究者们快速构建 NLP 任务的神经网络模型。 ...

06/19
526
0

2018/12/25
104
0

mysql-connector-java升级到8.0后保存时间到数据库出现了时差

ValSong
26分钟前
4
0
Spring Boot 如何部署到 Linux 中的服务

honeymoose
28分钟前
4
0
Spring Boot 2 实战：使用 Spring Boot Admin 监控你的应用

1. 前言 生产上对 Web 应用 的监控是十分必要的。我们可以近乎实时来对应用的健康、性能等其他指标进行监控来及时应对一些突发情况。避免一些故障的发生。对于 Spring Boot 应用来说我们可以...

6
0
ZetCode 教程翻译计划正式启动 | ApacheCN

ApacheCN_飞龙

4
0
CSS定位

CSS定位 relative相对定位 absolute绝对定位 fixed和sticky及zIndex relative相对定位 position特性：css position属性用于指定一个元素在文档中的定位方式。top、right、bottom、left属性则...

studywin

7
0