#!/usr/bin/env python
# -*- coding: utf-8 -*-
from Digit import Digit
from optparse import OptionParser
from variantKNN import notKNN
import pickle
import os
import glob # NOQA
N = 100
K = 1
def _initializeDigitsFromFile(filename):
result = []
with open(filename) as trainFile:
for line in trainFile:
tempDigit = Digit(line)
result.append(tempDigit)
return result
def trainTheModel(trainFile, trainObj=None):
"""The function will train from the file, if trainObject is None, will
genernate new one, otherwise, reuse the exist trainObject and go on
training.
:trainFile: The file contained the data for training
:trainObj: The trained object model
:returns: The trained object model
"""
trainList = _initializeDigitsFromFile(trainFile)
# digit1 = trainList[0]
# print(digit1)
# print(digit1.getPointsWithNumberList())
# print(digit1.getSortedPoints())
# print(digit1.getTarget())
# print(digit1.imagewithpredicition(5))
trainMatrix = []
for i in trainList:
trainMatrix.append(i.getPointsWithNumberList())
# trainMatrix.append(i.getResamblePoints())
targets = []
for i in trainList:
targets.append(i.getTarget())
trainedObj = notKNN()
# trainedObj.training(trainMatrix[:2000], targets[:2000], N)
trainedObj.training(trainMatrix, targets, N)
return trainedObj
def savetofile(predfile, points, predResults):
"""Save the prediction to the file
predfile: The file to store the result
points: The points list
predResult: The list of prediction result
:returns: None
"""
with open(predfile, 'w', encoding='utf8') as saveFile:
for i in range(len(predResults)):
saveFile.write(points[i].imagewithpredicition(predResults[i]))
def prediction(predFile, trainedModelObj):
"""The function will prediction the data from the predFile.
If the prediction file had target, will report the sucseful rate.
If the prediction file had no target, will save the file with prediction
numbers
:predFile: The file contain the prediction digits.
:returns: None
"""
predList = _initializeDigitsFromFile(predFile)
predMatrix = []
needSaveResult = False
for i in predList:
predMatrix.append(i.getPointsWithNumberList())
# predMatrix.append(i.getResamblePoints())
targets = []
for i in predList:
targets.append(i.getTarget())
if targets[0] == -1:
needSaveResult = True
# res = trainedModelObj.predict(predMatrix, targets)
print("predmatrix len={0}".format(len(predMatrix)))
if needSaveResult is True:
res = trainedModelObj.predict(predMatrix, None, K)
savetofile(predFile, predList, res)
return
# if the predction file with targets, will caculate the hit rate.
res = trainedModelObj.predict(predMatrix, targets, K)
count = 0
for i in range(len(res)):
if res[i] == targets[i]:
count += 1
accurate = float(count) / len(res) * 100.0
print("Test set total = {0}".format(len(res)))
print("Test match = {0}".format(count))
print("Accuracy = {0}%".format(accurate))
pass
def loadTrainedModel(filename):
"""Load the trainedObject from file, if file non-exsit, retrun None
return None or saved tranedObject
"""
if not(os.path.exists(filename) and os.path.isfile(filename)):
return None
with open(filename, 'rb') as dbfile:
trainedObj = pickle.load(dbfile)
return trainedObj
return None
def saveTrainedModel(filename, trainedObj):
"""Save the trainedObject to file.
:returns: None
"""
with open(filename, 'wb') as dbfile:
pickle.dump(trainedObj, dbfile)
pass
def walkTrainFileInDir(folder, outputfilename):
filelocation = folder
filearray = []
output = []
for filename in glob.glob(filelocation + "*.tra"):
filearray.append(filename)
print(filearray)
for i in range(len(filearray)):
fname = filearray[i]
count = 1
with open(fname) as trainFile:
for line in trainFile:
line = line.strip('\n')
if len(line.strip()) <= 1:
continue
trainTarget = count % 10
linestr = line + ", " + str(trainTarget) + ('\n')
output.append(linestr)
with open(outputfilename, 'w') as outputfile:
for i in output:
outputfile.write(i)
pass
def main():
"""The main function
"""
try:
# Parse the input options
opt = OptionParser()
opt.add_option('-t', '--trainFile',
dest='trainFile',
type=str,
help='the file stored the pendigits train data')
opt.add_option('-p', '--predictionFile',
dest='predictionFile',
type=str,
help='the file stored the pendigits test data')
opt.add_option('-n', '--nosave',
dest='nosave',
action="store_true",
default=False,
help='if set --nosave, the program will not save the\
model which ahas been trained.')
opt.add_option('-w', '--walkfolder',
dest='walkfolder',
type=str,
help='the walk folder which locate the *.tra files')
(options, args) = opt.parse_args()
# first try to load the trained model, if no traied model, the
# trainTheModel will create one, so -t is a must.
if options.walkfolder:
walkTrainFileInDir(options.walkfolder, "output")
return
saveTrainedFile = "TrainObj-pickle"
trainedModelObj = loadTrainedModel(saveTrainedFile)
if options.trainFile:
trainedModelObj = trainTheModel(options.trainFile)
if options.predictionFile:
prediction(options.predictionFile, trainedModelObj)
# trainedModelObj.image()
if not options.nosave:
saveTrainedModel(saveTrainedFile, trainedModelObj)
except Exception as e:
print("Check if TrainObj-pickle exsit, if none, train again")
raise e
if __name__ == "__main__":
main()
tt.zip_UJUN_python/digital/KNN_rockz2j
版权申诉
75 浏览量
2022-09-21
22:13:20
上传
评论
收藏 98KB ZIP 举报
邓凌佳
- 粉丝: 65
- 资源: 1万+