microsoft/AI-For-Beginners

Public

mirrored fromhttps://github.com/microsoft/AI-For-BeginnersAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cac0988687e243aef04d18c2e52ff408c9b44ea5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

lessons/5-NLP/15-LanguageModeling/CBoW-PyTorch.ipynb

563lines · modecode

1{
2 "cells": [
3 {
4 "cell_type": "markdown",
5 "metadata": {
6 "id": "NXTSugt6ieXh"
7 },
8 "source": [
9 "## Training CBoW Model\n",
10 "\n",
11 "This notebooks is a part of [AI for Beginners Curriculum](http://aka.ms/ai-beginners)\n",
12 "\n",
13 "In this example, we will look at training CBoW language model to get our own Word2Vec embedding space. We will use AG News dataset as the source of text."
14 ]
15 },
16 {
17 "cell_type": "code",
18 "source": [
19 "import torch\n",
20 "import torchtext\n",
21 "import os\n",
22 "import collections\n",
23 "import builtins\n",
24 "import random\n",
25 "import numpy as np"
26 ],
27 "metadata": {
28 "id": "q-UiiJUKaxHj"
29 },
30 "execution_count": null,
31 "outputs": []
32 },
33 {
34 "cell_type": "code",
35 "source": [
36 "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
37 ],
38 "metadata": {
39 "id": "TFbR8CZaTZ1q"
40 },
41 "execution_count": null,
42 "outputs": []
43 },
44 {
45 "cell_type": "markdown",
46 "source": [
47 "First let's load our dataset and define tokenizer and vocabulary. We will set `vocab_size` to 5000 to limit computations a bit."
48 ],
49 "metadata": {
50 "id": "HIwC7lI5T-ov"
51 }
52 },
53 {
54 "cell_type": "code",
55 "source": [
56 "def load_dataset(ngrams = 1, min_freq = 1, vocab_size = 5000 , lines_cnt = 500):\n",
57 " tokenizer = torchtext.data.utils.get_tokenizer('basic_english')\n",
58 " print(\"Loading dataset...\")\n",
59 " test_dataset, train_dataset = torchtext.datasets.AG_NEWS(root='./data')\n",
60 " train_dataset = list(train_dataset)\n",
61 " test_dataset = list(test_dataset)\n",
62 " classes = ['World', 'Sports', 'Business', 'Sci/Tech']\n",
63 " print('Building vocab...')\n",
64 " counter = collections.Counter()\n",
65 " for i, (_, line) in enumerate(train_dataset):\n",
66 " counter.update(torchtext.data.utils.ngrams_iterator(tokenizer(line),ngrams=ngrams))\n",
67 " if i == lines_cnt:\n",
68 " break\n",
69 " vocab = torchtext.vocab.Vocab(collections.Counter(dict(counter.most_common(vocab_size))), min_freq=min_freq)\n",
70 " return train_dataset, test_dataset, classes, vocab, tokenizer"
71 ],
72 "metadata": {
73 "id": "wdZuygtgiuLG"
74 },
75 "execution_count": null,
76 "outputs": []
77 },
78 {
79 "cell_type": "code",
80 "source": [
81 "train_dataset, test_dataset, _, vocab, tokenizer = load_dataset()"
82 ],
83 "metadata": {
84 "colab": {
85 "base_uri": "https://localhost:8080/"
86 },
87 "id": "4d1nU1gsivGu",
88 "outputId": "949fe272-ae0e-49f5-c373-6703458b3a74"
89 },
90 "execution_count": null,
91 "outputs": [
92 {
93 "output_type": "stream",
94 "name": "stdout",
95 "text": [
96 "Loading dataset...\n",
97 "Building vocab...\n"
98 ]
99 }
100 ]
101 },
102 {
103 "cell_type": "code",
104 "source": [
105 "def encode(x, vocabulary, tokenizer = tokenizer):\n",
106 " return [vocabulary[s] for s in tokenizer(x)]"
107 ],
108 "metadata": {
109 "id": "1XDYNhG8ToFV"
110 },
111 "execution_count": null,
112 "outputs": []
113 },
114 {
115 "cell_type": "markdown",
116 "metadata": {
117 "id": "LIlQk6_PaHVY"
118 },
119 "source": [
120 "## CBoW Model\n",
121 "\n",
122 "CBoW learns to predict a word based on the $2N$ neighboring words. For example, when $N=1$, we will get the following pairs from the sentence *I like to train networks*: (like,I), (I, like), (to, like), (like,to), (train,to), (to, train), (networks, train), (train,networks). Here, first word is the neighboring word used as an input, and second word is the one we are predicting.\n",
123 "\n",
124 "To build a network to predict next word, we will need to supply neighboring word as input, and get word number as output. The architecture of CBoW network is the following:\n",
125 "\n",
126 "* Input word is passed through the embedding layer. This very embedding layer would be our Word2Vec embedding, thus we will define it separately as `embedder` variable. We will use embedding size = 30 in this example, even though you might want to experiment with higher dimensions (real word2vec has 300)\n",
127 "* Embedding vector would then be passed to a linear layer that will predict output word. Thus it has the `vocab_size` neurons.\n",
128 "\n",
129 "For the output, if we use `CrossEntropyLoss` as loss function, we would also have to provide just word numbers as expected results, without one-hot encoding."
130 ]
131 },
132 {
133 "cell_type": "code",
134 "source": [
135 "vocab_size = len(vocab)\n",
136 "\n",
137 "embedder = torch.nn.Embedding(num_embeddings = vocab_size, embedding_dim = 30)\n",
138 "model = torch.nn.Sequential(\n",
139 " embedder,\n",
140 " torch.nn.Linear(in_features = 30, out_features = vocab_size),\n",
141 ")\n",
142 "\n",
143 "print(model)"
144 ],
145 "metadata": {
146 "colab": {
147 "base_uri": "https://localhost:8080/"
148 },
149 "id": "akKTcKQKkfl2",
150 "outputId": "da687e3e-a8ec-4c1a-e456-ab8cd6ac7dad"
151 },
152 "execution_count": null,
153 "outputs": [
154 {
155 "output_type": "stream",
156 "name": "stdout",
157 "text": [
158 "Sequential(\n",
159 " (0): Embedding(5002, 30)\n",
160 " (1): Linear(in_features=30, out_features=5002, bias=True)\n",
161 ")\n"
162 ]
163 }
164 ]
165 },
166 {
167 "cell_type": "markdown",
168 "metadata": {
169 "id": "Nud6jgGPaHVa"
170 },
171 "source": [
172 "## Preparing Training Data\n",
173 "\n",
174 "Now let's program the main function that will compute CBoW word pairs from text. This function will allow us to specify window size, and will return a set of pairs - input and output word. Note that this function can be used on words, as well as on vectors/tensors - which will allow us to encode the text, before passing it to `to_cbow` function."
175 ]
176 },
177 {
178 "cell_type": "code",
179 "execution_count": null,
180 "metadata": {
181 "colab": {
182 "base_uri": "https://localhost:8080/"
183 },
184 "id": "x-dsXygOieXn",
185 "outputId": "c2218280-e540-40ba-9546-efe48d0d714f"
186 },
187 "outputs": [
188 {
189 "output_type": "stream",
190 "name": "stdout",
191 "text": [
192 "[['like', 'I'], ['to', 'I'], ['I', 'like'], ['to', 'like'], ['train', 'like'], ['I', 'to'], ['like', 'to'], ['train', 'to'], ['networks', 'to'], ['like', 'train'], ['to', 'train'], ['networks', 'train'], ['to', 'networks'], ['train', 'networks']]\n",
193 "[[232, 172], [5, 172], [172, 232], [5, 232], [0, 232], [172, 5], [232, 5], [0, 5], [1202, 5], [232, 0], [5, 0], [1202, 0], [5, 1202], [0, 1202]]\n"
194 ]
195 }
196 ],
197 "source": [
198 "def to_cbow(sent,window_size=2):\n",
199 " res = []\n",
200 " for i,x in enumerate(sent):\n",
201 " for j in range(max(0,i-window_size),min(i+window_size+1,len(sent))):\n",
202 " if i!=j:\n",
203 " res.append([sent[j],x])\n",
204 " return res\n",
205 "\n",
206 "print(to_cbow(['I','like','to','train','networks']))\n",
207 "print(to_cbow(encode('I like to train networks', vocab)))"
208 ]
209 },
210 {
211 "cell_type": "markdown",
212 "metadata": {
213 "id": "XVaaDLjaaHVb"
214 },
215 "source": [
216 "Let's prepare the training dataset. We will go through all news, call `to_cbow` to get the list of word pairs, and add those pairs to `X` and `Y`. For the sake of time, we will only consider first 10k news items - you can easily remove the limitation in case you have more time to wait, and want to get better embeddings :)"
217 ]
218 },
219 {
220 "cell_type": "code",
221 "execution_count": null,
222 "metadata": {
223 "id": "54b-Gd9TieXo"
224 },
225 "outputs": [],
226 "source": [
227 "X = []\n",
228 "Y = []\n",
229 "for i, x in zip(range(10000), train_dataset):\n",
230 " for w1, w2 in to_cbow(encode(x[1], vocab), window_size = 5):\n",
231 " X.append(w1)\n",
232 " Y.append(w2)\n",
233 "\n",
234 "X = torch.tensor(X)\n",
235 "Y = torch.tensor(Y)"
236 ]
237 },
238 {
239 "cell_type": "markdown",
240 "source": [
241 "We will also convert that data to one dataset, and create dataloader:"
242 ],
243 "metadata": {
244 "id": "cwWy0PzXWhN5"
245 }
246 },
247 {
248 "cell_type": "code",
249 "source": [
250 "class SimpleIterableDataset(torch.utils.data.IterableDataset):\n",
251 " def __init__(self, X, Y):\n",
252 " super(SimpleIterableDataset).__init__()\n",
253 " self.data = []\n",
254 " for i in range(len(X)):\n",
255 " self.data.append( (Y[i], X[i]) )\n",
256 " random.shuffle(self.data)\n",
257 "\n",
258 " def __iter__(self):\n",
259 " return iter(self.data)"
260 ],
261 "metadata": {
262 "id": "mfoAcGPFZU8p"
263 },
264 "execution_count": null,
265 "outputs": []
266 },
267 {
268 "cell_type": "markdown",
269 "metadata": {
270 "id": "e4NQ_-5waHVc"
271 },
272 "source": [
273 "We will also convert that data to one dataset, and create dataloader:"
274 ]
275 },
276 {
277 "cell_type": "code",
278 "execution_count": null,
279 "metadata": {
280 "id": "AbLUcojlieXo"
281 },
282 "outputs": [],
283 "source": [
284 "ds = SimpleIterableDataset(X, Y)\n",
285 "dl = torch.utils.data.DataLoader(ds, batch_size = 256)"
286 ]
287 },
288 {
289 "cell_type": "markdown",
290 "metadata": {
291 "id": "pKQr7sXeaHVc"
292 },
293 "source": [
294 "Now let's do the actual training. We will use `SGD` optimizer with pretty high learning rate. You can also try playing around with other optimizers, such as `Adam`. We will train for 10 epochs to begin with - and you can re-run this cell if you want even lower loss."
295 ]
296 },
297 {
298 "cell_type": "code",
299 "source": [
300 "def train_epoch(net, dataloader, lr = 0.01, optimizer = None, loss_fn = torch.nn.CrossEntropyLoss(), epochs = None, report_freq = 1):\n",
301 " optimizer = optimizer or torch.optim.Adam(net.parameters(), lr = lr)\n",
302 " loss_fn = loss_fn.to(device)\n",
303 " net.train()\n",
304 "\n",
305 " for i in range(epochs):\n",
306 " total_loss, j = 0, 0, \n",
307 " for labels, features in dataloader:\n",
308 " optimizer.zero_grad()\n",
309 " features, labels = features.to(device), labels.to(device)\n",
310 " out = net(features)\n",
311 " loss = loss_fn(out, labels)\n",
312 " loss.backward()\n",
313 " optimizer.step()\n",
314 " total_loss += loss\n",
315 " j += 1\n",
316 " if i % report_freq == 0:\n",
317 " print(f\"Epoch: {i+1}: loss={total_loss.item()/j}\")\n",
318 "\n",
319 " return total_loss.item()/j"
320 ],
321 "metadata": {
322 "id": "HeeCYKr_KF1w"
323 },
324 "execution_count": null,
325 "outputs": []
326 },
327 {
328 "cell_type": "code",
329 "source": [
330 "train_epoch(net = model, dataloader = dl, optimizer = torch.optim.SGD(model.parameters(), lr = 0.1), loss_fn = torch.nn.CrossEntropyLoss(), epochs = 10)"
331 ],
332 "metadata": {
333 "colab": {
334 "base_uri": "https://localhost:8080/"
335 },
336 "id": "KVgwGtDHgDlT",
337 "outputId": "2447833f-f0e3-4566-c33d-addbfe2f451d"
338 },
339 "execution_count": null,
340 "outputs": [
341 {
342 "output_type": "stream",
343 "name": "stdout",
344 "text": [
345 "Epoch: 1: loss=5.664632366860172\n",
346 "Epoch: 2: loss=5.632101973960962\n",
347 "Epoch: 3: loss=5.610399051405015\n",
348 "Epoch: 4: loss=5.594621561080262\n",
349 "Epoch: 5: loss=5.582538017415446\n",
350 "Epoch: 6: loss=5.572900234519603\n",
351 "Epoch: 7: loss=5.564951676341915\n",
352 "Epoch: 8: loss=5.558288112064614\n",
353 "Epoch: 9: loss=5.552576955031129\n",
354 "Epoch: 10: loss=5.547634165194347\n"
355 ]
356 },
357 {
358 "output_type": "execute_result",
359 "data": {
360 "text/plain": [
361 "5.547634165194347"
362 ]
363 },
364 "metadata": {},
365 "execution_count": 16
366 }
367 ]
368 },
369 {
370 "cell_type": "markdown",
371 "metadata": {
372 "id": "W8u2qXZmaHVd"
373 },
374 "source": [
375 "## Trying out Word2Vec\n",
376 "\n",
377 "To use Word2Vec, let's extract vectors corresponding to all words in our vocabulary:"
378 ]
379 },
380 {
381 "cell_type": "code",
382 "execution_count": null,
383 "metadata": {
384 "id": "r8TatcXjkU_t"
385 },
386 "outputs": [],
387 "source": [
388 "vectors = torch.stack([embedder(torch.tensor(vocab[s])) for s in vocab.itos], 0)"
389 ]
390 },
391 {
392 "cell_type": "markdown",
393 "metadata": {
394 "id": "3OcX21UOaHVd"
395 },
396 "source": [
397 "Let's see, for example, how the word **Paris** is encoded into a vector:"
398 ]
399 },
400 {
401 "cell_type": "code",
402 "execution_count": null,
403 "metadata": {
404 "colab": {
405 "base_uri": "https://localhost:8080/"
406 },
407 "id": "bz6tAeLzieXp",
408 "outputId": "5b20850e-4342-45e9-f840-cfac2b4d61d8"
409 },
410 "outputs": [
411 {
412 "output_type": "stream",
413 "name": "stdout",
414 "text": [
415 "tensor([-0.0915, 2.1224, -0.0281, -0.6819, 1.1219, 0.6458, -1.3704, -1.3314,\n",
416 " -1.1437, 0.4496, 0.2301, -0.3515, -0.8485, 1.0481, 0.4386, -0.8949,\n",
417 " 0.5644, 1.0939, -2.5096, 3.2949, -0.2601, -0.8640, 0.1421, -0.0804,\n",
418 " -0.5083, -1.0560, 0.9753, -0.5949, -1.6046, 0.5774],\n",
419 " grad_fn=<EmbeddingBackward>)\n"
420 ]
421 }
422 ],
423 "source": [
424 "paris_vec = embedder(torch.tensor(vocab['paris']))\n",
425 "print(paris_vec)"
426 ]
427 },
428 {
429 "cell_type": "markdown",
430 "metadata": {
431 "id": "pHTJlaeYaHVd"
432 },
433 "source": [
434 "It is interesting to use Word2Vec to look for synonyms. The following function will return `n` closest words to a given input. To find them, we compute the norm of $|w_i - v|$, where $v$ is the vector corresponding to our input word, and $w_i$ is the encoding of $i$-th word in the vocabulary. We then sort the array and return corresponding indices using `argsort`, and take first `n` elements of the list, which encode positions of closest words in the vocabulary. "
435 ]
436 },
437 {
438 "cell_type": "code",
439 "execution_count": null,
440 "metadata": {
441 "colab": {
442 "base_uri": "https://localhost:8080/"
443 },
444 "id": "NlZyi-_olFar",
445 "outputId": "b5dbb163-88c4-4d5a-eaf2-6751f700e98c"
446 },
447 "outputs": [
448 {
449 "output_type": "execute_result",
450 "data": {
451 "text/plain": [
452 "['microsoft', 'quoted', 'lp', 'rate', 'top']"
453 ]
454 },
455 "metadata": {},
456 "execution_count": 56
457 }
458 ],
459 "source": [
460 "def close_words(x, n = 5):\n",
461 " vec = embedder(torch.tensor(vocab[x]))\n",
462 " top5 = np.linalg.norm(vectors.detach().numpy() - vec.detach().numpy(), axis = 1).argsort()[:n]\n",
463 " return [ vocab.itos[x] for x in top5 ]\n",
464 "\n",
465 "close_words('microsoft')"
466 ]
467 },
468 {
469 "cell_type": "code",
470 "execution_count": null,
471 "metadata": {
472 "colab": {
473 "base_uri": "https://localhost:8080/"
474 },
475 "id": "-dQq7xeAln0U",
476 "outputId": "66f768c3-c248-4bfd-ce4f-c8ffc6d0dd0d"
477 },
478 "outputs": [
479 {
480 "output_type": "execute_result",
481 "data": {
482 "text/plain": [
483 "['basketball', 'lot', 'sinai', 'states', 'healthdaynews']"
484 ]
485 },
486 "metadata": {},
487 "execution_count": 51
488 }
489 ],
490 "source": [
491 "close_words('basketball')"
492 ]
493 },
494 {
495 "cell_type": "code",
496 "execution_count": null,
497 "metadata": {
498 "colab": {
499 "base_uri": "https://localhost:8080/"
500 },
501 "id": "fJXqK26b29sa",
502 "outputId": "78f0baba-ffd0-485a-dd87-0a12bedfd7fa"
503 },
504 "outputs": [
505 {
506 "output_type": "execute_result",
507 "data": {
508 "text/plain": [
509 "['funds', 'travel', 'sydney', 'japan', 'business']"
510 ]
511 },
512 "metadata": {},
513 "execution_count": 77
514 }
515 ],
516 "source": [
517 "close_words('funds')"
518 ]
519 },
520 {
521 "cell_type": "markdown",
522 "metadata": {
523 "id": "My0VeTDd3Ji8"
524 },
525 "source": [
526 "## Takeaway\n",
527 "\n",
528 "Using clever techniques such as CBoW, we can train Word2Vec model. You may also try to train skip-gram model that is trained to predict the neighboring word given the central one, and see how well it performs. "
529 ]
530 }
531 ],
532 "metadata": {
533 "colab": {
534 "collapsed_sections": [],
535 "name": "CBoW-PyTorch.ipynb",
536 "provenance": []
537 },
538 "interpreter": {
539 "hash": "16af2a8bbb083ea23e5e41c7f5787656b2ce26968575d8763f2c4b17f9cd711f"
540 },
541 "kernelspec": {
542 "display_name": "Python 3.8.12 ('py38')",
543 "language": "python",
544 "name": "python3"
545 },
546 "language_info": {
547 "codemirror_mode": {
548 "name": "ipython",
549 "version": 3
550 },
551 "file_extension": ".py",
552 "mimetype": "text/x-python",
553 "name": "python",
554 "nbconvert_exporter": "python",
555 "pygments_lexer": "ipython3",
556 "version": "3.8.12"
557 },
558 "orig_nbformat": 4,
559 "gpuClass": "standard"
560 },
561 "nbformat": 4,
562 "nbformat_minor": 0
563}