from pathlib import Path
import pandas as pd
import os
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
# 数据读取&预处理
def load_data():
images = []
labels = []
df = pd.read_csv('./dataset/chinese_mnist.csv')
# 数据标注的形式是,csv里存的是文件名的后缀数字,映射到正确的答案
# 图片已经是64x64的了
folder_path = './dataset/data/data'
for image_name in os.listdir(folder_path):
image_path = os.path.join(folder_path, image_name)
image = Image.open(image_path).convert('RGB')
# 归一化
image_nd_array = np.array(image) / 255
images.append(image_nd_array.flatten())
# 先去掉扩展名,
image_ids = ''.join(image_name.split('.')[0]).split('_')[1:]
# suite_id,sample_id,code
target_df = df[
(df['suite_id'] == int(image_ids[0])) &
(df['sample_id'] == int(image_ids[1])) &
(df['code'] == int(image_ids[2]))
]
label = target_df['value'].values[0]
if label <= 10:
labels.append(label)
elif label == 100:
labels.append(11)
elif label == 1000:
labels.append(12)
elif label == 10000:
labels.append(13)
elif label == 100000000:
labels.append(14)
x_train, x_test, y_train, y_test = train_test_split(
np.array(images), np.array(labels), test_size=0.2, random_state=42)
return x_train, x_test, y_train, y_test
def train(images, labels):
print('images.shape \n', images.shape)
print('labels: \n', labels)
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
# 有十五个类别,包括0-10,百千万亿
model.add(layers.Dense(15, activation='softmax'))
# 设置优化器、损失函数、评估函数
# labels是 one_hat, 且类型是多分类,用稀疏交叉熵SparseCategoricalCrossentropy作为损失函数就可以了
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.fit(images, labels, epochs=10)
model.save('./model/chinese_mnist.keras')
return model
def predict(model, x_test, y_test):
y_predicted = model.predict(x_test)
y_predicted_classes = np.argmax(y_predicted, axis=1)
accuracy = np.sum(y_predicted_classes == y_test) / len(y_test)
print(f"Accuracy: {accuracy * 100:.2f}%")
return y_predicted
x_train, x_test, y_train, y_test = load_data()
model = train(x_train, y_train)
prediction_res = predict(model, x_test, y_test)
没有合适的资源?快使用搜索试试~ 我知道了~
19 - Tensorflow'实现识别中文数字项目
共2000个文件
jpg:1997个
py:1个
xml:1个
1 下载量 143 浏览量
2024-02-17
08:19:34
上传
评论 1
收藏 21.16MB ZIP 举报
温馨提示
这个项目是一个使用TensorFlow和Keras构建的图像分类项目,旨在识别中文数字。以下是对项目的描述以及数据处理的概述: 1. **项目描述:** - 该项目旨在通过深度学习技术,使用TensorFlow和Keras框架,构建一个图像分类模型,专门用于识别中文数字。 - 图像分类是机器学习领域的一个常见任务,该项目通过训练神经网络,使其能够准确地识别手写的中文数字。 2. **数据读取与预处理:** - 数据集包含两部分:一个CSV文件(`chinese_mnist.csv`)和图像文件夹(`./dataset/data/data`)。 - CSV文件中存储了图像文件名的后缀数字与正确标签的映射关系。 - 图像数据被加载、转换为RGB格式,并进行归一化处理(将像素值除以255)。 - 图像的标签根据文件名映射到CSV文件中相应的suite_id、sample_id和code来获取。 3. **模型训练:** - 使用Keras的Sequential API搭建神经网络模型。 - 模型包括两个具有ReLU激活函数的全连接层,以及
资源推荐
资源详情
资源评论
收起资源包目录
19 - Tensorflow'实现识别中文数字项目 (2000个子文件)
input_17_5_1.jpg 941B
input_17_8_1.jpg 908B
input_28_4_1.jpg 903B
input_36_9_1.jpg 890B
input_28_9_1.jpg 887B
input_46_2_1.jpg 868B
input_36_2_1.jpg 848B
input_30_6_1.jpg 841B
input_28_7_1.jpg 839B
input_27_8_1.jpg 837B
input_39_5_1.jpg 826B
input_2_2_1.jpg 824B
input_34_3_5.jpg 819B
input_4_9_1.jpg 817B
input_27_9_1.jpg 815B
input_20_1_1.jpg 814B
input_30_8_1.jpg 813B
input_20_6_12.jpg 810B
input_96_2_1.jpg 803B
input_28_1_12.jpg 800B
input_20_3_12.jpg 798B
input_36_7_15.jpg 791B
input_20_4_12.jpg 789B
input_20_5_15.jpg 784B
input_37_5_1.jpg 782B
input_36_3_12.jpg 778B
input_50_8_1.jpg 778B
input_21_3_1.jpg 773B
input_50_6_1.jpg 773B
input_34_1_1.jpg 771B
input_40_8_1.jpg 767B
input_91_3_1.jpg 767B
input_22_1_12.jpg 766B
input_18_3_15.jpg 765B
input_35_3_1.jpg 764B
input_34_6_5.jpg 763B
input_39_2_1.jpg 762B
input_45_4_1.jpg 760B
input_6_5_1.jpg 760B
input_32_8_1.jpg 755B
input_28_1_6.jpg 755B
input_50_1_1.jpg 753B
input_71_8_1.jpg 750B
input_71_3_1.jpg 749B
input_60_8_1.jpg 748B
input_2_7_1.jpg 748B
input_46_2_14.jpg 746B
input_55_8_15.jpg 746B
input_20_2_14.jpg 745B
input_96_7_1.jpg 744B
input_9_8_1.jpg 742B
input_55_9_15.jpg 742B
input_19_2_1.jpg 741B
input_18_10_15.jpg 741B
input_40_3_1.jpg 740B
input_5_5_1.jpg 740B
input_36_4_5.jpg 738B
input_2_10_1.jpg 738B
input_22_5_5.jpg 738B
input_45_9_1.jpg 737B
input_18_2_5.jpg 734B
input_18_10_12.jpg 732B
input_26_6_1.jpg 731B
input_18_8_6.jpg 727B
input_34_7_6.jpg 727B
input_91_5_1.jpg 725B
input_46_10_15.jpg 724B
input_22_4_12.jpg 722B
input_40_10_1.jpg 722B
input_36_8_14.jpg 721B
input_19_5_15.jpg 720B
input_100_6_1.jpg 716B
input_28_2_10.jpg 715B
input_30_4_12.jpg 715B
input_55_4_15.jpg 715B
input_20_3_4.jpg 713B
input_53_10_1.jpg 713B
input_76_4_5.jpg 713B
input_22_1_15.jpg 712B
input_38_7_12.jpg 712B
input_92_8_1.jpg 710B
input_17_4_14.jpg 710B
input_19_6_5.jpg 709B
input_4_2_10.jpg 707B
input_4_3_5.jpg 707B
input_5_1_1.jpg 706B
input_20_9_6.jpg 705B
input_38_1_5.jpg 705B
input_1_10_1.jpg 705B
input_84_10_1.jpg 703B
input_76_7_5.jpg 702B
input_20_1_6.jpg 702B
input_85_8_15.jpg 702B
input_22_1_5.jpg 701B
input_30_3_5.jpg 701B
input_34_7_10.jpg 700B
input_20_9_5.jpg 700B
input_36_5_10.jpg 700B
input_50_7_6.jpg 699B
input_18_6_8.jpg 699B
共 2000 条
- 1
- 2
- 3
- 4
- 5
- 6
- 20
资源评论
小夕Coding
- 粉丝: 5858
- 资源: 461
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功