microsoft/AI-For-Beginners

Public

mirrored fromhttps://github.com/microsoft/AI-For-BeginnersAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
278c50a748972c5ee148537f45d25a9a773b32ee

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

lessons/4-ComputerVision/07-ConvNets/tfcv.py

86lines · modecode

1# Tensorflow Computer Vision Helper
2
3import tensorflow as tf
4from tensorflow import keras
5import numpy as np
6import matplotlib.pyplot as plt
7from PIL import Image
8import glob
9import os
10
11def plot_convolution(data,t,title=''):
12 fig, ax = plt.subplots(2,len(data)+1,figsize=(8,3))
13 fig.suptitle(title,fontsize=16)
14 tt = np.expand_dims(np.expand_dims(t,2),2)
15 for i,im in enumerate(data):
16 ax[0][i].imshow(im)
17 ximg = np.expand_dims(np.expand_dims(im,2),0)
18 cim = tf.nn.conv2d(ximg,tt,1,'SAME')
19 ax[1][i].imshow(cim[0][:,:,0])
20 ax[0][i].axis('off')
21 ax[1][i].axis('off')
22 ax[0,-1].imshow(t)
23 ax[0,-1].axis('off')
24 ax[1,-1].axis('off')
25 #plt.tight_layout()
26 plt.show()
27
28def plot_results(hist):
29 fig,ax = plt.subplots(1,2,figsize=(15,3))
30 ax[0].set_title('Accuracy')
31 ax[1].set_title('Loss')
32 for x in ['acc','val_acc']:
33 ax[0].plot(hist.history[x])
34 for x in ['loss','val_loss']:
35 ax[1].plot(hist.history[x])
36 plt.show()
37
38def display_dataset(dataset, labels=None, n=10, classes=None):
39 fig,ax = plt.subplots(1,n,figsize=(15,3))
40 for i in range(n):
41 ax[i].imshow(dataset[i])
42 ax[i].axis('off')
43 if classes is not None and labels is not None:
44 ax[i].set_title(classes[labels[i][0]])
45
46def check_image(fn):
47 try:
48 im = Image.open(fn)
49 im.verify()
50 return im.format=='JPEG'
51 except:
52 return False
53
54def check_image_dir(path):
55 for fn in glob.glob(path):
56 if not check_image(fn):
57 print("Corrupt image or wrong format: {}".format(fn))
58 os.remove(fn)
59
60def load_cats_dogs_dataset(batch_size=64):
61 if not os.path.exists('data/PetImages'):
62 print("Extracting the dataset")
63 with zipfile.ZipFile('data/kagglecatsanddogs_3367a.zip', 'r') as zip_ref:
64 zip_ref.extractall('data')
65 print("Checking dataset")
66 check_image_dir('data/PetImages/Cat/*.jpg')
67 check_image_dir('data/PetImages/Dog/*.jpg')
68 data_dir = 'data/PetImages'
69 print("Loading dataset")
70 ds_train = keras.preprocessing.image_dataset_from_directory(
71 data_dir,
72 validation_split = 0.2,
73 subset = 'training',
74 seed = 13,
75 image_size = (224,224),
76 batch_size = batch_size
77 )
78 ds_test = keras.preprocessing.image_dataset_from_directory(
79 data_dir,
80 validation_split = 0.2,
81 subset = 'validation',
82 seed = 13,
83 image_size = (224,224),
84 batch_size = batch_size
85 )
86 return ds_train,ds_test
87