# https://cloud.tencent.com/developer/article/1409009
# https://blog.csdn.net/qq_27158179/article/details/82717821
import os
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
from config import *
class general_mulitpose_model(object):
def __init__(self, keypoint_num):
self.point_names = point_names_25 if keypoint_num==25 else point_names_18
self.point_pairs = point_pairs_25 if keypoint_num==25 else point_pairs_25
self.map_idx = map_idx_25 if keypoint_num==25 else map_idx_25
self.colors = colors_25 if keypoint_num==25 else colors_25
self.num_points = 25 if keypoint_num==25 else 18
self.prototxt = prototxt_25 if keypoint_num==25 else prototxt_18
self.caffemodel = caffemodel_25 if keypoint_num==25 else caffemodel_18
self.pose_net = self.get_model()
def get_model(self):
coco_net = cv2.dnn.readNetFromCaffe(self.prototxt, self.caffemodel)
return coco_net
def getKeypoints(self, probMap, threshold=0.1):
mapSmooth = cv2.GaussianBlur(probMap, (3, 3), 0, 0)
mapMask = np.uint8(mapSmooth > threshold)
# plt.imshow(mapMask)
# plt.show()
# exit()
keypoints = []
# find the blobs
contours, hierarchy = cv2.findContours(mapMask,
cv2.RETR_TREE,
cv2.CHAIN_APPROX_SIMPLE)
# for each blob find the maxima
for cnt in contours:
blobMask = np.zeros(mapMask.shape)
blobMask = cv2.fillConvexPoly(blobMask, cnt, 1)
# plt.imshow(blobMask)
# plt.show()
maskedProbMap = mapSmooth * blobMask
_, maxVal, _, maxLoc = cv2.minMaxLoc(maskedProbMap)
keypoints.append(maxLoc + (probMap[maxLoc[1], maxLoc[0]],))
return keypoints
def getValidPairs(self, output, detected_keypoints, img_width, img_height):
valid_pairs = []
invalid_pairs = []
n_interp_samples = 10
paf_score_th = 0.1
conf_th = 0.7
for k in range(len(self.map_idx)):
# A->B constitute a limb
pafA = output[0, self.map_idx[k][0], :, :]
pafB = output[0, self.map_idx[k][1], :, :]
pafA = cv2.resize(pafA, (img_width, img_height))
pafB = cv2.resize(pafB, (img_width, img_height))
# plt.imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
# plt.imshow(pafB, alpha=0.4)
# plt.show()
# Find the keypoints for the first and second limb
candA = detected_keypoints[self.point_pairs[k][0]]
candB = detected_keypoints[self.point_pairs[k][1]]
nA = len(candA)
nB = len(candB)
if (nA != 0 and nB != 0):
valid_pair = np.zeros((0, 3))
for i in range(nA):
max_j = -1
maxScore = -1
found = 0
for j in range(nB):
# Find d_ij
d_ij = np.subtract(candB[j][:2], candA[i][:2])
norm = np.linalg.norm(d_ij)
if norm:
d_ij = d_ij / norm
else:
continue
# Find p(u)
interp_coord = list(
zip(np.linspace(candA[i][0], candB[j][0], num=n_interp_samples),
np.linspace(candA[i][1], candB[j][1], num=n_interp_samples)))
# Find L(p(u))
paf_interp = []
for k in range(len(interp_coord)):
paf_interp.append([pafA[int(round(interp_coord[k][1])), int(
round(interp_coord[k][0]))],
pafB[int(round(interp_coord[k][1])), int(
round(interp_coord[k][0]))]])
# Find E
paf_scores = np.dot(paf_interp, d_ij)
avg_paf_score = sum(paf_scores) / len(paf_scores)
# Check if the connection is valid
if (len(np.where(paf_scores > paf_score_th)[
0]) / n_interp_samples) > conf_th:
if avg_paf_score > maxScore:
max_j = j
maxScore = avg_paf_score
found = 1
# Append the connection to the list
if found:
valid_pair = np.append(valid_pair,
[[candA[i][3], candB[max_j][3], maxScore]], axis=0)
# Append the detected connections to the global list
valid_pairs.append(valid_pair)
else: # If no keypoints are detected
print("No Connection : k = {}".format(k))
invalid_pairs.append(k)
valid_pairs.append([])
return valid_pairs, invalid_pairs
def getPersonwiseKeypoints(self, valid_pairs, invalid_pairs, keypoints_list):
personwiseKeypoints = -1 * np.ones((0, self.num_points+1))
for k in range(len(self.map_idx)):
print("==")
if k not in invalid_pairs:
partAs = valid_pairs[k][:, 0]
partBs = valid_pairs[k][:, 1]
indexA, indexB = np.array(self.point_pairs[k])
for i in range(len(valid_pairs[k])):
found = 0
person_idx = -1
print("=========")
for j in range(len(personwiseKeypoints)):
# print(personwiseKeypoints[0][indexA])
# exit()
if personwiseKeypoints[j][indexA] == partAs[i]:
person_idx = j
found = 1
break
if found:
personwiseKeypoints[person_idx][indexB] = partBs[i]
personwiseKeypoints[person_idx][-1] += keypoints_list[
partBs[i].astype(int), 2] + \
valid_pairs[k][i][2]
# if find no partA in the subset, create a new subset
elif not found and k < self.num_points-1:
row = -1 * np.ones(self.num_points+1)
# print(row)
row[indexA] = partAs[i]
row[indexB] = partBs[i]
# print(valid_pairs[k][i, :2])
# add the keypoint_scores for the two keypoints and the paf_score
row[-1] = sum(keypoints_list[valid_pairs[k][i, :2].astype(int), 2]) + \
valid_pairs[k][i][2]
# print(row)
# exit()
personwiseKeypoints = np.vstack([personwiseKeypoints, row])
# print(personwiseKeypoints)
return personwiseKeypoints
def predict(self, imgfile):
img_cv2 = cv2.imread(imgfile)
self.image = img_cv2
img_width, img_height = img_cv2.shape[1], img_cv2.shape[0]
net_height = 368
net_width = int((net_height / img_height) * img_width)
start = time.time()
in_blob = cv2.dnn.blobFromImage(
img_cv2,
1.0 / 255,
(net_width, net_height),
(0, 0, 0),
swapRB=False,
crop=False)
self.pose_net.setInput(in_blob)
output = self.pose_net.forward()
print("[INFO]Time Taken in Forward p
opencv-openpose.zip的脚本
需积分: 5 61 浏览量
2021-02-23
18:01:18
上传
评论 5
收藏 279.16MB ZIP 举报
magic_ll
- 粉丝: 1w+
- 资源: 2
最新资源
- 基于CSS的响应式鲜花网站全屏效果设计源码
- 基于JavaScript的访客预约系统设计源码
- 基于Vue和ECharts的工作租房数据可视化系统设计源码
- 1040g0cg310ravpiu6ibg5pg00tsipsln3ju2d0g 2
- 基于Python的SAR图像去噪CNN-NLM设计源码
- redhat6升级到redhat7,过程redhat6.x-> redhat6.10->rehat7.9 主版本最高版本
- 基于Django的流程引擎设计源码
- 基于Node.js的Express框架与MySQL的后台管理系统设计源码
- 基于Java的Flink流批一体数据处理快速集成开发框架设计源码
- FirstFilterOrderCompare
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
评论0