# 模型构建
TensorFlow 中提供了三种方式来构建模型:
- 使用 Sequential 按层顺序构建模型
- 使用函数式 API 构建任意结构模型
- 继承 Model 基类构建自定义模型
对于顺序结构的模型,优先使用 Sequential 方法构建。如果模型有多输入或者多输出,或者模型需要共享权重,或者模型具有残差连接等非顺序结构,推荐使用函数式 API 进行创建。如果无特定必要,尽可能避免使用 Model 子类化的方式构建模型,这种方式提供了极大的灵活性,但也有更大的概率出错。
# 案例:IMDB
```py
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras import *
train_token_path = "./data/imdb/train_token.csv"
test_token_path = "./data/imdb/test_token.csv"
MAX_WORDS = 10000 # We will only consider the top 10,000 words in the dataset
MAX_LEN = 200 # We will cut reviews after 200 words
BATCH_SIZE = 20
# 构建管道
def parse_line(line):
t = tf.strings.split(line,"\t")
label = tf.reshape(tf.cast(tf.strings.to_number(t[0]),tf.int32),(-1,))
features = tf.cast(tf.strings.to_number(tf.strings.split(t[1]," ")),tf.int32)
return (features,label)
ds_train= tf.data.TextLineDataset(filenames = [train_token_path]) \
.map(parse_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)
ds_test= tf.data.TextLineDataset(filenames = [test_token_path]) \
.map(parse_line,num_parallel_calls = tf.data.experimental.AUTOTUNE) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)
```
## Sequential 按层顺序创建模型
```py
tf.keras.backend.clear_session()
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(1,activation = "sigmoid"))
model.compile(optimizer='Nadam',
loss='binary_crossentropy',
metrics=['accuracy',"AUC"])
model.summary()
```
![模型案例](https://github.com/lyhue1991/eat_tensorflow2_in_30_days/raw/master/data/Sequential%E6%A8%A1%E5%9E%8B%E7%BB%93%E6%9E%84.png)
```py
import datetime
import matplotlib.pyplot as plt
baselogger = callbacks.BaseLogger(stateful_metrics=["AUC"])
logdir = "./data/keras_model/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
history = model.fit(ds_train,validation_data = ds_test,
epochs = 6,callbacks=[baselogger,tensorboard_callback])
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
def plot_metric(history, metric):
train_metrics = history.history[metric]
val_metrics = history.history['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
plot_metric(history,"AUC")
```
![AUC](https://github.com/lyhue1991/eat_tensorflow2_in_30_days/raw/master/data/6-1-fit%E6%A8%A1%E5%9E%8B.jpg)
# 函数式 API 创建任意结构模型
```py
tf.keras.backend.clear_session()
inputs = layers.Input(shape=[MAX_LEN])
x = layers.Embedding(MAX_WORDS,7)(inputs)
branch1 = layers.SeparableConv1D(64,3,activation="relu")(x)
branch1 = layers.MaxPool1D(3)(branch1)
branch1 = layers.SeparableConv1D(32,3,activation="relu")(branch1)
branch1 = layers.GlobalMaxPool1D()(branch1)
branch2 = layers.SeparableConv1D(64,5,activation="relu")(x)
branch2 = layers.MaxPool1D(5)(branch2)
branch2 = layers.SeparableConv1D(32,5,activation="relu")(branch2)
branch2 = layers.GlobalMaxPool1D()(branch2)
branch3 = layers.SeparableConv1D(64,7,activation="relu")(x)
branch3 = layers.MaxPool1D(7)(branch3)
branch3 = layers.SeparableConv1D(32,7,activation="relu")(branch3)
branch3 = layers.GlobalMaxPool1D()(branch3)
concat = layers.Concatenate()([branch1,branch2,branch3])
outputs = layers.Dense(1,activation = "sigmoid")(concat)
model = models.Model(inputs = inputs,outputs = outputs)
model.compile(optimizer='Nadam',
loss='binary_crossentropy',
metrics=['accuracy',"AUC"])
model.summary()
```
```sh
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 200)] 0
__________________________________________________________________________________________________
embedding (Embedding) (None, 200, 7) 70000 input_1[0][0]
__________________________________________________________________________________________________
separable_conv1d (SeparableConv (None, 198, 64) 533 embedding[0][0]
__________________________________________________________________________________________________
separable_conv1d_2 (SeparableCo (None, 196, 64) 547 embedding[0][0]
__________________________________________________________________________________________________
separable_conv1d_4 (SeparableCo (None, 194, 64) 561 embedding[0][0]
__________________________________________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 66, 64) 0 separable_conv1d[0][0]
__________________________________________________________________________________________________
max_pooling1d_1 (MaxPooling1D) (None, 39, 64) 0 separable_conv1d_2[0][0]
__________________________________________________________________________________________________
max_pooling1d_2 (MaxPooling1D) (None, 27, 64) 0 separable_conv1d_4[0][0]
__________________________________________________________________________________________________
separable_conv1d_1 (SeparableCo (None, 64, 32) 2272 max_pooling1d[0][0]
__________________________________________________________________________________________________
separable_conv1d_3 (SeparableCo (None, 35, 32) 2400 max_pooling1d_1[0][0]
__________________________________________________________________________________________________
separable_conv1d_5 (SeparableCo (None, 21, 32) 2528 max_pooling1d_2[0][0]
__________________________________________________________________________________________________
global_max_pooling1d (GlobalMax (None, 32) 0 separable_conv1d_1[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_1 (GlobalM (None, 32) 0 separable_conv1d_3[0][0]
__________________________________________________________________________________________________
global_max_pooling1d_2 (GlobalM (None, 32) 0 separable_conv1d_5[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 96) 0 global_max_pooling1d[0][0]
global_max_pooling1d_1[0][0]
global_max_pooling1d_2[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 1) 97 concatenate[0][0]
====================================================================================
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
人工智能与深度学习实战_-_TensorFlow_篇(MD_&_Notebooks)_TensorFlow-Notes.zip (43个子文件)
TensorFlow-Notes-master
_sidebar.md 1KB
增强框架
Keras
README.md 109B
TFLearn
simple_regression.ipynb 6KB
README.md 0B
Scikit Flow
README.md 10KB
.nojekyll 0B
张量与 AutoGraph
张量结构与操作.ipynb 32KB
张量的数学运算.ipynb 14KB
AutoGraph 编码规范.md 3KB
自动微分计算.ipynb 10KB
README.md 1KB
LICENSE 16KB
多端部署
TensorFire.md 3KB
tf_serving.ipynb 39KB
分布式集群.md 8KB
TensorFlow Serving.md 898B
模型训练
GPU
GPU 设置.md 5KB
TPU 设置.md 2KB
README.md 0B
TPU
案例:Minist TPU.ipynb 23KB
README.md 7KB
INTRODUCTION.md 15B
层次结构
计算图.md 6KB
DNN 二分类模型的多层次实现.ipynb 3.21MB
线性回归模型的多层次实现.ipynb 528KB
index.html 6KB
实践案例
建模基础
MNIST_与_Fashion_MNIST.ipynb 64KB
Imdb 文本建模.ipynb 20KB
Titanic 结构化建模.ipynb 198KB
Cifar2 图片建模.ipynb 127KB
时间序列数据建模.ipynb 175KB
模型构建
优化器.md 2KB
数据管道.ipynb 33KB
损失函数.md 6KB
回调函数.md 5KB
模型层.md 10KB
激活函数.md 3KB
特征列.md 7KB
评估指标.md 6KB
README.md 11KB
.gitignore 9B
README.md 5KB
header.svg 8KB
共 43 条
- 1
资源评论
普通网友
- 粉丝: 1127
- 资源: 5292
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- NE555+74LS192+74LS48电子秒表课程设计报告(纯数电实现)
- 基于深度学习的视频描述综述:视觉与语言的桥梁
- 2024年全球干式变压器行业规模及市场占有率分析报告
- 小红书2024新年市集合作方案解析与品牌营销策略
- 基于javaweb的沙发销售管理系统论文.doc
- 毕业设计Jupyter Notebook基于深度网络的垃圾识别与分类算法研究项目源代码,用PyTorch框架中的transforms方法对数据进行预处理操作,后经过多次调参实验,对比不同模型分类效果
- 基于java的扫雷游戏的设计与实现论文.doc
- 基于java的企业员工信息管理系统论文.doc
- 深度视频压缩框架:从预测编码到条件编码的技术革新
- 1221额的2的2的2额
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功