自定义网络搭建
使用到的API有:keras.Sequential、Layers/Model
1.keras.Sequential
以前的代码已经很多次用到了这个接口,这里直接给出代码:
model = Sequential([
layers.Dense(256,activation=tf.nn.relu), # [b,784] ==>[b,256]
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(64,activation=tf.nn.relu),
layers.Dense(32,activation=tf.nn.relu),
layers.Dense(10)
])
model.build(input_shape=[None,28*28])
model.summary()Sequential还可以通过一些API去管理参数,如:model.trainable_variables、model.call(),前者是用来获取网络中所有的可训练参数,后者则是相当于逐层调model方法
2.Layer/Model
Layer的全路径为keras.layers.Layer,Model的全路径为keras.Model(包含compile,fit,evaluate功能)
class MyDense(keras.layers.Layer):
def __init__(self,inp_dim,outp_dim):
super(MyDense, self).__init__()
self.kernel = self.add_variable(‘w‘,[inp_dim,outp_dim])
self.bias = self.add_variable(‘b‘,[outp_dim])
def call(self,inputs,training=None):
out = inputs @ self.kernel + self.bias
return out
class MyModel(keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = MyDense(28*28,256)
self.fc2 = MyDense(256, 128)
self.fc3 = MyDense(128, 64)
self.fc4 = MyDense(64, 32)
self.fc5 = MyDense(32, 10)
def call(self,inputs,training=None):
x = self.fc1(inputs)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
x = self.fc5(x)
return x 相关推荐
KyrieHe 2020-10-04
davidsmith 2020-09-04
GDGYZL 2020-08-28
comwayLi 2020-08-16
xiaoxiaokeke 2020-08-04
xiaoxiaokeke 2020-07-28
诗蕊 2020-07-20
dataastron 2020-07-18
Niteowl 2020-07-15
zhongkeli 2020-07-14
xiaoxiaokeke 2020-06-27
dataastron 2020-06-25
xiaoxiaokeke 2020-06-25
CodeWang 2020-06-21
xiaoxiaokeke 2020-06-16
zhongkeli 2020-06-14
lujiandong 2020-06-14