當前位置:
首頁 > 知識 > 博客 CIFAR10 數據預處理

博客 CIFAR10 數據預處理

本系列文章已由作者授權在AI研習社發布。

歡迎關注我的AI研習社博客:

http://www.gair.link/page/center/myPage/5104751,

或訂閱我的 CSDN:

https://blog.csdn.net/Kuo_Jun_Lin

Brief 概述

在上一章中我們使用了 MNIST 手寫數字數據集,套入一個非常簡單的線性模型中,得到了大約 90% 左右的正確率,用意在於熟悉神經網路節點的架構和框架的使用方法,接下來這章將把前一章的數據集和方法全面提升一個檔次,使用的是 CIFAR10 與 CNN 卷積神經網路的架構,同時也可以做為探討深層神經網路如 VGG19,GoogleNet,與 ResNet 的敲門磚。

CNN 卷積神經網路假設大家已經有一個大致的了解,它不像線性回歸的方法,從每個像素著手發現歸類到不同標籤的規則,而是使用卷積核逐步掃描整張圖片的方式抽取出圖像特徵,經過逐個卷積和在逐層維度上特徵的抽離處理,最終把他們與全連階層相連,通往標籤的歸類,但是說著簡單,其實操作上還有許多細節需要注意如下面幾點:

借鑒上一次的代碼運行過程,首先第一件事就是減少「類」中函數的冗長定義,因為每次呼叫類的方法時,其實類中的內容都會被重新刷新一遍,多次反覆下來就是一個有負擔的計算量。

圖片不再是簡單的手寫數字,CIFAR10 有背景與對應標籤的圖案,因此為了更好的訓練,圖片需要做預處理,隨機的:旋轉角度,灰階度,對比度,圖像尺寸調整,明暗度,色調,與裁切,都是可以嘗試的手法。

下面我們將探討一個數據集在多個維度的比較,並嘗試出最好正確率的排列組合。

p.s. 卷積神經網路搭建開始之前必須先確定自身電腦內存是否 >= 8G,雖然這個網路在 CNN 演算法中非常簡單,但如果從最開始的神經網路加總,一共也會有幾十萬個參數的量,需要注意電腦是否能夠承載。

Code[1]

import sys

import tensorflowastf

print(tf.__version__)

1.10.1


CIFAR10 Dataset

它是一個內涵六萬張圖片的數據集,比起 MNIST,它的通道數是三個,用來表示其彩色的畫面,並且圖像尺寸是 32*32,其中分成訓練集五萬張與測試集一萬張,製作人在打包數據的時候分了幾個文檔如下圖:

其內部排列方式為一個大的字典,圖片數據對應 "data" 字典鍵,標籤數據對應 "labels" 字典鍵,而單張圖片數據排布方式為一個一維列表:[...1024 red ... ...1024 green... ...1024 blue...],讀取的方式可以直接點擊官網網址。

為了使自己能夠更加熟悉數據集內部結構的解析,同時 CIFAR10 官網只告知了打開它們數據集的方法,我們需要如使用 MNIST 的情況一樣開始自己定義我們所需要的函數,不外乎數據讀取,數據格式轉換,one_hot 等大類,步驟如下:


1. Define functions without being iterated with class

定義的函數分別在如下陳列:

time_counter(): 是一個裝飾器,功能是用來計時一個函數啟動的時間

one_hot(): 用來把標籤轉換成 one hot 形式,方便後面神經網路歸類匹配使用

get_random_batch(): 隨機抽取樣本做為一個簇後,方便小批量訓練

Code[2]

importtime

importnumpyasnp

# To set a decorator used to count the time a func spent.

deftime_counter(func):

# In order to count many func"s time, arguments should be *args and **kwargs

defwrapper(*args, **kwargs):

t1 = time.time()

result = func(*args, **kwargs)

t2 = time.time() - t1

print("Took sec to run "" func".format(t2, func.__name__))

returnresult

returnwrapper

# To convert the number labels into one hot mode respectively.

defone_hot(labels, class_num=10):

convert = np.eye(class_num, dtype=float)[labels]

returnconvert

# To get a random batch so that we can easily put data to train a model.

defget_random_batch(data, batch_size=32):

random = np.random.randint(, len(data), size=batch_size)

returndata[random]


2. Define a class used to well organized take apart the dataset

由於數據是呈現 5 個批次儲存,其中的函數設定我希望把他們融合成一塊,後面處理和調用也表方便,並且其圖片大小為 32x32 的尺寸,並不至於大到沒辦法一次容納,因此設置的函數方法如下陳列:

load_binary_data(): 把二進位數據讀取出來,並依照字典鍵的要求給出一個 numpy 數組的結果,方便後面數據處理

merge_batches(): 把全部批次的訓練集數據全部融合起來成為一個大的數組

set_validation(): 設置一個驗證集在訓練集的比例,如果有不同的模型搭建可能會用到此功能

format_images(): 把一個 1D 向量表示的數據轉換成卷積方法需要用到的 4D 格式(Batch, Height, Width, Channels)

