import tensorflow as tf
from tensorflow import keras
import cv2
import numpy as np
import glob
import os
import sys
import imutils
import random
import pymssql
import pyodbc
from PIL import Image, ImageTk
from tensorflow.keras import losses,layers,optimizers
from tensorflow.keras.callbacks import EarlyStopping
from PyQt5.QtWidgets import QApplication, QMainWindow, QDockWidget, QTextEdit, QLabel, QWidget, QPushButton, QVBoxLayout, QListWidget, QLineEdit, QStyleFactory
from PyQt5.QtCore import Qt, QTimer
from PyQt5 import QtCore, QtGui, QtWidgets, uic
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtSql import QSqlDatabase , QSqlQuery
from pathlib import Path, PureWindowsPath
from pyecharts import options as opts
from pyecharts.charts import Scatter
from PyQt5.QtWebEngineWidgets import QWebEngineView, QWebEngineSettings
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtWebEngineWidgets import QWebEngineView
QApplication.setStyle(QStyleFactory.create('Fusion'))
#定义目标文件夹为全局变量,给各个函数进行调用
targetfilePath = ''
isPredicting = True
temp = ''
count = 0
model = object
trainingfilePath = ''
mynet = object
Camera_status = False
saveImage = object
xdata = []
ydata = []
#---------------------------------------------------------------------------------训练模块---------------------------------------------------------------------------------------------------#
#训练参数设置
tf.random.set_seed(2222)
np.random.seed(2222)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
#预处理训练图片
def preprocessTraining(x,y):
x = tf.image.resize(x,[244,244])
x = tf.image.random_flip_left_right(x)
x = tf.image.random_crop(x,[224,224,3])
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=10)
return x,y
#加载训练图片以及Resize转换
def Data_Generation(filePath):
X_data = []
Y_data = []
path_data = []
path_label = []
files = os.listdir(filePath)
#循环加载文件
for file in files:
if os.path.exists(filePath + '/' + file+ '/') == True:
for path in glob.glob(filePath + '/' + file + '/*.png'):
path_data.append(path)
else:
for path in glob.glob(filePath + '/*.png'):
path_data.append(path)
#打乱数据
random.shuffle(path_data)
#按照图片分类进行label标记
for paths in path_data:
#print(paths)
if 'EIS板厚Sample' in paths:
path_label.append(0)
elif 'EIS面辐Sample' in paths:
path_label.append(1)
elif 'EIS水污Sample' in paths:
path_label.append(2)
elif 'LIS点图Sample' in paths:
path_label.append(3)
elif 'LIS水污Sample' in paths:
path_label.append(4)
elif 'Particle_AST' in paths:
path_label.append(5)
elif 'FeO_AST' in paths:
path_label.append(6)
elif 'Bacteria_AST' in paths:
path_label.append(7)
elif 'Other_AST' in paths:
path_label.append(8)
img = cv2.imdecode(np.fromfile(paths, dtype=np.uint8), -1)
#img = cv2.imread(paths)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cv2.resize(img,(224,224))
X_data.append(img)
L = len(path_data)
Y_data = path_label
X_data = np.array(X_data,dtype=float)
Y_data = np.array(Y_data,dtype='uint8')
X_train = X_data[0:int(L*0.6)]
Y_train = Y_data[0:int(L*0.6)]
X_valid = X_data[int(L*0.6):int(L*0.8)]
Y_valid = Y_data[int(L*0.6):int(L*0.8)]
X_test = X_data[int(L*0.8):]
Y_test = Y_data[int(L*0.8):]
return X_train,Y_train,X_valid,Y_valid,X_test,Y_test,L
#获取训练集,测试集
def trainingModel(self):
#文件加载路径
trainingfilePath = targetfilePath
if len(trainingfilePath) == 0:
self.logText.setText('请先选择文件夹再开始训练!')
return
else:
self.logText.setText('开始训练!')
X_train,Y_train,X_valid,Y_valid,X_test,Y_test,L = Data_Generation(trainingfilePath)
batchsz = 32
#print(shape(X_data), shape(Y_data))
train_db = tf.data.Dataset.from_tensor_slices((X_train,Y_train))
train_db = train_db.shuffle(10000).map(preprocessTraining).batch(batchsz)
valid_db = tf.data.Dataset.from_tensor_slices((X_valid,Y_valid))
valid_db = valid_db.map(preprocessTraining).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices((X_test,Y_test))
test_db = test_db.map(preprocessTraining).batch(batchsz)
#这里使用了自带的DenseNet121网络 你也可以用keras.Sequential DIY模型
net = keras.applications.DenseNet121(weights='imagenet',include_top=False,pooling='max')
net.trainable = False
global mynet
mynet = keras.Sequential([
net,
layers.Dense(1024,activation='relu'),
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(), #BN层 标准化数据
layers.Dropout(rate=0.2),
layers.Dense(10)])
mynet.build(input_shape=(4,224,224,3))
mynet.summary()
#防止过拟合
early_stopping = EarlyStopping(monitor='val_accuracy', min_delta=0.01, patience=3)
mynet.compile(optimizer=optimizers.SGD(lr=1e-3), loss=losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
history = mynet.fit(train_db, validation_data=valid_db, validation_freq=1, epochs=50, callbacks=[early_stopping])
history = history.history
mynet.evaluate(test_db)
self.logText.setText('训练完成!')
#保存模型
def saveModel(self):
FolderName = QtWidgets.QFileDialog.getExistingDirectory()
if '/' in FolderName :
# 用\替换/,注意'\\'的用法,
FolderName.replace('/', '\\')
#训练结束以后保存mmodel文件到本地方便做图片分类的时候直接调用
#保存model成 .h5格式 里面包含了模型结构和训练好的模型参数
if len(FolderName) == 0:
self.logText.setText('未保存模型文件!')
else:
mynet.save(FolderName + '\densenet.h5')
self.logText.setText('模型文件已保存!')
#---------------------------------------------------------------------------------识别模块---------------------------------------------------------------------------------------------------#
#图片归一化
def normalize(x):
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
x = (x - img_mean)/img_std
return x
#初始化图片
def preprocess(x):
x = tf.expand_dims(x,axis=0)
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
return x
#加载模型
def loadModel(self):
File = QtWidgets.QFileDialog.getOpenFileName()
modelFile = File[0]
if '/' in modelFile:
# 用\替换/,注意'\\'的用法,
modelFile.replace('/', '\\')
if len(modelFile) == 0:
self.logText.setText('未加载模型文件!')
else:
self.modelSelectText.setText(modelFile)
network = keras.models.load_model(modelFile)
network.summary()
self.logText.setText('检测模型加载成功:' + modelFile)
global model
model = network
#打开目标文件夹
def openFileFolder(self):
FolderName = QtWidgets.QFileDialog.getExistingDirectory()
if '/' in FolderName :
# 用\替换/,注意'\\'的用法,
FolderName.replace('/', '\\')
if len(FolderName) == 0 :