以下操作均在Ubuntu14.04+Anaconda中进行
导入python标准包
In [ ]:
import os # 处理字符串路径
import glob # 用于查找文件
导入相关库
安装过程
conda update conda
conda update --all
conda install mingw libpython
pip install git+git://github.com/Theano/Theano.git
pip install git+git://github.com/fchollet/keras.git
cv2
OpenCV库
conda isntall opnecv
numpy
In [ ]:
from keras.models import Sequential
from keras.layers.core import Flatten, Dense, Dropout
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.optimizers import SGD
import cv2, numpy as np
使用keras建立vgg16模型
In [ ]:
def VGG_16(weights_path=None):
model = Sequential()
model.add(ZeroPadding2D((1,1),input_shape=(3,224,224)))
model.add(Convolution2D(64, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(64, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(128, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(128, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(256, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(256, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(256, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, 3, 3, activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1000, activation='softmax'))
if weights_path:
model.load_weights(weights_path)
return model
引入训练好的vgg16_weights模型
Note:
In [ ]:
model = VGG_16('vgg16_weights.h5')
In [ ]:
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)model.compile(optimizer=sgd, loss='categorical_crossentropy')
猫和狗的特征
In [ ]:
dogs=[251, 268, 256, 253, 255, 254, 257, 159, 211, 210, 212, 214, 213, 216, 215, 219, 220, 221, 217, 218, 207, 209, 206, 205, 208, 193, 202, 194, 191, 204, 187, 203, 185, 192, 183, 199, 195, 181, 184, 201, 186, 200, 182, 188, 189, 190, 197, 196, 198, 179, 180, 177, 178, 175, 163, 174, 176, 160, 162, 161, 164, 168, 173, 170, 169, 165, 166, 167, 172, 171, 264, 263, 266, 265, 267, 262, 246, 242, 243, 248, 247, 229, 233, 234, 228, 231, 232, 230, 227, 226, 235, 225, 224, 223, 222, 236, 252, 237, 250, 249, 241, 239, 238, 240, 244, 245, 259, 261, 260, 258, 154, 153, 158, 152, 155, 151, 157, 156]cats=[281,282,283,284,285,286,287]
待处理文件导入
Note:
In [ ]:
path = os.path.join('imgs', 'test', '*.jpg') #拼接路径
files = glob.glob(path) #返回路径
定义几个变量
In [ ]:
result=[]
In [ ]:
flbase=0p=0temp=0
定义图像加载函数
In [ ]:
def load_image(imageurl):
im = cv2.resize(temp ,(224,224)).astype(np.float32)
im[:,:,0] -= 103.939
im[:,:,1] -= 116.779
im[:,:,2] -= 123.68
im = im.transpose((2,0,1))
im = np.expand_dims(im,axis=0)
return im
定义预测函数
In [ ]:
def predict(url):
im = load_image(url)
out = model.predict(im)
flbase = os.path.basename(url)
p = np.sum(out[0,dogs]) / (np.sum(out[0,dogs]) + np.sum(out[0,cats]))
result.append((flbase,p))
开始预测
Note:
此处的if,else异常检测很重要,因为cv2.imread(fl)在遇到某几张图时会为空,抛出错误,程序中途停止,图片集得不到完全检测。
一般配置电脑跑这部分时,大约需要20~30分钟,不是程序没有工作,请耐心等待。
In [ ]:
for fl in files:
temp=cv2.imread(fl)
if temp ==None:
pass
else:
predict(fl)
对结果进行排序
In [ ]:
result=sorted(result, key=lambda x:x[1], reverse=True)
打印预测结果与相应概率
In [ ]:
for x in result:
print x[0],x[1]
预测结果
In [ ]:
for x in result:
print x[0]
ps:完整的代码可以在github下载
https://github.com/KuHung/DateCastle/blob/master/catdog.ipynb
想收到更多信息即时信息?
关注DC官方公众号