# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import os
from cell import ConvLSTMCell
if 'session' in locals() and session is not None:
print('Close interactive session')
session.close()
def main():
shape=[132, 228]
kernel = [3, 3]
filters = 12
batch_size=181
num_steps=8
channels=7
inputs = tf.placeholder(tf.float32, [batch_size, num_steps] + shape + [channels])
cell = ConvLSTMCell(shape, filters, kernel)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype, time_major=True)
print(inputs.shape)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):
for step, (X, Y) in enumerate(epoch):
o, s = sess.run([outputs, state], feed_dict={inputs:inp})
print o.shape
#init_state = tf.zeros([batch_size, state_size])
#path='/home/zhanglu/python/AIRnet/gfs_air'#gfs_air201506.npz'
#gen_bath(path,batch_size,num_steps,channels,shape)
'''
if os.path.exists(path):
for root,dirs,files in os.walk(path):
sub_f=[root+f for f in files if f[-3:] == 'npz']
for filename in sub_f:
print(filename)
with np.load(filename) as f_d:
b_arr=np.array(f_d['gfs_air'])
#0:tm,1:rh,2:ugrd,3:vgrd,4:prate,5:tcdc,6:CO,7:PM10,8:O3,9:SO2,10:NO2,11:AQI,12:PM25
#tmp:Temperature
#rh :Relative Humidity
#ugrd:U-Component of Wind
#vgrd:V-Component of Wind
#prate:Precipitation Rate
#tcdc:Total Cloud Cover
#data_atm=np.array(b_arr[:,:,:,0])
#data_PM25=np.array(b_arr[:,:,:,12])
#data=np.stack((np.array(b_arr[:,:,:,0]),np.array(b_arr[:,:,:,1]),np.array(b_arr[:,:,:,12])))
num+=b_arr.shape[2]
print(num,b_arr.shape[2])
'''
def gen_bath(path,batch_size,num_steps,channels,shape):
month_d=[239,248,240,248,248,240,248,
240,248,248,232,248,240,248,
240,248,248,240,248,240,248,
248,224,248,240,248,240,248,248,169]
file_nameF=['201505','201506','201507','201508','201509','201510','201511','201512','201601',
'201602','201603','201604','201605','201606','201607','201608','201609','201610',
'201611','201612','201701','201702','201703','201704','201705','201706','201707',
'201708','201709']
#H:\IEEE_ICDM\gfs_air201504.npz
atm_type=12
#0:tm,1:rh,2:ugrd,3:vgrd,4:prate,5:tcdc,6:CO,7:PM10,8:O3,9:SO2,10:NO2,11:AQI,12:PM25
#path='H:\\IEEE_ICDM\\gfs_air'
time_len=7240
batch_partition_length = time_len // batch_size
print(batch_partition_length)
data_x = np.zeros([batch_size, batch_partition_length] + shape + [channels], dtype=np.float32)
data_y = np.zeros([batch_size, batch_partition_length] + shape + [channels], dtype=np.float32)
print(data_x.shape)
j=0
sum_time=month_d[0]
x=month_d[0]
filename=path+file_nameF[0]+'.npz'
next_range=0
ii=0
for i in range(batch_size):
x=x-batch_partition_length
#print(i*batch_partition_length)
if sum_time >i*batch_partition_length:
b_ii=ii
ii=ii+batch_partition_length
if (sum_time-i*batch_partition_length)<= 40:
range1=-sum_time+i*batch_partition_length
ii=40+range1
if range1<0:
if j+1 >=len(file_nameF):
break
filename1=path+file_nameF[j+1]+'.npz'
print(filename,range1)
print(filename1,range1+40)
f=np.load(filename)
b=np.array(f['gfs_air'])
data_b=np.array(b[:,:,range1:,0:6])
data_b1=np.array(b[:,:,range1:,atm_type])
data_b1=data_b1[:,:,:,np.newaxis]
data_b=np.concatenate((data_b,data_b1),axis=3)
f.close()
f=np.load(filename1)
b=np.array(f['gfs_air'])
data_a=np.array(b[:,:,0:range1+40,0:6])
data_a1=np.array(b[:,:,0:range1+40,atm_type])
data_a1=data_a1[:,:,:,np.newaxis]
data_a=np.concatenate((data_a,data_a1),axis=3)
f.close()
data=np.concatenate((data_b,data_a),axis=2)
#print('data1:',data.shape)
#print('data_a:',data_a.shape)
#print('data:',data.shape)
else:
a_ii=ii
f=np.load(filename)
b=np.array(f['gfs_air'])
data=np.array(b[:,:,b_ii:a_ii,0:6])
data1=np.array(b[:,:,b_ii:a_ii,atm_type])
data1=data1[:,:,:,np.newaxis]
data=np.concatenate((data,data1),axis=3)
f.close()
#print('data2:',data.shape)
else:
j=j+1
if j+1 > len(file_nameF):
break
sum_time=sum_time+month_d[j]
filename=path+file_nameF[j]+'.npz'
data=data.swapaxes(2,0)
data=data.swapaxes(2,1)
print('data:',data.shape)
'''
#print(batch_partition_length)
#with np.load(filename) as f_d
# data_x[i] = raw_x[batch_partition_length * i:batch_partition_length * (i + 1)]
# data_y[i] = raw_y[batch_partition_length * i:batch_partition_length * (i + 1)]
#print(data_x.shape)
global batch_size,timesteps,shape,kernel,channels,filters
batch_size = 32
timesteps = 100
shape = [640, 480]
kernel = [3, 3]
channels = 6
filters = 12
# Create a placeholder for videos.
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels],name='input_placeholder')
outputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [1],name='labels_placeholder')
print(tf.shape(inputs),inputs.shape)
cell = ConvLSTMCell(shape, filters, kernel)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype)
################################################
# Add the ConvLSTM step.
cell = ConvLSTMCell(shape, filters, kernel)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype)
# There's also a ConvGRUCell that is more memory efficient.
from cell import ConvGRUCell
cell = ConvGRUCell(shape, filters, kernel)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype)
# It's also possible to enter 2D input or 4D input instead of 3D.
shape = [100]
kernel = [3]
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels])
cell = ConvLSTMCell(shape, filters, kernel)
outputs, state = tf.nn.bidirectional_dynamic_rnn(cell, cell, inputs, dtype=inputs.dtype)
shape = [50, 50, 50]
kernel = [1, 3, 5]
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels])
cell = ConvGRUCell(shape, filters, kernel)
outputs, state= tf.nn.bidirectional_dynamic_rnn(cell, cell, inputs, dtype=inputs.dtype)
'''
if __name__ == '__main__':
main()
print('##FINISH##')
convlstm.rar_ConvLSTM_ConvLstm、_conlstm实现分类_convlstm代码_卷积LSTM实现
版权申诉
178 浏览量
2022-09-22
14:10:50
上传
评论 1
收藏 2KB RAR 举报
小贝德罗
- 粉丝: 68
- 资源: 1万+
最新资源
- Screenshot_20240427_031602.jpg
- 网页PDF_2024年04月26日 23-46-14_QQ浏览器网页保存_QQ浏览器转格式(6).docx
- 直接插入排序,冒泡排序,直接选择排序.zip
- 在排序2的基础上,再次对快排进行优化,其次增加快排非递归,归并排序,归并排序非递归版.zip
- 实现了7种排序算法.三种复杂度排序.三种nlogn复杂度排序(堆排序,归并排序,快速排序)一种线性复杂度的排序.zip
- 冒泡排序 直接选择排序 直接插入排序 随机快速排序 归并排序 堆排序.zip
- 课设-内部排序算法比较 包括冒泡排序、直接插入排序、简单选择排序、快速排序、希尔排序、归并排序和堆排序.zip
- Python排序算法.zip
- C语言实现直接插入排序、希尔排序、选择排序、冒泡排序、堆排序、快速排序、归并排序、计数排序,并带图详解.zip
- 常用工具集参考用于图像等数据处理
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