# coding:utf-8
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
import time
def createMatrix(m, n):
A = np.zeros((n + 2, m + 2))
Up = np.ones((m + 2, 1)) * 100
Down = np.ones((m + 2, 1)) * 0
Lf = np.ones((1, n + 2)) * 75
Rt = np.ones((1, n + 2)) * 50
A[0, :] = Up.ravel()
A[n + 1, :] = Down.ravel()
A[:, 0] = Lf.ravel()
A[:, m + 1] = Rt.ravel()
return A
def oneIter(A, r_lf, r_rt):
a_size = A.shape
m = a_size[1] - 2
n = a_size[0] - 2
# create init ImpMatrix M and b
M = np.diag(np.ones((1, m)).ravel() * (1 + r_lf))
M = M + np.diag(np.ones((1, m - 1)).ravel() * (-1.0 * r_lf / 2), 1)
M = M + np.diag(np.ones((1, m - 1)).ravel() * (-1.0 * r_lf / 2), -1)
B = A.copy()
for j in range(1, n + 1):
b = np.zeros((m, 1))
rowA = A[j, :]
b[0] = b[0] + rowA[0] * r_lf / 2
b[m - 1] = b[m - 1] + rowA[m - 1] * r_lf / 2
for i in range(1, m + 1):
colA = A[j - 1:j + 1 + 1, i]
b[i - 1] = b[i - 1] + r_rt / 2 * colA[0] + (1 - r_rt) * colA[1] + r_rt / 2 * colA[2]
B[j, 1:m + 1] = np.linalg.solve(M, b).ravel()
return B
def computeA(m, n, rx, ry, iter):
A = createMatrix(m, n)
print(
'total iter=%s') % (iter)
for i in range(1, iter):
print
('iter num=%s' % (i))
A = oneIter(A, rx, ry)
B = oneIter(np.transpose(A), ry, rx)
A = np.transpose(B)
return A
def computeOneIter(A, m, n, rx, ry):
A = oneIter(A, rx, ry)
B = oneIter(np.transpose(A), ry, rx)
A = np.transpose(B)
return A
def getStart():
X_INTERVAL = [0, 20]
Y_INTERVAL = [0, 30]
T = [0, 10]
deltax = 0.5
deltay = 0.3
tao = 1.0 / 3 * min(deltax, deltay) * min(deltax, deltay)
m = (X_INTERVAL[1] - X_INTERVAL[0]) / deltax - 1
n = (Y_INTERVAL[1] - Y_INTERVAL[0]) / deltay - 1
m = int(m)
n = int(n)
print
('m=%s,n=%s' % (m, n))
x = np.linspace(X_INTERVAL[0], X_INTERVAL[1], m)
y = np.linspace(Y_INTERVAL[0], Y_INTERVAL[1], n)
# A = computeA(m,n,tao/deltax/deltax, tao/deltay/deltay, int((T[1] - T[0])/tao))
# animation
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X = x
Y = y
X, Y = np.meshgrid(X, Y)
wframe = None
iter = int((T[1] - T[0]) / tao)
A = createMatrix(m - 2, n - 2)
for i in range(iter):
A = computeOneIter(A, m, n, tao / deltax / deltax, tao / deltay / deltay)
if wframe:
ax.collections.remove(wframe)
wframe = ax.plot_wireframe(X, Y, A, rstride=2, cstride=2)
plt.pause(0.01)
print
( 'iter=', i)
m = A.shape[0]
n = A.shape[1]
return A, x, y
if __name__ == '__main__':
getStart()
评论0