import os
import cv2
import sys
import re
import time
import numpy as np
from tqdm import tqdm
from datetime import datetime
from tabulate import tabulate
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from A6_submission import Classifier, Params
class NotMNIST_RGB(Dataset):
def __init__(self, fname='train_data_rgb.npz'):
train_data = np.load(fname, allow_pickle=True)
"""_train_images is a 4D array of size n_images x 28 x 28 x 3"""
self._train_images, self._train_labels = train_data['images'], train_data[
'labels'] # type: np.ndarray, np.ndarray
"""
pixel values converted to floating-point numbers and normalized to be between 0 and 1 to make them
suitable for processing in CNNs
"""
self._train_images = self._train_images.astype(np.float32) / 255.0
"""switch images from n_images x 28 x 28 x 3 to n_images x 3 x 28 x 28 since CNNs expect the channels to be
the first dimension"""
self._train_images = np.transpose(self._train_images, (0, 3, 1, 2))
self._train_labels = self._train_labels.astype(np.int64)
self._n_train = self._train_images.shape[0]
def __len__(self):
return self._n_train
def __getitem__(self, idx):
assert idx < self._n_train, "Invalid idx: {} for n_train: {}".format(idx, self._n_train)
images = self._train_images[idx, ...]
labels = self._train_labels[idx]
return images, labels
def train(classifier, dataloader, criterion, optimizer, device):
total_loss = 0
train_total = 0
train_correct = 0
n_batches = 0
# set CNN to training mode
classifier.train()
for batch_idx, (inputs, targets) in tqdm(enumerate(dataloader)):
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = classifier(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item()
loss.backward()
optimizer.step()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
n_batches += 1
mean_loss = total_loss / n_batches
train_acc = 100. * train_correct / train_total
return mean_loss, train_acc
def evaluate(classifier, dataloader, criterion, vis, writer, iteration, device):
total_loss = 0
_psnr_sum = 0
total_images = 0
correct_images = 0
n_batches = 0
_pause = 1
# set CNN to evaluation mode
classifier.eval()
total_test_time = 0
# disable gradients computation
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(device)
targets = targets.to(device)
start_t = time.time()
outputs = classifier(inputs)
end_t = time.time()
test_time = end_t - start_t
total_test_time += test_time
loss = criterion(outputs, targets)
total_loss += loss.item()
_, predicted = outputs.max(1)
total_images += targets.size(0)
is_correct = predicted.eq(targets)
correct_images += is_correct.sum().item()
n_batches += 1
if vis:
inputs_np = inputs.detach().cpu().numpy()
concat_imgs = []
for i in range(dataloader.batch_size):
input_img = inputs_np[i, ...].squeeze()
"""switch image from 3 x 28 x 28 to 28 x 28 x 3 since opencv expects channels
to be on the last axis
copy is needed to resolve an opencv issue with memory layout:
https://stackoverflow.com/questions/23830618/python-opencv-typeerror-layout-of-the-output-array
-incompatible-with-cvmat
"""
input_img = np.transpose(input_img, (1, 2, 0)).copy()
target = targets[i]
output = predicted[i]
_is_correct = is_correct[i].item()
if _is_correct:
col = (0, 1, 0)
else:
col = (0, 0, 1)
pred_img = np.zeros_like(input_img)
_text = '{}'.format(chr(65 + int(output)))
cv2.putText(pred_img, _text, (8, 20), cv2.FONT_HERSHEY_COMPLEX_SMALL,
1, col, 1, cv2.LINE_AA)
label_img = np.zeros_like(input_img)
_text = '{}'.format(chr(65 + int(target)))
cv2.putText(label_img, _text, (8, 20), cv2.FONT_HERSHEY_COMPLEX_SMALL,
1, (1, 1, 1), 1, cv2.LINE_AA)
concat_img = np.concatenate((input_img, label_img, pred_img), axis=0)
concat_imgs.append(concat_img)
vis_img = np.concatenate(concat_imgs, axis=1)
if writer is not None:
vis_img_uint8 = (vis_img * 255.0).astype(np.uint8)
vis_img_tb = cv2.cvtColor(vis_img_uint8, cv2.COLOR_BGR2RGB)
"""tensorboard expects channels in the first axis"""
vis_img_tb = np.transpose(vis_img_tb, axes=[2, 0, 1])
writer.add_image('evaluation', vis_img_tb, iteration)
if vis == 2:
cv2.imshow('Press Esc to exit, Space to resume, any other key for next batch', vis_img)
k = cv2.waitKey(1 - _pause)
if k == 27:
sys.exit(0)
elif k == ord('q'):
vis = 0
cv2.destroyWindow('concat_imgs')
break
elif k == 32:
_pause = 1 - _pause
mean_loss = total_loss / n_batches
acc = 100. * float(correct_images) / float(total_images)
test_speed = float(total_images) / total_test_time
if vis == 2:
cv2.destroyWindow('concat_imgs')
return mean_loss, acc, total_images, test_speed
def main():
params = Params()
# optional command line argument parsing
try:
import paramparse
except ImportError:
pass
else:
paramparse.process(params)
# init device
if params.use_gpu and torch.cuda.is_available():
device = torch.device("cuda")
print('Running on GPU: {}'.format(torch.cuda.get_device_name(0)))
else:
device = torch.device("cpu")
print('Running on CPU')
# load dataset
train_set = NotMNIST_RGB()
num_train = len(train_set)
indices = list(range(num_train))
assert params.val_ratio > 0, "Zero validation ratio is not allowed "
split = int(np.floor((1.0 - params.val_ratio) * num_train))
train_idx, val_idx = indices[:split], indices[split:]
n_train = len(train_idx)
n_val = len(val_idx)
print('Training samples: {}\n'
'Validation samples: {}\n'
''.format(
n_train, n_val,
))
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=params.train_batch_size,
sampler=train_sampler, num_workers=params.n_workers)
val_dataloader = torch.utils.data.DataLoader(train_set, batch_size=params.val_batch_size,
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
手写CNN.zip (263个子文件)
model.pt.47 87.9MB
.gitignore 190B
Assignment 6 Files-20230401.iml 497B
A6_run.ipynb 12.27MB
train_data_rgb.npz 6.78MB
train_data.npz 3.68MB
batch_148.png 143KB
batch_183.png 143KB
batch_76.png 139KB
batch_119.png 138KB
batch_188.png 138KB
batch_174.png 138KB
batch_145.png 138KB
batch_109.png 137KB
batch_157.png 137KB
batch_205.png 137KB
batch_67.png 136KB
batch_115.png 136KB
batch_13.png 136KB
batch_38.png 136KB
batch_88.png 135KB
batch_5.png 135KB
batch_158.png 135KB
batch_164.png 135KB
batch_113.png 135KB
batch_131.png 135KB
batch_134.png 135KB
batch_51.png 134KB
batch_163.png 134KB
batch_170.png 134KB
batch_43.png 134KB
batch_168.png 134KB
batch_177.png 134KB
batch_199.png 134KB
batch_48.png 133KB
batch_65.png 133KB
batch_114.png 133KB
batch_46.png 133KB
batch_10.png 133KB
batch_151.png 133KB
batch_16.png 133KB
batch_91.png 133KB
batch_33.png 133KB
batch_27.png 133KB
batch_132.png 133KB
batch_130.png 133KB
batch_45.png 133KB
batch_110.png 132KB
batch_8.png 132KB
batch_107.png 132KB
batch_150.png 132KB
batch_176.png 132KB
batch_128.png 132KB
batch_159.png 132KB
batch_71.png 132KB
batch_229.png 132KB
batch_226.png 132KB
batch_192.png 132KB
batch_87.png 132KB
batch_29.png 132KB
batch_223.png 132KB
batch_118.png 131KB
batch_56.png 131KB
batch_173.png 131KB
batch_12.png 131KB
batch_171.png 131KB
batch_60.png 131KB
batch_100.png 131KB
batch_175.png 131KB
batch_155.png 131KB
batch_135.png 131KB
batch_189.png 131KB
batch_181.png 131KB
batch_136.png 131KB
batch_218.png 131KB
batch_37.png 131KB
batch_216.png 131KB
batch_139.png 131KB
batch_64.png 131KB
batch_17.png 131KB
batch_228.png 131KB
batch_180.png 131KB
batch_120.png 131KB
batch_41.png 131KB
batch_99.png 131KB
batch_20.png 131KB
batch_126.png 131KB
batch_239.png 131KB
batch_233.png 131KB
batch_30.png 131KB
batch_220.png 131KB
batch_36.png 131KB
batch_31.png 131KB
batch_201.png 131KB
batch_6.png 131KB
batch_85.png 131KB
batch_129.png 130KB
batch_152.png 130KB
batch_83.png 130KB
batch_73.png 130KB
共 263 条
- 1
- 2
- 3
资源评论
sjx_alo
- 粉丝: 1w+
- 资源: 1206
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功