import numpy as np
from queue import Queue
class State:
def __init__(self, state, directionFlag=None, parent=None, f=0):
self.state = state
self.direction = ['up', 'down', 'right', 'left']
if directionFlag:
self.direction.remove(directionFlag)
self.parent = parent
self.f = f
self.expanded_nodes = 0 # 初始化扩展节点数
self.generated_nodes = 0 # 初始化生成节点数
def getDirection(self):
return self.direction
def setF(self, f):
self.f = f
return
# 打印结果
def showInfo(self):
for i in range(len(self.state)):
for j in range(len(self.state)):
print(self.state[i, j], end=' ')
print("\n")
print('->')
return
# 获取0点
def getZeroPos(self):
postion = np.where(self.state == 0)
return postion
# 曼哈顿距离 f = g + h,g=1,如果用宽度优先的评估函数可以不调用该函数
def getFunctionValue(self):
cur_node = self.state.copy()
fin_node = self.answer.copy()
dist = 0
N = len(cur_node)
for i in range(N):
for j in range(N):
if cur_node[i][j] != fin_node[i][j]:
index = np.argwhere(fin_node == cur_node[i][j])
x = index[0][0] # 最终x距离
y = index[0][1] # 最终y距离
dist += (abs(x - i) + abs(y - j))
return 0
# 切比雪夫距离作为评估函数
# def getFunctionValue(self):
# cur_node = self.state.copy()
# fin_node = self.answer.copy()
# dist = 0
# N = len(cur_node)
#
# for i in range(N):
# for j in range(N):
# if cur_node[i][j] != fin_node[i][j]:
# index = np.argwhere(fin_node == cur_node[i][j])
# x = index[0][0] # 最终x距离
# y = index[0][1] # 最终y距离
# dist += max(abs(x - i) , abs(y - j))
# return dist +1
def nextStep(self):
if not self.direction:
return []
subStates = []
boarder = len(self.state) - 1
# 获取0点位置
x, y = self.getZeroPos()
# 向左
if 'left' in self.direction and y > 0:
s = self.state.copy()
tmp = s[x, y - 1]
s[x, y - 1] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='right', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向上
if 'up' in self.direction and x > 0:
# it can move to upper place
s = self.state.copy()
tmp = s[x - 1, y]
s[x - 1, y] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='down', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向下
if 'down' in self.direction and x < boarder:
# it can move to down place
s = self.state.copy()
tmp = s[x + 1, y]
s[x + 1, y] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='up', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向右
if self.direction.count('right') and y < boarder:
# it can move to right place
s = self.state.copy()
tmp = s[x, y + 1]
s[x, y + 1] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='left', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 返回F值最小的下一个点
subStates.sort(key=compareNum)
return subStates[0]
def nextSteps(self):
if not self.direction:
return []
subStates = []
boarder = len(self.state) - 1
# 获取0点位置
x, y = self.getZeroPos()
# 向左
if 'left' in self.direction and y > 0:
s = self.state.copy()
tmp = s[x, y - 1]
s[x, y - 1] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='right', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向上
if 'up' in self.direction and x > 0:
# it can move to upper place
s = self.state.copy()
tmp = s[x - 1, y]
s[x - 1, y] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='down', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向下
if 'down' in self.direction and x < boarder:
# it can move to down place
s = self.state.copy()
tmp = s[x + 1, y]
s[x + 1, y] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='up', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向右
if self.direction.count('right') and y < boarder:
# it can move to right place
s = self.state.copy()
tmp = s[x, y + 1]
s[x, y + 1] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='left', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 返回F值最小的下一个点
subStates.sort(key=compareNum)
return subStates
def a_star_search_efficiency(self):
open_list = []
open_list.append(self)
closed_list = []
while open_list:
current_node = open_list.pop(0)
closed_list.append(current_node)
current_node.expanded_nodes += 1 # 更新扩展节点数
if (current_node.state == current_node.answer).all():
return current_node.expanded_nodes, current_node.generated_nodes
next_states = current_node.nextSteps()
for next_state in next_states:
if next_state not in open_list and next_state not in closed_list:
open_list.append(next_state)
next_state.generated_nodes += 1 # 更新生成节点数
return current_node.expanded_nodes, current_node.generated_nodes
# A* 迭代
def solve(self):
openTable = Queue() # 使用队列代替列表
closeTable = []
openTable.put(self) # 入队
while not openTable.empty(): # 队列非空时循环
n = openTable.get() # 出队
closeTable.append(n)
self.expanded_nodes = len(closeTable)
subStates = n.nextSteps()
for subState in subStates:
path = []
if (subState.state == subState.answer).all():
while subState.parent and subState.parent != originState:
path.append(subState.parent)
subState = subState.parent
path.reverse()
return path
for state in subStates: # 将子节点入队
openTable.put(state)
else:
return None, None
# openList
# openTable = []
# # closeList
# closeTable = []
# openTable.append(self)
# while len(openTable) > 0:
# # 下一步的点移除open
# n = openTable.pop(0)
# # 加入close
# closeTable.append(n)
#