最近要用到Tensorflow了,回顾一下。

参考:
使用tf.keras自定义模型建模后model.summary()中Param的计算过程

Model的两种实例化方式

【tf官网】Model实例化方式

在这里插入图片描述

1. 功能性API

def MyModel(input_shape):
    input1 = tf.keras.Input(shape=input_shape,name="input1")
    X = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")(input1)

    model = tf.keras.Model(inputs=input1,outputs=X,name="my_model") 
    return model

2. 继承tf.keras.Model

class MyModel(tf.keras.Model):
    def __init__(self,input_shape):
        super(MyModel,self).__init__()	# 必须在首行明确
        self.input1 = tf.keras.Input(shape=input_shape,name="input1")

        self.dense1 = tf.keras.layers.Dense(4,activation=tf.nn.relu,name="dense1")

        self.out1 = self.call(self.input1)
        
        # reinitialize
        super(MyModel,self).__init__(
            inputs=self.input1,
            outputs=self.out1,
            name="my_model"
        )
    
    # 前向转播过程
    def call(self,inputs):
        """
        参数:
            input           - 输入,形状必须为 self.input_shape
        """
        x = self.dense1(inputs)
        return x

summary输出

执行以下代码:

if __name__ == '__main__':
    model = MyModel((100,))
    model.summary()

输出如下,

  1. 功能性API
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input1 (InputLayer)          [(None, 100)]             0
_________________________________________________________________
dense1 (Dense)               (None, 4)                 404
=================================================================
Total params: 404
Trainable params: 404
Non-trainable params: 0
_________________________________________________________________
  1. 继承Model
Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input1 (InputLayer)          [(None, 100)]             0
_________________________________________________________________
dense1 (Dense)               (None, 4)                 404
=================================================================
Total params: 404
Trainable params: 404
Non-trainable params: 0
_________________________________________________________________

可以看到,summary()输出相同。


model.saveload_model

  1. 功能性API
if __name__ == '__main__':
    model = MyModel((100,))
    model.save("mymodel.h5")							# 保存模型
    model = tf.keras.models.load_model("mymodel.h5")	# 加载模型
    model.summary()
  1. 继承Model

加载模型时,需要明确custom_objects

if __name__ == '__main__':
    model = MyModel((100,))
    model.save("mymodel.h5")	# 保存模型
    # 加载模型,需要明确custom_objects
    model = tf.keras.models.load_model("mymodel.h5",custom_objects={"MyModel":MyModel})	
    model.summary()
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