import numpy as np
import time
import torch
from MTCNN_last_version import Tools
from MTCNN_last_version.Tools import convert_to_squre
from MTCNN_last_version.Module import RNet, PNet
from MTCNN_last_version.Train_Onet_Landmark import ONet # Landmark
from PIL import Image
from PIL import ImageDraw
from torchvision import transforms
class Detector:
def __init__(self, pp=r'./params/Pnet2048.pth', rp=r'./params/Rnet.pth',
op=r'./params/Onet_Landmark.pth', isCuda=True):
self.isCuda = isCuda
self.pnet = PNet()
self.rnet = RNet()
self.onet = ONet()
if self.isCuda:
self.pnet.cuda()
self.rnet.cuda()
self.onet.cuda()
self.pnet.eval()
self.rnet.eval()
self.onet.eval()
self.pnet.load_state_dict(torch.load(pp, map_location='cpu'))
self.rnet.load_state_dict(torch.load(rp, map_location='cpu'))
self.onet.load_state_dict(torch.load(op, map_location='cpu'))
self.__transfrom = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def detect(self, image):
empty = np.array([])
t1 = time.time()
p_boxes = self.__pnet_detect(image)
print('Pbox', p_boxes.shape)
if p_boxes.shape[0] == 0:
return empty
t2 = time.time()
t_p = t2 - t1
t3 = time.time()
r_boxes = self.__rnet_detect(image, p_boxes)
print('Rbox', r_boxes.shape)
if r_boxes.shape[0] == 0:
return empty
t4 = time.time()
t_r = t4 - t3
t5 = time.time()
o_boxes = self.__onet_detect(image, r_boxes)
print('Obox', o_boxes.shape)
if o_boxes.shape[0] == 0:
return empty
t6 = time.time()
t_o = t6 - t5
t_sum = t_p + t_r + t_o
print('t_sum :{} t_p: {} t_r: {} t_o: {}'.format(t_sum, t_p, t_r, t_o))
return o_boxes
def __pnet_detect(self, x):
boxes = []
img = x
w, h = img.size
min_side = min(w, h)
scale = 1
while min_side > 12:
img_data = self.__transfrom(img)
if self.isCuda:
img_data = img_data.cuda()
img_data.unsqueeze_(0)
_cls, _offset = self.pnet(img_data)
cls, offset = _cls[0][0].cpu().data, _offset[0].cpu().data
indexes = torch.nonzero(torch.gt(cls, 0.3))
for idx in indexes:
boxes.append(self.__box(idx, offset, cls[idx[0], idx[1]], scale))
scale *= 0.68
_w = int(w * scale)
_h = int(h * scale)
img = img.resize((_w, _h))
min_side = np.minimum(_w, _h)
return Tools.nms2(np.array(boxes), thresh=0.3)
def __box(self, start_index, offset, cls, scale, stride=2, side_len=12):
_x1 = int(start_index[1] * stride) / scale # 宽,W,x
_y1 = int(start_index[0] * stride) / scale # 高,H,y
_x2 = int(start_index[1] * stride + side_len) / scale
_y2 = int(start_index[0] * stride + side_len) / scale
ow = _x2 - _x1 # 12
oh = _y2 - _y1
_offset = offset[:, start_index[0], start_index[1]]
x1 = _x1 + ow * _offset[0]
y1 = _y1 + oh * _offset[1]
x2 = _x2 + ow * _offset[2]
y2 = _y2 + oh * _offset[3]
return [x1, y1, x2, y2, cls]
def __rnet_detect(self, image, y):
_img_dataset = []
pnet_boxes = convert_to_squre(y)
for _box in pnet_boxes:
_x1 = int(_box[0])
_y1 = int(_box[1])
_x2 = int(_box[2])
_y2 = int(_box[3])
img = image.crop((_x1, _y1, _x2, _y2))
img = img.resize((24, 24))
img_data = self.__transfrom(img)
_img_dataset.append(img_data)
img_dataset = torch.stack(_img_dataset)
if self.isCuda:
img_dataset = img_dataset.cuda()
_cls, _offset = self.rnet(img_dataset)
_cls = _cls.cpu().data.numpy()
offset = _offset.cpu().data.numpy()
boxes = []
indexes, _ = np.where(_cls > 0.5)
for idx in indexes:
_box = pnet_boxes[idx]
_x1 = int(_box[0])
_y1 = int(_box[1])
_x2 = int(_box[2])
_y2 = int(_box[3])
ow = _x2 - _x1
oh = _y2 - _y1
x1 = _x1 + ow * offset[idx][0]
y1 = _y1 + oh * offset[idx][1]
x2 = _x2 + ow * offset[idx][2]
y2 = _y2 + oh * offset[idx][3]
cls = _cls[idx][0]
boxes.append([x1, y1, x2, y2, cls])
return Tools.nms(np.array(boxes), 0.3)
def __onet_detect(self, image, _rnet_box):
_img_data = []
rnet_box = convert_to_squre(_rnet_box)
for _box in rnet_box:
_x1 = int(_box[0])
_y1 = int(_box[1])
_x2 = int(_box[2])
_y2 = int(_box[3])
img = image.crop((_x1, _y1, _x2, _y2))
img = img.resize((48, 48))
img_data = self.__transfrom(img)
_img_data.append(img_data)
img_dataset = torch.stack(_img_data)
if self.isCuda:
img_dataset = img_dataset.cuda()
_cls, _offset, fll_ = self.onet(img_dataset)
_cls = _cls.cpu().data.numpy()
offset = _offset.cpu().data.numpy()
fll = fll_.cpu().data.numpy()
boxes = []
indexes, _ = np.where(_cls > 0.001)
for idx in indexes:
_box = rnet_box[idx]
_x1 = int(_box[0])
_y1 = int(_box[1])
_x2 = int(_box[2])
_y2 = int(_box[3])
ow = _x2 - _x1
oh = _y2 - _y1
x1 = _x1 + ow * offset[idx][0]
y1 = _y1 + oh * offset[idx][1]
x2 = _x2 + ow * offset[idx][2]
y2 = _y2 + oh * offset[idx][3]
cls = _cls[idx][0]
fllx1 = _x1 + ow * fll[idx][0]
flly1 = _y1 + oh * fll[idx][1]
fllx2 = _x1 + ow * fll[idx][2]
flly2 = _y1 + oh * fll[idx][3]
fllx3 = _x1 + ow * fll[idx][4]
flly3 = _y1 + oh * fll[idx][5]
fllx4 = _x1 + ow * fll[idx][6]
flly4 = _y1 + oh * fll[idx][7]
fllx5 = _x1 + ow * fll[idx][8]
flly5 = _y1 + oh * fll[idx][9]
boxes.append([x1, y1, x2, y2, cls, fllx1, flly1, fllx2, flly2, fllx3, flly3, fllx4, flly4, fllx5, flly5])
return Tools.nms(np.array(boxes), 0.3, isMin=True)
if __name__ == '__main__':
t01 = time.time()
with torch.no_grad() as grad:
image_file = r'C:\Projects\MTCNN_last_version\tets_img\3.jpg'
detect = Detector()
with Image.open(image_file) as img:
boxes = detect.detect(img)
print(boxes.shape)
imgDraw = ImageDraw.Draw(img)
for box in boxes:
x1 = int(box[0])
y1 = int(box[1])
x2 = int(box[2])
y2 = int(box[3])
'''5 - 14'''
fllx1 = int(box[5])
flly1 = int(box[6])
fllx2 = int(box[7])
flly2 = int(box[8])
fllx3 = int(box[9])
flly3 = int(box[10])
fllx4 = int(box[11])
flly4 = int(box[12])
fllx5 = int(box[13])
flly5 = int(box[14])