【Tensorflow2】tf.keras中Model实例化的方式
tf.keras中Model实例化的方式`Model`的两种实例化方式1. 功能性API2. 继承`tf.keras.Model`summary输出`model.save` 与 `load_model`最近要用到Tensorflow了,回顾一下。参考:使用tf.keras自定义模型建模后model.summary()中Param的计算过程Model的两种实例化方式【tf官网】Model实例化方式1
·
tf.keras中Model实例化的方式
最近要用到
Tensorflow
了,回顾一下。
Model
的两种实例化方式
1. 功能性API
def MyModel(input_shape):
input1 = tf.keras.Input(shape=input_shape,name="input1")
X = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")(input1)
model = tf.keras.Model(inputs=input1,outputs=X,name="my_model")
return model
2. 继承tf.keras.Model
class MyModel(tf.keras.Model):
def __init__(self,input_shape):
super(MyModel,self).__init__() # 必须在首行明确
self.input1 = tf.keras.Input(shape=input_shape,name="input1")
self.dense1 = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")
self.out1 = self.call(self.input1)
# reinitialize
super(MyModel,self).__init__(
inputs=self.input1,
outputs=self.out1,
name="my_model"
)
# 前向转播过程
def call(self,inputs):
"""
参数:
input - 输入,形状必须为 self.input_shape
"""
x = self.dense1(inputs)
return x
summary输出
执行以下代码:
if __name__ == '__main__':
model = MyModel((100,))
model.summary()
输出如下,
- 功能性API
Model: "my_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input1 (InputLayer) [(None, 100)] 0
_________________________________________________________________
dense1 (Dense) (None, 4) 404
=================================================================
Total params: 404
Trainable params: 404
Non-trainable params: 0
_________________________________________________________________
- 继承
Model
Model: "my_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input1 (InputLayer) [(None, 100)] 0
_________________________________________________________________
dense1 (Dense) (None, 4) 404
=================================================================
Total params: 404
Trainable params: 404
Non-trainable params: 0
_________________________________________________________________
可以看到,summary()
输出相同。
model.save
与 load_model
- 功能性API
if __name__ == '__main__':
model = MyModel((100,))
model.save("mymodel.h5") # 保存模型
model = tf.keras.models.load_model("mymodel.h5") # 加载模型
model.summary()
- 继承
Model
加载模型时,需要明确custom_objects
if __name__ == '__main__':
model = MyModel((100,))
model.save("mymodel.h5") # 保存模型
# 加载模型,需要明确custom_objects
model = tf.keras.models.load_model("mymodel.h5",custom_objects={"MyModel":MyModel})
model.summary()
更多推荐
已为社区贡献2条内容
所有评论(0)