microsoft/AI-For-Beginners

Public

mirrored fromhttps://github.com/microsoft/AI-For-BeginnersAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
278c50a748972c5ee148537f45d25a9a773b32ee

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

lessons/4-ComputerVision/07-ConvNets/pytorchcv.py

169lines · modecode

1
2# Script file to hide implementation details for PyTorch computer vision module
3
4import builtins
5import torch
6import torch.nn as nn
7from torch.utils import data
8import torchvision
9from torchvision.transforms import ToTensor
10import matplotlib.pyplot as plt
11import numpy as np
12from PIL import Image
13import glob
14import os
15import zipfile
16
17default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
19def load_mnist(batch_size=64):
20 builtins.data_train = torchvision.datasets.MNIST('./data',
21 download=True,train=True,transform=ToTensor())
22 builtins.data_test = torchvision.datasets.MNIST('./data',
23 download=True,train=False,transform=ToTensor())
24 builtins.train_loader = torch.utils.data.DataLoader(data_train,batch_size=batch_size)
25 builtins.test_loader = torch.utils.data.DataLoader(data_test,batch_size=batch_size)
26
27def train_epoch(net,dataloader,lr=0.01,optimizer=None,loss_fn = nn.NLLLoss()):
28 optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
29 net.train()
30 total_loss,acc,count = 0,0,0
31 for features,labels in dataloader:
32 optimizer.zero_grad()
33 lbls = labels.to(default_device)
34 out = net(features.to(default_device))
35 loss = loss_fn(out,lbls) #cross_entropy(out,labels)
36 loss.backward()
37 optimizer.step()
38 total_loss+=loss
39 _,predicted = torch.max(out,1)
40 acc+=(predicted==lbls).sum()
41 count+=len(labels)
42 return total_loss.item()/count, acc.item()/count
43
44def validate(net, dataloader,loss_fn=nn.NLLLoss()):
45 net.eval()
46 count,acc,loss = 0,0,0
47 with torch.no_grad():
48 for features,labels in dataloader:
49 lbls = labels.to(default_device)
50 out = net(features.to(default_device))
51 loss += loss_fn(out,lbls)
52 pred = torch.max(out,1)[1]
53 acc += (pred==lbls).sum()
54 count += len(labels)
55 return loss.item()/count, acc.item()/count
56
57def train(net,train_loader,test_loader,optimizer=None,lr=0.01,epochs=10,loss_fn=nn.NLLLoss()):
58 optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
59 res = { 'train_loss' : [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
60 for ep in range(epochs):
61 tl,ta = train_epoch(net,train_loader,optimizer=optimizer,lr=lr,loss_fn=loss_fn)
62 vl,va = validate(net,test_loader,loss_fn=loss_fn)
63 print(f"Epoch {ep:2}, Train acc={ta:.3f}, Val acc={va:.3f}, Train loss={tl:.3f}, Val loss={vl:.3f}")
64 res['train_loss'].append(tl)
65 res['train_acc'].append(ta)
66 res['val_loss'].append(vl)
67 res['val_acc'].append(va)
68 return res
69
70def train_long(net,train_loader,test_loader,epochs=5,lr=0.01,optimizer=None,loss_fn = nn.NLLLoss(),print_freq=10):
71 optimizer = optimizer or torch.optim.Adam(net.parameters(),lr=lr)
72 for epoch in range(epochs):
73 net.train()
74 total_loss,acc,count = 0,0,0
75 for i, (features,labels) in enumerate(train_loader):
76 lbls = labels.to(default_device)
77 optimizer.zero_grad()
78 out = net(features.to(default_device))
79 loss = loss_fn(out,lbls)
80 loss.backward()
81 optimizer.step()
82 total_loss+=loss
83 _,predicted = torch.max(out,1)
84 acc+=(predicted==lbls).sum()
85 count+=len(labels)
86 if i%print_freq==0:
87 print("Epoch {}, minibatch {}: train acc = {}, train loss = {}".format(epoch,i,acc.item()/count,total_loss.item()/count))
88 vl,va = validate(net,test_loader,loss_fn)
89 print("Epoch {} done, validation acc = {}, validation loss = {}".format(epoch,va,vl))
90
91
92def plot_results(hist):
93 plt.figure(figsize=(15,5))
94 plt.subplot(121)
95 plt.plot(hist['train_acc'], label='Training acc')
96 plt.plot(hist['val_acc'], label='Validation acc')
97 plt.legend()
98 plt.subplot(122)
99 plt.plot(hist['train_loss'], label='Training loss')
100 plt.plot(hist['val_loss'], label='Validation loss')
101 plt.legend()
102
103def plot_convolution(t,title=''):
104 with torch.no_grad():
105 c = nn.Conv2d(kernel_size=(3,3),out_channels=1,in_channels=1)
106 c.weight.copy_(t)
107 fig, ax = plt.subplots(2,6,figsize=(8,3))
108 fig.suptitle(title,fontsize=16)
109 for i in range(5):
110 im = data_train[i][0]
111 ax[0][i].imshow(im[0])
112 ax[1][i].imshow(c(im.unsqueeze(0))[0][0])
113 ax[0][i].axis('off')
114 ax[1][i].axis('off')
115 ax[0,5].imshow(t)
116 ax[0,5].axis('off')
117 ax[1,5].axis('off')
118 #plt.tight_layout()
119 plt.show()
120
121def display_dataset(dataset, n=10,classes=None):
122 fig,ax = plt.subplots(1,n,figsize=(15,3))
123 mn = min([dataset[i][0].min() for i in range(n)])
124 mx = max([dataset[i][0].max() for i in range(n)])
125 for i in range(n):
126 ax[i].imshow(np.transpose((dataset[i][0]-mn)/(mx-mn),(1,2,0)))
127 ax[i].axis('off')
128 if classes:
129 ax[i].set_title(classes[dataset[i][1]])
130
131
132def check_image(fn):
133 try:
134 im = Image.open(fn)
135 im.verify()
136 return True
137 except:
138 return False
139
140def check_image_dir(path):
141 for fn in glob.glob(path):
142 if not check_image(fn):
143 print("Corrupt image: {}".format(fn))
144 os.remove(fn)
145
146
147def common_transform():
148 std_normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
149 std=[0.229, 0.224, 0.225])
150 trans = torchvision.transforms.Compose([
151 torchvision.transforms.Resize(256),
152 torchvision.transforms.CenterCrop(224),
153 torchvision.transforms.ToTensor(),
154 std_normalize])
155 return trans
156
157def load_cats_dogs_dataset():
158 if not os.path.exists('data/PetImages'):
159 with zipfile.ZipFile('data/kagglecatsanddogs_3367a.zip', 'r') as zip_ref:
160 zip_ref.extractall('data')
161
162 check_image_dir('data/PetImages/Cat/*.jpg')
163 check_image_dir('data/PetImages/Dog/*.jpg')
164
165 dataset = torchvision.datasets.ImageFolder('data/PetImages',transform=common_transform())
166 trainset, testset = torch.utils.data.random_split(dataset,[20000,len(dataset)-20000])
167 trainloader = torch.utils.data.DataLoader(trainset,batch_size=32)
168 testloader = torch.utils.data.DataLoader(trainset,batch_size=32)
169 return dataset, trainloader, testloader