代码:
- def segnet(pretrained_weights=None, input_size=(512, 512, 3), classNum=2, learning_rate=1e-5):
- inputs = Input(input_size)
- #encode
- #第一层 64,64
- conv1 = BatchNormalization()(
- Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs))
- conv1 = BatchNormalization()(
- Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1))
- pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
- #第二层 128,128
- conv2 = BatchNormalization()(
- Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1))
- conv2 = BatchNormalization()(
- Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2))
- pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
- #第三层 256,256,256
- conv3 = BatchNormalization()(
- Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2))
- conv3 = BatchNormalization()(
- Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3))
- conv3 = BatchNormalization()(
- Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3))
- pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
- #第四层 512,512,512
- conv4 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3))
- conv4 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4))
- conv4 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4))
- pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
- #第五层 512,512,512
- conv5 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4))
- conv5 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5))
- conv5 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5))
- pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)
- #decode
- #上采样
- up1 = UpSampling2D(size=(2, 2))(pool5)
- #第六层 512,512,512
- conv6 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up1))
- conv6 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6))
- conv6 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6))
- up2 = UpSampling2D(size=(2, 2))(conv6)
- #第七层 512,512,512
- conv7 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up2))
- conv7 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7))
- conv7 = BatchNormalization()(
- Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7))
- up3 = UpSampling2D(size=(2, 2))(conv7)
- #第八层 256,256,256
- conv8 = BatchNormalization()(
- Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up3))
- conv8 = BatchNormalization()(
- Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8))
- conv8 = BatchNormalization()(
- Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8))
- up4 = UpSampling2D(size=(2, 2))(conv8)
- # 第八层 256,256,256
- conv9 = BatchNormalization()(
- Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up4))
- conv9 = BatchNormalization()(
- Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9))
- up5 = UpSampling2D(size=(2, 2))(conv9)
- #第九层 64,64
- conv10 = BatchNormalization()(
- Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up5))
- conv10 = BatchNormalization()(
- Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10))
- # softmax 输出层
- conv11 = Conv2D(1, 1, padding='same', activation='sigmoid')(conv10)
- model = Model(inputs=inputs, outputs=conv11)
- return model
来源: http://www.bubuko.com/infodetail-3823794.html