microsoft/AI-For-Beginners
Publicmirrored fromhttps://github.com/microsoft/AI-For-BeginnersAvailable
lessons/5-NLP/17-GenerativeNetworks/GenerativePyTorch.ipynb
395lines · modecode
| 1 | { |
| 2 | "cells": [ |
| 3 | { |
| 4 | "cell_type": "markdown", |
| 5 | "metadata": {}, |
| 6 | "source": [ |
| 7 | "# Generative networks\n", |
| 8 | "\n", |
| 9 | "Recurrent Neural Networks (RNNs) and their gated cell variants such as Long Short Term Memory Cells (LSTMs) and Gated Recurrent Units (GRUs) provided a mechanism for language modeling, i.e. they can learn word ordering and provide predictions for next word in a sequence. This allows us to use RNNs for **generative tasks**, such as ordinary text generation, machine translation, and even image captioning.\n", |
| 10 | "\n", |
| 11 | "In RNN architecture we discussed in the previous unit, each RNN unit produced next next hidden state as an output. However, we can also add another output to each recurrent unit, which would allow us to output a **sequence** (which is equal in length to the original sequence). Moreover, we can use RNN units that do not accept an input at each step, and just take some initial state vector, and then produce a sequence of outputs.\n", |
| 12 | "\n", |
| 13 | "In this notebook, we will focus on simple generative models that help us generate text. For simplicity, let's build **character-level network**, which generates text letter by letter. During training, we need to take some text corpus, and split it into letter sequences. " |
| 14 | ] |
| 15 | }, |
| 16 | { |
| 17 | "cell_type": "code", |
| 18 | "execution_count": 1, |
| 19 | "metadata": {}, |
| 20 | "outputs": [ |
| 21 | { |
| 22 | "name": "stdout", |
| 23 | "output_type": "stream", |
| 24 | "text": [ |
| 25 | "Loading dataset...\n", |
| 26 | "Building vocab...\n" |
| 27 | ] |
| 28 | } |
| 29 | ], |
| 30 | "source": [ |
| 31 | "import torch\n", |
| 32 | "import torchtext\n", |
| 33 | "import numpy as np\n", |
| 34 | "from torchnlp import *\n", |
| 35 | "train_dataset,test_dataset,classes,vocab = load_dataset()" |
| 36 | ] |
| 37 | }, |
| 38 | { |
| 39 | "cell_type": "markdown", |
| 40 | "metadata": {}, |
| 41 | "source": [ |
| 42 | "## Building character vocabulary\n", |
| 43 | "\n", |
| 44 | "To build character-level generative network, we need to split text into individual characters instead of words. This can be done by defining a different tokenizer:" |
| 45 | ] |
| 46 | }, |
| 47 | { |
| 48 | "cell_type": "code", |
| 49 | "execution_count": 2, |
| 50 | "metadata": {}, |
| 51 | "outputs": [ |
| 52 | { |
| 53 | "name": "stdout", |
| 54 | "output_type": "stream", |
| 55 | "text": [ |
| 56 | "Vocabulary size = 84\n", |
| 57 | "Encoding of 'a' is 4\n", |
| 58 | "Character with code 13 is h\n" |
| 59 | ] |
| 60 | } |
| 61 | ], |
| 62 | "source": [ |
| 63 | "def char_tokenizer(words):\n", |
| 64 | " return list(words) #[word for word in words]\n", |
| 65 | "\n", |
| 66 | "counter = collections.Counter()\n", |
| 67 | "for (label, line) in train_dataset:\n", |
| 68 | " counter.update(char_tokenizer(line))\n", |
| 69 | "vocab = torchtext.vocab.Vocab(counter)\n", |
| 70 | "\n", |
| 71 | "vocab_size = len(vocab)\n", |
| 72 | "print(f\"Vocabulary size = {vocab_size}\")\n", |
| 73 | "print(f\"Encoding of 'a' is {vocab.stoi['a']}\")\n", |
| 74 | "print(f\"Character with code 13 is {vocab.itos[13]}\")" |
| 75 | ] |
| 76 | }, |
| 77 | { |
| 78 | "cell_type": "markdown", |
| 79 | "metadata": {}, |
| 80 | "source": [ |
| 81 | "Let's see the example of how we can encode the text from our dataset:" |
| 82 | ] |
| 83 | }, |
| 84 | { |
| 85 | "cell_type": "code", |
| 86 | "execution_count": 5, |
| 87 | "metadata": {}, |
| 88 | "outputs": [ |
| 89 | { |
| 90 | "data": { |
| 91 | "text/plain": [ |
| 92 | "tensor([43, 4, 11, 11, 2, 26, 5, 23, 2, 38, 3, 4, 10, 9, 2, 31, 11, 4,\n", |
| 93 | " 21, 2, 38, 4, 14, 25, 2, 34, 8, 5, 6, 2, 5, 13, 3, 2, 38, 11,\n", |
| 94 | " 4, 14, 25, 2, 55, 37, 3, 15, 5, 3, 10, 9, 56, 2, 37, 3, 15, 5,\n", |
| 95 | " 3, 10, 9, 2, 29, 2, 26, 13, 6, 10, 5, 29, 9, 3, 11, 11, 3, 10,\n", |
| 96 | " 9, 27, 2, 43, 4, 11, 11, 2, 26, 5, 10, 3, 3, 5, 58, 9, 2, 12,\n", |
| 97 | " 21, 7, 8, 12, 11, 7, 8, 18, 61, 22, 4, 8, 12, 2, 6, 19, 2, 15,\n", |
| 98 | " 11, 5, 10, 4, 29, 14, 20, 8, 7, 14, 9, 27, 2, 4, 10, 3, 2, 9,\n", |
| 99 | " 3, 3, 7, 8, 18, 2, 18, 10, 3, 3, 8, 2, 4, 18, 4, 7, 8, 23])" |
| 100 | ] |
| 101 | }, |
| 102 | "execution_count": 5, |
| 103 | "metadata": {}, |
| 104 | "output_type": "execute_result" |
| 105 | } |
| 106 | ], |
| 107 | "source": [ |
| 108 | "def enc(x):\n", |
| 109 | " return torch.LongTensor(encode(x,voc=vocab,tokenizer=char_tokenizer))\n", |
| 110 | "\n", |
| 111 | "enc(train_dataset[0][1])" |
| 112 | ] |
| 113 | }, |
| 114 | { |
| 115 | "cell_type": "markdown", |
| 116 | "metadata": {}, |
| 117 | "source": [ |
| 118 | "## Training a generative RNN\n", |
| 119 | "\n", |
| 120 | "The way we will train RNN to generate text is the following. On each step, we will take a sequence of characters of length `nchars`, and ask the network to generate next output character for each input character:\n", |
| 121 | "\n", |
| 122 | "\n", |
| 123 | "\n", |
| 124 | "Depending on the actual scenario, we may also want to include some special characters, such as *end-of-sequence* `<eos>`. In our case, we just want to train the network for endless text generation, thus we will fix the size of each sequence to be equal to `nchars` tokens. Consequently, each training example will consist of `nchars` inputs and `nchars` outputs (which are input sequence shifted one symbol to the left). Minibatch will consist of several such sequences.\n", |
| 125 | "\n", |
| 126 | "The way we will generate minibatches is to take each news text of length `l`, and generate all possible input-output combinations from it (there will be `l-nchars` such combinations). They will form one minibatch, and size of minibatches would be different at each training step. " |
| 127 | ] |
| 128 | }, |
| 129 | { |
| 130 | "cell_type": "code", |
| 131 | "execution_count": 8, |
| 132 | "metadata": {}, |
| 133 | "outputs": [ |
| 134 | { |
| 135 | "data": { |
| 136 | "text/plain": [ |
| 137 | "(tensor([[43, 4, 11, ..., 18, 61, 22],\n", |
| 138 | " [ 4, 11, 11, ..., 61, 22, 4],\n", |
| 139 | " [11, 11, 2, ..., 22, 4, 8],\n", |
| 140 | " ...,\n", |
| 141 | " [37, 3, 15, ..., 4, 18, 4],\n", |
| 142 | " [ 3, 15, 5, ..., 18, 4, 7],\n", |
| 143 | " [15, 5, 3, ..., 4, 7, 8]], device='cuda:0'),\n", |
| 144 | " tensor([[ 4, 11, 11, ..., 61, 22, 4],\n", |
| 145 | " [11, 11, 2, ..., 22, 4, 8],\n", |
| 146 | " [11, 2, 26, ..., 4, 8, 12],\n", |
| 147 | " ...,\n", |
| 148 | " [ 3, 15, 5, ..., 18, 4, 7],\n", |
| 149 | " [15, 5, 3, ..., 4, 7, 8],\n", |
| 150 | " [ 5, 3, 10, ..., 7, 8, 23]], device='cuda:0'))" |
| 151 | ] |
| 152 | }, |
| 153 | "execution_count": 8, |
| 154 | "metadata": {}, |
| 155 | "output_type": "execute_result" |
| 156 | } |
| 157 | ], |
| 158 | "source": [ |
| 159 | "nchars = 100\n", |
| 160 | "\n", |
| 161 | "def get_batch(s,nchars=nchars):\n", |
| 162 | " ins = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)\n", |
| 163 | " outs = torch.zeros(len(s)-nchars,nchars,dtype=torch.long,device=device)\n", |
| 164 | " for i in range(len(s)-nchars):\n", |
| 165 | " ins[i] = enc(s[i:i+nchars])\n", |
| 166 | " outs[i] = enc(s[i+1:i+nchars+1])\n", |
| 167 | " return ins,outs\n", |
| 168 | "\n", |
| 169 | "get_batch(train_dataset[0][1])" |
| 170 | ] |
| 171 | }, |
| 172 | { |
| 173 | "cell_type": "markdown", |
| 174 | "metadata": {}, |
| 175 | "source": [ |
| 176 | "Now let's define generator network. It can be based on any recurrent cell which we discussed in the previous unit (simple, LSTM or GRU). In our example we will use LSTM.\n", |
| 177 | "\n", |
| 178 | "Because the network takes characters as input, and vocabulary size is pretty small, we do not need embedding layer, one-hot-encoded input can directly go to LSTM cell. However, because we pass character numbers as input, we need to one-hot-encode them before passing to LSTM. This is done by calling `one_hot` function during `forward` pass. Output encoder would be a linear layer that will convert hidden state into one-hot-encoded output." |
| 179 | ] |
| 180 | }, |
| 181 | { |
| 182 | "cell_type": "code", |
| 183 | "execution_count": 9, |
| 184 | "metadata": {}, |
| 185 | "outputs": [], |
| 186 | "source": [ |
| 187 | "class LSTMGenerator(torch.nn.Module):\n", |
| 188 | " def __init__(self, vocab_size, hidden_dim):\n", |
| 189 | " super().__init__()\n", |
| 190 | " self.rnn = torch.nn.LSTM(vocab_size,hidden_dim,batch_first=True)\n", |
| 191 | " self.fc = torch.nn.Linear(hidden_dim, vocab_size)\n", |
| 192 | "\n", |
| 193 | " def forward(self, x, s=None):\n", |
| 194 | " x = torch.nn.functional.one_hot(x,vocab_size).to(torch.float32)\n", |
| 195 | " x,s = self.rnn(x,s)\n", |
| 196 | " return self.fc(x),s" |
| 197 | ] |
| 198 | }, |
| 199 | { |
| 200 | "cell_type": "markdown", |
| 201 | "metadata": {}, |
| 202 | "source": [ |
| 203 | "During training, we want to be able to sample generated text. To do that, we will define `generate` function that will produce output string of length `size`, starting from the initial string `start`.\n", |
| 204 | "\n", |
| 205 | "The way it works is the following. First, we will pass the whole start string through the network, and take output state `s` and next predicted character `out`. Since `out` is one-hot encoded, we take `argmax` to get the index of the character `nc` in the vocabulary, and use `itos` to figure out the actual character and append it to the resulting list of characters `chars`. This process of generating one character is repeated `size` times to generate required number of characters. " |
| 206 | ] |
| 207 | }, |
| 208 | { |
| 209 | "cell_type": "code", |
| 210 | "execution_count": 13, |
| 211 | "metadata": {}, |
| 212 | "outputs": [], |
| 213 | "source": [ |
| 214 | "def generate(net,size=100,start='today '):\n", |
| 215 | " chars = list(start)\n", |
| 216 | " out, s = net(enc(chars).view(1,-1).to(device))\n", |
| 217 | " for i in range(size):\n", |
| 218 | " nc = torch.argmax(out[0][-1])\n", |
| 219 | " chars.append(vocab.itos[nc])\n", |
| 220 | " out, s = net(nc.view(1,-1),s)\n", |
| 221 | " return ''.join(chars)" |
| 222 | ] |
| 223 | }, |
| 224 | { |
| 225 | "cell_type": "markdown", |
| 226 | "metadata": {}, |
| 227 | "source": [ |
| 228 | "Now let's do the training! Training loop is almost the same as in all our previous examples, but instead of accuracy we print sampled generated text every 1000 epochs.\n", |
| 229 | "\n", |
| 230 | "Special attention needs to be paid to the way we compute loss. We need to compute loss given one-hot-encoded output `out`, and expected text `text_out`, which is the list of character indices. Luckily, the `cross_entropy` function expects unnormalized network output as first argument, and class number as the second, which is exactly what we have. It also performs automatic averaging over minibatch size.\n", |
| 231 | "\n", |
| 232 | "We also limit the training by `samples_to_train` samples, in order not to wait for too long. We encourage you to experiment and try longer training, possibly for several epochs (in which case you would need to create another loop around this code)." |
| 233 | ] |
| 234 | }, |
| 235 | { |
| 236 | "cell_type": "code", |
| 237 | "execution_count": 14, |
| 238 | "metadata": {}, |
| 239 | "outputs": [ |
| 240 | { |
| 241 | "name": "stdout", |
| 242 | "output_type": "stream", |
| 243 | "text": [ |
| 244 | "Current loss = 4.442246913909912\n", |
| 245 | "today ggrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrgrg\n", |
| 246 | "Current loss = 2.1178359985351562\n", |
| 247 | "today and a could a the to the to the to the to the to the to the to the to the to the to the to the to th\n", |
| 248 | "Current loss = 1.6465336084365845\n", |
| 249 | "today on Tuesday the company to the United States and a policing to the United States and a policing to th\n", |
| 250 | "Current loss = 2.3716814517974854\n", |
| 251 | "today to the United States and a new men to the United States and a new men to the United States and a new\n", |
| 252 | "Current loss = 1.6844098567962646\n", |
| 253 | "today of the first the first the first the first the first the first the first the first the first the fir\n", |
| 254 | "Current loss = 1.702707052230835\n", |
| 255 | "today of the United States a said the United States a said the United States a said the United States a sa\n", |
| 256 | "Current loss = 1.9633255004882812\n", |
| 257 | "today of the first the first the first the first the first the first the first the first the first the fir\n", |
| 258 | "Current loss = 1.8642014265060425\n", |
| 259 | "today of the second a second a second a second a second a second a second a second a second a second a sec\n", |
| 260 | "Current loss = 1.7720613479614258\n", |
| 261 | "today and and and the company of the company of the company of the company of the company of the company o\n", |
| 262 | "Current loss = 1.52818763256073\n", |
| 263 | "today and the company of the company of the company of the company of the company of the company of the co\n", |
| 264 | "Current loss = 1.5444810390472412\n", |
| 265 | "today and the counters to the first the counters to the first the counters to the first the counters to th\n" |
| 266 | ] |
| 267 | } |
| 268 | ], |
| 269 | "source": [ |
| 270 | "net = LSTMGenerator(vocab_size,64).to(device)\n", |
| 271 | "\n", |
| 272 | "samples_to_train = 10000\n", |
| 273 | "optimizer = torch.optim.Adam(net.parameters(),0.01)\n", |
| 274 | "loss_fn = torch.nn.CrossEntropyLoss()\n", |
| 275 | "net.train()\n", |
| 276 | "for i,x in enumerate(train_dataset):\n", |
| 277 | " # x[0] is class label, x[1] is text\n", |
| 278 | " if len(x[1])-nchars<10:\n", |
| 279 | " continue\n", |
| 280 | " samples_to_train-=1\n", |
| 281 | " if not samples_to_train: break\n", |
| 282 | " text_in, text_out = get_batch(x[1])\n", |
| 283 | " optimizer.zero_grad()\n", |
| 284 | " out,s = net(text_in)\n", |
| 285 | " loss = torch.nn.functional.cross_entropy(out.view(-1,vocab_size),text_out.flatten()) #cross_entropy(out,labels)\n", |
| 286 | " loss.backward()\n", |
| 287 | " optimizer.step()\n", |
| 288 | " if i%1000==0:\n", |
| 289 | " print(f\"Current loss = {loss.item()}\")\n", |
| 290 | " print(generate(net))" |
| 291 | ] |
| 292 | }, |
| 293 | { |
| 294 | "cell_type": "markdown", |
| 295 | "metadata": {}, |
| 296 | "source": [ |
| 297 | "This example already generates some pretty good text, but it can be further improved in several ways:\n", |
| 298 | "* **Better minibatch generation**. The way we prepared data for training was to generate one minibatch from one sample. This is not ideal, because minibatches are all of different sizes, and some of them even cannot be generated, because the text is smaller than `nchars`. Also, small minibatches do not load GPU sufficiently enough. It would be wiser to get one large chunk of text from all samples, then generate all input-output pairs, shuffle them, and generate minibatches of equal size.\n", |
| 299 | "* **Multilayer LSTM**. It makes sense to try 2 or 3 layers of LSTM cells. As we mentioned in the previous unit, each layer of LSTM extracts certain patterns from text, and in case of character-level generator we can expect lower LSTM level to be responsible for extracting syllables, and higher levels - for words and word combinations. This can be simply implemented by passing number-of-layers parameter to LSTM constructor.\n", |
| 300 | "* You may also want to experiment with **GRU units** and see which ones perform better, and with **different hidden layer sizes**. Too large hidden layer may result in overfitting (e.g. network will learn exact text), and smaller size might not produce good result." |
| 301 | ] |
| 302 | }, |
| 303 | { |
| 304 | "cell_type": "markdown", |
| 305 | "metadata": {}, |
| 306 | "source": [ |
| 307 | "## Soft text generation and temperature\n", |
| 308 | "\n", |
| 309 | "In the previous definition of `generate`, we were always taking the character with highest probability as the next character in generated text. This resulted in the fact that the text often \"cycled\" between the same character sequences again and again, like in this example:\n", |
| 310 | "```\n", |
| 311 | "today of the second the company and a second the company ...\n", |
| 312 | "```\n", |
| 313 | "\n", |
| 314 | "However, if we look at the probability distribution for the next character, it could be that the difference between a few highest probabilities is not huge, e.g. one character can have probability 0.2, another - 0.19, etc. For example, when looking for the next character in the sequence '*play*', next character can equally well be either space, or **e** (as in the word *player*).\n", |
| 315 | "\n", |
| 316 | "This leads us to the conclusion that it is not always \"fair\" to select the character with higher probability, because choosing the second highest might still lead us to meaningful text. It is more wise to **sample** characters from the probability distribution given by the network output.\n", |
| 317 | "\n", |
| 318 | "This sampling can be done using `multinomial` function that implements so-called **multinomial distribution**. A function that implements this **soft** text generation is defined below:" |
| 319 | ] |
| 320 | }, |
| 321 | { |
| 322 | "cell_type": "code", |
| 323 | "execution_count": 15, |
| 324 | "metadata": { |
| 325 | "scrolled": true |
| 326 | }, |
| 327 | "outputs": [ |
| 328 | { |
| 329 | "name": "stdout", |
| 330 | "output_type": "stream", |
| 331 | "text": [ |
| 332 | "--- Temperature = 0.3\n", |
| 333 | "Today and to has a software to in the first the power the gold medal was of the first and succer to the company will a report the first the and the gain the company in the and a new a report a pack of the four the first the company of the such with the half to a security to the and a success the first she\n", |
| 334 | "\n", |
| 335 | "--- Temperature = 0.8\n", |
| 336 | "Today drud out of the three-rent possiem that sales purssion has finminiaty women's from NAC Inc. (AP) -- Shimbon has weel with a may stelight first three flaw gold from their a scent, big study with a nighting sovicturner has slarh football at a hour of Angelage discression, into cubs, US year player sor\n", |
| 337 | "\n", |
| 338 | "--- Temperature = 1.0\n", |
| 339 | "Today by compoy, said to hup the couns ay rrope iist\\fill sinie-5-1- than he of a fightier Corp. the Vew, Mkli Unite Hold Austria on Tuesday resfare rextarted in the new has buy thisnillials thrust first capuration of the it larget expected the ir edulagy Airin Penny after Emonet Cuc Washieve an are Gurry\n", |
| 340 | "\n", |
| 341 | "--- Temperature = 1.3\n", |
| 342 | "Today cluscy,, wangled and-ox they, stee of as;\\seculity dillancrile inmution svanse gall ATHEYS today a first oresift 6-Jalf mangback explymate that wrook\" haffic illowbre overwage in Tecrian Hunrieleers to attowny service Adching, blanks governine? Aug. : : NE: Sir NFP (P2AAU) Bow SWDE: The ex2\"cut Pmoc\n", |
| 343 | "\n", |
| 344 | "--- Temperature = 1.8\n", |
| 345 | "Today sas gom, twing hWe a Dajfcou hamb--5 to bemolecresem ig irkembets plentll repws, scatchey: Actuss.io Theffouge, cirw biggemed Goiga propperinut you racive #5-Aeia:riato..Lf. N7TNap:,ser,wploy a Fraull tbashonerdlantuanseve /bBT -$06 Wemob-e.P EvVlaicy), ZOf0 cUSeballd sturk out houselty, TARZM) AbAe\n", |
| 346 | "\n" |
| 347 | ] |
| 348 | } |
| 349 | ], |
| 350 | "source": [ |
| 351 | "def generate_soft(net,size=100,start='today ',temperature=1.0):\n", |
| 352 | " chars = list(start)\n", |
| 353 | " out, s = net(enc(chars).view(1,-1).to(device))\n", |
| 354 | " for i in range(size):\n", |
| 355 | " #nc = torch.argmax(out[0][-1])\n", |
| 356 | " out_dist = out[0][-1].div(temperature).exp()\n", |
| 357 | " nc = torch.multinomial(out_dist,1)[0]\n", |
| 358 | " chars.append(vocab.itos[nc])\n", |
| 359 | " out, s = net(nc.view(1,-1),s)\n", |
| 360 | " return ''.join(chars)\n", |
| 361 | " \n", |
| 362 | "for i in [0.3,0.8,1.0,1.3,1.8]:\n", |
| 363 | " print(f\"--- Temperature = {i}\\n{generate_soft(net,size=300,start='Today ',temperature=i)}\\n\")" |
| 364 | ] |
| 365 | }, |
| 366 | { |
| 367 | "cell_type": "markdown", |
| 368 | "metadata": {}, |
| 369 | "source": [ |
| 370 | "We have introduced one more parameter called **temperature**, which is used to indicate how hard we should stick to the highest probability. If temperature is 1.0, we do fair multinomial sampling, and when temperature goes to infinity - all probabilities become equal, and we randomly select next character. In the example below we can observe that the text becomes meaningless when we increase the temperature too much, and it resembles \"cycled\" hard-generated text when it becomes closer to 0. " |
| 371 | ] |
| 372 | } |
| 373 | ], |
| 374 | "metadata": { |
| 375 | "kernelspec": { |
| 376 | "display_name": "py37_pytorch", |
| 377 | "language": "python", |
| 378 | "name": "conda-env-py37_pytorch-py" |
| 379 | }, |
| 380 | "language_info": { |
| 381 | "codemirror_mode": { |
| 382 | "name": "ipython", |
| 383 | "version": 3 |
| 384 | }, |
| 385 | "file_extension": ".py", |
| 386 | "mimetype": "text/x-python", |
| 387 | "name": "python", |
| 388 | "nbconvert_exporter": "python", |
| 389 | "pygments_lexer": "ipython3", |
| 390 | "version": "3.7.7" |
| 391 | } |
| 392 | }, |
| 393 | "nbformat": 4, |
| 394 | "nbformat_minor": 4 |
| 395 | } |
| 396 | |