Code[3]

# pickle is the module to open cifar10 dataset

import pickle

import os, sys

# This class is used to refer the arranged content of CIFAR10 dataset

classCIFAR10:

# The unchangeable variables should be set here.

image_size =32

image_channels =3

def__init__(self, val_ratio=.1, data_dir="cifar-10-batches-py"):

# Validation set can also be set if it is necessary for other purposes

self.val_ratio = val_ratio

self.data_dir = data_dir

# Get the overall images data "without formatting"!

self.img_train =self.merge_batches("data")

self.img_train_main,self.img_train_val =self.set_validation(self.img_train)

self.lab_train =self.merge_batches("labels")

self.lab_train_main,self.lab_train_val =self.set_validation(self.lab_train)

self.img_test =self.load_binary_data("test_batch","data") /255.0

self.lab_test =self.load_binary_data("test_batch","labels").astype(np.int)

# The data format is binary mode and we should load them with pickle module

# which is introduced at the official web page.

defload_binary_data(self, file_name, dic_key):

path = os.path.join(self.data_dir, file_name)

with open(path,"rb") asfile:

dic = pickle.load(file, encoding="bytes")

# Those binary data are all contained by a dictionary also with

# binary type of dictionary key. The returned list should also be

# converted into np.array so that it can be indexed conveniently.

try:

dic_key = dic_key.encode(encoding="utf-8")

returnnp.array(dic[dic_key])

except:

print("dic_key argument accepts only 4 keys as follow:
",

"1.batch_label ; 2.labels ; 3.data ; 4.filenames")

# There are five separated images dataset and we will want to

# depose of them all at once.

defmerge_batches(self, dic_key):

merge = []

foriinrange(5):

filename ="data_batch_{}".format(i+1)

data =self.load_binary_data(filename, dic_key)

merge.append(data)

np_merge = np.array(merge)

ifdic_key =="data":

length =self.image_size *self.image_size *self.image_channels

np_merge = np_merge.reshape(5*len(data), length)

returnnp.array(np_merge) /255.0

else:

np_merge = np_merge.reshape(5*len(data))

returnnp.array(np_merge).astype(np.int)

defset_validation(self, data):

val_set = round(len(data) *self.val_ratio)

val_data = data[:val_set]

main_data = data[val_set:]

return[main_data, val_data]

# The 1D array representing an image should be converted to the format

# that is as same as the regular image format (H, W, C)

defformat_images(self, images_flat):

# The format of original data has (10000, 3072) shape matrix

# with conjoint red 1024, green 1024, blue 1024.

images = images_flat.reshape([-1,self.image_channels,

self.image_size,self.image_size])

# when depositing images, channels should stay at the last dimension.

images = images.transpose([,2,3,1])

returnimages

@property

defget_class_names(self):

path = os.path.join(self.data_dir,"batches.meta")

with open(path,"rb") asfile:

dic = pickle.load(file, encoding="bytes")

class_names = [w.decode("utf-8")forwindic[b"label_names"]]

fornum, labelinenumerate(class_names):

print("{}: {}".format(num, label))

returnclass_names

@property

defnum_per_batch(self):

path = os.path.join(self.data_dir,"batches.meta")

with open(path,"rb") asfile:

dic = pickle.load(file, encoding="bytes")

returndic[b"num_cases_per_batch"]

path = input("The directory of CIFAR10 dataset: ")

cifar = CIFAR1(data_dir=path)

cifar.get_class_names

print("Number per batch: {}".format(cifar.num_per_batch))

The directory of CIFAR10 dataset:/Users/kcl/Documents/Python_Projects/cifar-10-batches-py

: airplane

1: automobile

2: bird

3:cat

4: deer

5: dog

6: frog

7: horse

8: ship

9: truck

Number per batch:10000


3. Print Images and Labels respectively

為了驗證導入的數據集是否與標籤匹配,避免在模型訓練前數據集基礎就已經歪得一塌糊塗,結合了上面定義的 .format_images() 方法與 get_random_batch() 函數套入以下定義的繪圖函數中,隨機抽樣查看數據匹配的完整性,代碼如下:

Code[4]

importmatplotlib.pyplotasplt

images_flat_train = cifar.img_train

images_train = cifar.format_images(images_flat_train)

labels_train = cifar.lab_train

images_flat_test = cifar.img_test

images_test = cifar.format_images(images_flat_test)

labels_test = cifar.lab_test

# To define a universal purpose oriented plotting function here.

# It should not only be able to plot correct images, but also is

# capable of plotting the predicted labels.

defplot_images(images, labels, lab_names, size=[3,3],

pred_labels=False, random=True, smooth=True):

fig, axes = plt.subplots(size[], size[1])

fig.subplots_adjust(hspace=0.6, wspace=0.6)

forn, axinenumerate(axes.flat):

# To decide if the printed images should be smooth or not.

ifsmooth:

interpolation ="spline16"

else:

interpolation ="nearest"

