microsoft/AI-For-Beginners
Publicmirrored fromhttps://github.com/microsoft/AI-For-BeginnersAvailable
lessons/4-ComputerVision/07-ConvNets/pytorchcv.py
169lines · modecode
| 1 | |
| 2 | # Script file to hide implementation details for PyTorch computer vision module |
| 3 | |
| 4 | import builtins |
| 5 | import torch |
| 6 | import torch.nn as nn |
| 7 | from torch.utils import data |
| 8 | import torchvision |
| 9 | from torchvision.transforms import ToTensor |
| 10 | import matplotlib.pyplot as plt |
| 11 | import numpy as np |
| 12 | from PIL import Image |
| 13 | import glob |
| 14 | import os |
| 15 | import zipfile |
| 16 | |
| 17 | default_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 18 | |
| 19 | def load_mnist(batch_size=64): |
| 20 | builtins.data_train = torchvision.datasets.MNIST('./data', |
| 21 | download=True,train=True,transform=ToTensor()) |
| 22 | builtins.data_test = torchvision.datasets.MNIST('./data', |
| 23 | download=True,train=False,transform=ToTensor()) |
| 24 | builtins.train_loader = torch.utils.data.DataLoader(data_train,batch_size=batch_size) |
| 25 | builtins.test_loader = torch.utils.data.DataLoader(data_test,batch_size=batch_size) |
| 26 | |
| 27 | def train_epoch(net,dataloader,lr=0.01,optimizer=None,loss_fn = nn.NLLLoss()): |
| 28 | optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr) |
| 29 | net.train() |
| 30 | total_loss,acc,count = 0,0,0 |
| 31 | for features,labels in dataloader: |
| 32 | optimizer.zero_grad() |
| 33 | lbls = labels.to(default_device) |
| 34 | out = net(features.to(default_device)) |
| 35 | loss = loss_fn(out,lbls) #cross_entropy(out,labels) |
| 36 | loss.backward() |
| 37 | optimizer.step() |
| 38 | total_loss+=loss |
| 39 | _,predicted = torch.max(out,1) |
| 40 | acc+=(predicted==lbls).sum() |
| 41 | count+=len(labels) |
| 42 | return total_loss.item()/count, acc.item()/count |
| 43 | |
| 44 | def validate(net, dataloader,loss_fn=nn.NLLLoss()): |
| 45 | net.eval() |
| 46 | count,acc,loss = 0,0,0 |
| 47 | with torch.no_grad(): |
| 48 | for features,labels in dataloader: |
| 49 | lbls = labels.to(default_device) |
| 50 | out = net(features.to(default_device)) |
| 51 | loss += loss_fn(out,lbls) |
| 52 | pred = torch.max(out,1)[1] |
| 53 | acc += (pred==lbls).sum() |
| 54 | count += len(labels) |
| 55 | return loss.item()/count, acc.item()/count |
| 56 | |
| 57 | def train(net,train_loader,test_loader,optimizer=None,lr=0.01,epochs=10,loss_fn=nn.NLLLoss()): |
| 58 | optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr) |
| 59 | res = { 'train_loss' : [], 'train_acc': [], 'val_loss': [], 'val_acc': []} |
| 60 | for ep in range(epochs): |
| 61 | tl,ta = train_epoch(net,train_loader,optimizer=optimizer,lr=lr,loss_fn=loss_fn) |
| 62 | vl,va = validate(net,test_loader,loss_fn=loss_fn) |
| 63 | print(f"Epoch {ep:2}, Train acc={ta:.3f}, Val acc={va:.3f}, Train loss={tl:.3f}, Val loss={vl:.3f}") |
| 64 | res['train_loss'].append(tl) |
| 65 | res['train_acc'].append(ta) |
| 66 | res['val_loss'].append(vl) |
| 67 | res['val_acc'].append(va) |
| 68 | return res |
| 69 | |
| 70 | def train_long(net,train_loader,test_loader,epochs=5,lr=0.01,optimizer=None,loss_fn = nn.NLLLoss(),print_freq=10): |
| 71 | optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr) |
| 72 | for epoch in range(epochs): |
| 73 | net.train() |
| 74 | total_loss,acc,count = 0,0,0 |
| 75 | for i, (features,labels) in enumerate(train_loader): |
| 76 | lbls = labels.to(default_device) |
| 77 | optimizer.zero_grad() |
| 78 | out = net(features.to(default_device)) |
| 79 | loss = loss_fn(out,lbls) |
| 80 | loss.backward() |
| 81 | optimizer.step() |
| 82 | total_loss+=loss |
| 83 | _,predicted = torch.max(out,1) |
| 84 | acc+=(predicted==lbls).sum() |
| 85 | count+=len(labels) |
| 86 | if i%print_freq==0: |
| 87 | print("Epoch {}, minibatch {}: train acc = {}, train loss = {}".format(epoch,i,acc.item()/count,total_loss.item()/count)) |
| 88 | vl,va = validate(net,test_loader,loss_fn) |
| 89 | print("Epoch {} done, validation acc = {}, validation loss = {}".format(epoch,va,vl)) |
| 90 | |
| 91 | |
| 92 | def plot_results(hist): |
| 93 | plt.figure(figsize=(15,5)) |
| 94 | plt.subplot(121) |
| 95 | plt.plot(hist['train_acc'], label='Training acc') |
| 96 | plt.plot(hist['val_acc'], label='Validation acc') |
| 97 | plt.legend() |
| 98 | plt.subplot(122) |
| 99 | plt.plot(hist['train_loss'], label='Training loss') |
| 100 | plt.plot(hist['val_loss'], label='Validation loss') |
| 101 | plt.legend() |
| 102 | |
| 103 | def plot_convolution(t,title=''): |
| 104 | with torch.no_grad(): |
| 105 | c = nn.Conv2d(kernel_size=(3,3),out_channels=1,in_channels=1) |
| 106 | c.weight.copy_(t) |
| 107 | fig, ax = plt.subplots(2,6,figsize=(8,3)) |
| 108 | fig.suptitle(title,fontsize=16) |
| 109 | for i in range(5): |
| 110 | im = data_train[i][0] |
| 111 | ax[0][i].imshow(im[0]) |
| 112 | ax[1][i].imshow(c(im.unsqueeze(0))[0][0]) |
| 113 | ax[0][i].axis('off') |
| 114 | ax[1][i].axis('off') |
| 115 | ax[0,5].imshow(t) |
| 116 | ax[0,5].axis('off') |
| 117 | ax[1,5].axis('off') |
| 118 | #plt.tight_layout() |
| 119 | plt.show() |
| 120 | |
| 121 | def display_dataset(dataset, n=10,classes=None): |
| 122 | fig,ax = plt.subplots(1,n,figsize=(15,3)) |
| 123 | mn = min([dataset[i][0].min() for i in range(n)]) |
| 124 | mx = max([dataset[i][0].max() for i in range(n)]) |
| 125 | for i in range(n): |
| 126 | ax[i].imshow(np.transpose((dataset[i][0]-mn)/(mx-mn),(1,2,0))) |
| 127 | ax[i].axis('off') |
| 128 | if classes: |
| 129 | ax[i].set_title(classes[dataset[i][1]]) |
| 130 | |
| 131 | |
| 132 | def check_image(fn): |
| 133 | try: |
| 134 | im = Image.open(fn) |
| 135 | im.verify() |
| 136 | return True |
| 137 | except: |
| 138 | return False |
| 139 | |
| 140 | def check_image_dir(path): |
| 141 | for fn in glob.glob(path): |
| 142 | if not check_image(fn): |
| 143 | print("Corrupt image: {}".format(fn)) |
| 144 | os.remove(fn) |
| 145 | |
| 146 | |
| 147 | def common_transform(): |
| 148 | std_normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 149 | std=[0.229, 0.224, 0.225]) |
| 150 | trans = torchvision.transforms.Compose([ |
| 151 | torchvision.transforms.Resize(256), |
| 152 | torchvision.transforms.CenterCrop(224), |
| 153 | torchvision.transforms.ToTensor(), |
| 154 | std_normalize]) |
| 155 | return trans |
| 156 | |
| 157 | def load_cats_dogs_dataset(): |
| 158 | if not os.path.exists('data/PetImages'): |
| 159 | with zipfile.ZipFile('data/kagglecatsanddogs_3367a.zip', 'r') as zip_ref: |
| 160 | zip_ref.extractall('data') |
| 161 | |
| 162 | check_image_dir('data/PetImages/Cat/*.jpg') |
| 163 | check_image_dir('data/PetImages/Dog/*.jpg') |
| 164 | |
| 165 | dataset = torchvision.datasets.ImageFolder('data/PetImages',transform=common_transform()) |
| 166 | trainset, testset = torch.utils.data.random_split(dataset,[20000,len(dataset)-20000]) |
| 167 | trainloader = torch.utils.data.DataLoader(trainset,batch_size=32) |
| 168 | testloader = torch.utils.data.DataLoader(trainset,batch_size=32) |
| 169 | return dataset, trainloader, testloader |