# To decide if the images should be randomly picked up.

ifrandom:

i = np.random.randint(, len(labels), size=None, dtype=np.int)

else:

i = n

ax.imshow(images[i], interpolation=interpolation)

ifpred_labelsisFalse:

xlabel ="T: {}".format(lab_names[labels[i]])

else:

xlabel ="T:
P:".format(lab_names[labels[i]],

lab_names[pred_labels[i]])

ax.set_xlabel(xlabel)

ax.set_xticks([])

ax.set_yticks([])

plt.show()

plot_images(images_train, labels_train, cifar.get_class_names, size=[3,5])

: airplane

1: automobile

2: bird

3: cat

4: deer

5: dog

6: frog

7: horse

8: ship

9: truck


Data Preprocessing 數據預處理

如同概述部分提及的圖像預處理步驟,接下來要使用下面 Tensorflow 所提供的方法來實現圖像的隨機改動:

p.s. 還有很多 Tensorflow 框架支持的圖像處理方法,點擊此查看官網

對輸入數據使用上面函數方法做改動就如同給數據集加了幾個維度的數據,而豐富的數據集正是神經網路能夠達到更高歸類準確率的基本要素,同時還可減少過擬合的結果發生,換個角度思考這些產生的數據,它們就如同數據的雜訊,為過擬合可能發生的情況提供了一道保險。

不過當使用此方法在訓練的時候,產生數據的過程會添加計算的負擔,進而造成時間上的消耗,是我們應用此方法的時候一個重要的考慮要點。

結合上述方法定義的函數代碼如下:

Code[5]

def image_preprocessing(single_img, crop=[28,28], crop_only=False):

H, W = cifar.image_size, cifar.image_size

height, width = crop

single_img =tf.random_crop(single_img, size=[height, width,3])

single_img =tf.image.random_flip_left_right(single_img)

single_img =tf.image.random_flip_up_down(single_img)

single_img =tf.image.random_contrast(single_img, lower=0.5, upper=1.0)

single_img =tf.image.random_hue(single_img, max_delta=0.03)

single_img =tf.image.random_brightness(single_img, max_delta=0.2)

single_img =tf.image.random_saturation(single_img, lower=0.5, upper=1.5)

single_img =tf.minimum(single_img,1.0)

single_img =tf.maximum(single_img,0.0)

single_img =tf.image.resize_image_with_crop_or_pad(

single_img, target_height=H, target_width=W)

returnsingle_img

此函數的邏輯為下面陳列的幾點說明:

調整我們要隨機位置裁切的尺寸大小後

對裁切下來的圖像開始隨意顛倒,變化色調等等

把超出 RGB 三個單元最大值和最小值的部分抹平

把縮小尺寸的裁切團重新 padding 回到原本未裁切的大小,目的是使用數據流圖時測試機不需要預處理圖像就能夠測試,此一做法更為合理

上面定義的函數必須強調的是,它只處理 "單一張" 圖片,如果關聯到批量處理,例如我們習慣於把一整批圖像數據用 4D 張量的方式表示,格式分別為 (張數,圖高,圖寬,顏色階數),則可以使用 tf.map_fn 配合 lambda 的方式一次隨機處理整批圖像數據,並且每張圖像數據的調整係數本身都不盡相同,最後面即為詳細的搭配使用代碼與說明。


A glimpse to the Preprocessed Images

為了確定我們處理的數據完整性與效果,下面嘗試使用我們定義好的函數來隨機列印預處理圖片集的結果,步驟如下:

導入數據集,並使用定義的類方法呼叫訓練圖像

使用 Tensorflow 框架的構建方法,把導入的數據集放入我們預先定義好的函數中

啟動 tf 會話 .Session() 功能

sess.run() 了上個函數的運算結果後,才把這裡的運算結果放入繪圖函數中

等待時間約為一分半鐘,預處理好後即自行列印

Code[6]

importtensorflowastf

lab_train = cifar.lab_train

format_imgs = cifar.format_images(cifar.img_train)

# We can put every single element of a list into the argument which

# is belonging to tf.map_fn()"s fn by using lambda expression so it can

# iterate all elements to the preset function "image_preprocessing".

format_imgs = tf.map_fn(lambdaimg: image_preprocessing(img, crop=[24,24]), format_imgs)

sess = tf.Session()

format_imgs = sess.run(format_imgs)

plot_images(format_imgs, lab_train, cifar.get_class_names, size=[3,4])

: airplane

1: automobile

2: bird

3: cat

4: deer

5: dog

6: frog

7: horse

8: ship

9: truck


文章回顧

喜歡這篇文章嗎?立刻分享出去讓更多人知道吧!

本站內容充實豐富,博大精深,小編精選每日熱門資訊,隨時更新,點擊「搶先收到最新資訊」瀏覽吧!


請您繼續閱讀更多來自 AI研習社 的精彩文章:

支撐區塊鏈中的底層查詢系統
2018全球機器學習技術大會議程搶鮮看!

TAG:AI研習社 |