# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for creating model backbones"""
import copy
from tensorflow.keras import backend as K
from tensorflow.keras import applications
from tensorflow.keras.backend import is_keras_tensor
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv3D, BatchNormalization
from tensorflow.keras.layers import Activation, MaxPool2D, MaxPool3D
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras.utils import get_source_inputs
[docs]
def featurenet_block(x, n_filters):
"""Add a set of layers that make up one unit of the featurenet backbone
Args:
x (tensorflow.keras.Layer): Keras layer object to pass to
backbone unit
n_filters (int): Number of filters to use for convolutional layers
Returns:
tensorflow.keras.Layer: Keras layer object
"""
df = K.image_data_format()
# conv set 1
x = Conv2D(n_filters, (3, 3), strides=(1, 1), padding='same', data_format=df)(x)
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
# conv set 2
x = Conv2D(n_filters, (3, 3), strides=(1, 1), padding='same', data_format=df)(x)
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
# Final max pooling stage
x = MaxPool2D(pool_size=(2, 2), padding='same', data_format=df)(x)
return x
[docs]
def featurenet_3D_block(x, n_filters):
"""Add a set of layers that make up one unit of the featurenet 3D backbone
Args:
x (tensorflow.keras.Layer): Keras layer object to pass to
backbone unit
n_filters (int): Number of filters to use for convolutional layers
Returns:
tensorflow.keras.Layer: Keras layer object
"""
df = K.image_data_format()
# conv set 1
x = Conv3D(n_filters, (3, 3, 3), strides=(1, 1, 1), padding='same', data_format=df)(x)
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
# conv set 2
x = Conv3D(n_filters, (3, 3, 3), strides=(1, 1, 1), padding='same', data_format=df)(x)
x = BatchNormalization(axis=-1)(x)
x = Activation('relu')(x)
# Final max pooling stage
x = MaxPool3D(pool_size=(2, 2, 2), data_format=df)(x)
return x
[docs]
def featurenet_backbone(input_tensor=None, input_shape=None,
n_filters=32, **kwargs):
"""Construct the deepcell backbone with five convolutional units
Args:
input_tensor (tensor): Input tensor to specify input size
n_filters (int): Number of filters for convolutional layers
Returns:
tuple: List of backbone layers, list of backbone names
"""
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
if not is_keras_tensor(input_tensor):
img_input = Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
# Build out backbone
c1 = featurenet_block(img_input, n_filters) # 1/2 64x64
c2 = featurenet_block(c1, n_filters) # 1/4 32x32
c3 = featurenet_block(c2, n_filters) # 1/8 16x16
c4 = featurenet_block(c3, n_filters) # 1/16 8x8
c5 = featurenet_block(c4, n_filters) # 1/32 4x4
backbone_features = [c1, c2, c3, c4, c5]
backbone_names = ['C1', 'C2', 'C3', 'C4', 'C5']
output_dict = {}
for name, feature in zip(backbone_names, backbone_features):
output_dict[name] = feature
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
model = Model(inputs=inputs, outputs=backbone_features)
return model, output_dict
[docs]
def featurenet_3D_backbone(input_tensor=None, input_shape=None,
n_filters=32, **kwargs):
"""Construct the deepcell backbone with five convolutional units
Args:
input_tensor (tensor): Input tensor to specify input size
n_filters (int): Number of filters for convolutional layers
Returns:
tuple: List of backbone layers, list of backbone names
"""
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
if not is_keras_tensor(input_tensor):
img_input = Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
# Build out backbone
c1 = featurenet_3D_block(img_input, n_filters) # 1/2 64x64
c2 = featurenet_3D_block(c1, n_filters) # 1/4 32x32
c3 = featurenet_3D_block(c2, n_filters) # 1/8 16x16
c4 = featurenet_3D_block(c3, n_filters) # 1/16 8x8
c5 = featurenet_3D_block(c4, n_filters) # 1/32 4x4
backbone_features = [c1, c2, c3, c4, c5]
backbone_names = ['C1', 'C2', 'C3', 'C4', 'C5']
output_dict = {}
for name, feature in zip(backbone_names, backbone_features):
output_dict[name] = feature
if input_tensor is not None:
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input
model = Model(inputs=inputs, outputs=backbone_features)
return model, output_dict
[docs]
def get_backbone(backbone, input_tensor=None, input_shape=None,
use_imagenet=False, return_dict=True,
frames_per_batch=1, **kwargs):
"""Retrieve backbones for the construction of feature pyramid networks.
Args:
backbone (str): Name of the backbone to be retrieved.
input_tensor (tensor): The input tensor for the backbone.
Should have channel dimension of size 3
use_imagenet (bool): Load pre-trained weights for the backbone
return_dict (bool): Whether to return a dictionary of backbone layers,
e.g. ``{'C1': C1, 'C2': C2, 'C3': C3, 'C4': C4, 'C5': C5}``.
If false, the whole model is returned instead
kwargs (dict): Keyword dictionary for backbone constructions.
Relevant keys include ``'include_top'``,
``'weights'`` (should be ``None``),
``'input_shape'``, and ``'pooling'``.
Returns:
tensorflow.keras.Model: An instantiated backbone
Raises:
ValueError: bad backbone name
ValueError: featurenet backbone with pre-trained imagenet
"""
_backbone = str(backbone).lower()
featurenet_backbones = {
'featurenet': featurenet_backbone,
'featurenet3d': featurenet_3D_backbone,
'featurenet_3d': featurenet_3D_backbone
}
vgg_backbones = {
'vgg16': applications.vgg16.VGG16,
'vgg19': applications.vgg19.VGG19,
}
densenet_backbones = {
'densenet121': applications.densenet.DenseNet121,
'densenet169': applications.densenet.DenseNet169,
'densenet201': applications.densenet.DenseNet201,
}
mobilenet_backbones = {
'mobilenet': applications.mobilenet.MobileNet,
'mobilenetv2': applications.mobilenet_v2.MobileNetV2,
'mobilenet_v2': applications.mobilenet_v2.MobileNetV2
}
resnet_backbones = {
'resnet50': applications.resnet.ResNet50,
'resnet101': applications.resnet.ResNet101,
'resnet152': applications.resnet.ResNet152,
}
resnet_v2_backbones = {
'resnet50v2': applications.resnet_v2.ResNet50V2,
'resnet101v2': applications.resnet_v2.ResNet101V2,
'resnet152v2': applications.resnet_v2.ResNet152V2,
}
# resnext_backbones = {
# 'resnext50': applications.resnext.ResNeXt50,
# 'resnext101': applications.resnext.ResNeXt101,
# }
nasnet_backbones = {
'nasnet_large': applications.nasnet.NASNetLarge,
'nasnet_mobile': applications.nasnet.NASNetMobile,
}
efficientnet_backbones = {
'efficientnetb0': applications.efficientnet.EfficientNetB0,
'efficientnetb1': applications.efficientnet.EfficientNetB1,
'efficientnetb2': applications.efficientnet.EfficientNetB2,
'efficientnetb3': applications.efficientnet.EfficientNetB3,
'efficientnetb4': applications.efficientnet.EfficientNetB4,
'efficientnetb5': applications.efficientnet.EfficientNetB5,
'efficientnetb6': applications.efficientnet.EfficientNetB6,
'efficientnetb7': applications.efficientnet.EfficientNetB7,
}
efficientnet_v2_backbones = {
'efficientnetv2b0': applications.efficientnet_v2.EfficientNetV2B0,
'efficientnetv2b1': applications.efficientnet_v2.EfficientNetV2B1,
'efficientnetv2b2': applications.efficientnet_v2.EfficientNetV2B2,
'efficientnetv2b3': applications.efficientnet_v2.EfficientNetV2B3,
'efficientnetv2bl': applications.efficientnet_v2.EfficientNetV2L,
'efficientnetv2bm': applications.efficientnet_v2.EfficientNetV2M,
'efficientnetv2bs': applications.efficientnet_v2.EfficientNetV2S,
}
# TODO: Check and make sure **kwargs is in the right format.
# 'weights' flag should be None, and 'input_shape' must have size 3 on the channel axis
if frames_per_batch == 1:
if input_tensor is not None:
img_input = input_tensor
else:
if input_shape:
img_input = Input(shape=input_shape)
else:
img_input = Input(shape=(None, None, 3))
else:
# using 3D data but a 2D backbone.
# TODO: why ignore input_tensor
if input_shape:
img_input = Input(shape=input_shape)
else:
img_input = Input(shape=(None, None, 3))
if use_imagenet:
kwargs_with_weights = copy.copy(kwargs)
kwargs_with_weights['weights'] = 'imagenet'
else:
kwargs['weights'] = None
if _backbone in featurenet_backbones:
if use_imagenet:
raise ValueError('A featurenet backbone that is pre-trained on '
'imagenet does not exist')
model_cls = featurenet_backbones[_backbone]
model, output_dict = model_cls(input_tensor=img_input, **kwargs)
layer_outputs = [output_dict['C1'], output_dict['C2'], output_dict['C3'],
output_dict['C4'], output_dict['C5']]
elif _backbone in vgg_backbones:
model_cls = vgg_backbones[_backbone]
model = model_cls(input_tensor=img_input, **kwargs)
# Set the weights of the model if requested
if use_imagenet:
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
layer_names = ['block1_pool', 'block2_pool', 'block3_pool', 'block4_pool', 'block5_pool']
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in densenet_backbones:
model_cls = densenet_backbones[_backbone]
model = model_cls(input_tensor=img_input, **kwargs)
if _backbone == 'densenet121':
blocks = [6, 12, 24, 16]
elif _backbone == 'densenet169':
blocks = [6, 12, 32, 32]
elif _backbone == 'densenet201':
blocks = [6, 12, 48, 32]
# Set the weights of the model if requested
if use_imagenet:
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
layer_names = ['conv1/relu'] + [f'conv{idx + 2}_block{block_num}_concat'
for idx, block_num in enumerate(blocks)]
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in resnet_backbones:
model_cls = resnet_backbones[_backbone]
model = model_cls(input_tensor=img_input, **kwargs)
# Set the weights of the model if requested
if use_imagenet:
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
if _backbone == 'resnet50':
layer_names = ['conv1_relu', 'conv2_block3_out', 'conv3_block4_out',
'conv4_block6_out', 'conv5_block3_out']
elif _backbone == 'resnet101':
layer_names = ['conv1_relu', 'conv2_block3_out', 'conv3_block4_out',
'conv4_block23_out', 'conv5_block3_out']
elif _backbone == 'resnet152':
layer_names = ['conv1_relu', 'conv2_block3_out', 'conv3_block8_out',
'conv4_block36_out', 'conv5_block3_out']
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in resnet_v2_backbones:
model_cls = resnet_v2_backbones[_backbone]
model = model_cls(input_tensor=img_input, **kwargs)
# Set the weights of the model if requested
if use_imagenet:
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
if _backbone == 'resnet50v2':
layer_names = ['post_relu', 'conv2_block3_out', 'conv3_block4_out',
'conv4_block6_out', 'conv5_block3_out']
elif _backbone == 'resnet101v2':
layer_names = ['post_relu', 'conv2_block3_out', 'conv3_block4_out',
'conv4_block23_out', 'conv5_block3_out']
elif _backbone == 'resnet152v2':
layer_names = ['post_relu', 'conv2_block3_out', 'conv3_block8_out',
'conv4_block36_out', 'conv5_block3_out']
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
# elif _backbone in resnext_backbones:
# model_cls = resnext_backbones[_backbone]
# model = model_cls(input_tensor=img_input, **kwargs)
#
# # Set the weights of the model if requested
# if use_imagenet:
# model_with_weights = model_cls(**kwargs_with_weights)
# model_with_weights.save_weights('model_weights.h5')
# model.load_weights('model_weights.h5', by_name=True)
#
# if _backbone == 'resnext50':
# layer_names = ['conv1_relu', 'conv2_block3_out', 'conv3_block4_out',
# 'conv4_block6_out', 'conv5_block3_out']
# elif _backbone == 'resnext101':
# layer_names = ['conv1_relu', 'conv2_block3_out', 'conv3_block4_out',
# 'conv4_block23_out', 'conv5_block3_out']
#
# layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in mobilenet_backbones:
model_cls = mobilenet_backbones[_backbone]
alpha = kwargs.pop('alpha', 1.0)
model = model_cls(alpha=alpha, input_tensor=img_input, **kwargs)
if _backbone.endswith('v2'):
block_ids = (2, 5, 12)
layer_names = ['expanded_conv_project_BN'] + \
['block_%s_add' % i for i in block_ids] + \
['block_16_project_BN']
else:
block_ids = (1, 3, 5, 11, 13)
layer_names = ['conv_pw_%s_relu' % i for i in block_ids]
# Set the weights of the model if requested
if use_imagenet:
model_with_weights = model_cls(alpha=alpha, **kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in nasnet_backbones:
model_cls = nasnet_backbones[_backbone]
model = model_cls(input_tensor=img_input, **kwargs)
if _backbone.endswith('large'):
block_ids = [5, 12, 18]
else:
block_ids = [3, 8, 12]
# Set the weights of the model if requested
if use_imagenet:
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
layer_names = ['stem_bn1', 'reduction_concat_stem_1']
layer_names.extend(['normal_concat_%s' % i for i in block_ids])
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in efficientnet_backbones:
model_cls = efficientnet_backbones[_backbone]
model = model_cls(input_tensor=img_input, **kwargs)
if use_imagenet:
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
layer_names = ['block2a_expand_activation', 'block3a_expand_activation',
'block4a_expand_activation', 'block6a_expand_activation',
'top_activation']
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
elif _backbone in efficientnet_v2_backbones:
model_cls = efficientnet_v2_backbones[_backbone]
kwargs['include_preprocessing'] = False
model = model_cls(input_tensor=img_input, **kwargs)
if use_imagenet:
kwargs_with_weights['include_preprocessing'] = False
model_with_weights = model_cls(**kwargs_with_weights)
model_with_weights.save_weights('model_weights.h5')
model.load_weights('model_weights.h5', by_name=True)
layer_names = ['block1b_add', 'block2c_add',
'block4a_expand_activation', 'block6a_expand_activation',
'top_activation']
layer_outputs = [model.get_layer(name=ln).output for ln in layer_names]
else:
join = lambda x: [v for y in x for v in list(y.keys())]
backbones = join([featurenet_backbones, densenet_backbones,
resnet_backbones, resnet_v2_backbones,
vgg_backbones, nasnet_backbones,
mobilenet_backbones, efficientnet_backbones,
efficientnet_v2_backbones])
raise ValueError('Invalid value for `backbone`. Must be one of: %s' %
', '.join(backbones))
if frames_per_batch > 1:
time_distributed_outputs = []
for i, out in enumerate(layer_outputs):
td_name = f'td_{i}'
model_name = f'model_{i}'
time_distributed_outputs.append(
TimeDistributed(Model(model.input, out, name=model_name),
name=td_name)(input_tensor))
if time_distributed_outputs:
layer_outputs = time_distributed_outputs
output_dict = {f'C{i + 1}': j for i, j in enumerate(layer_outputs)}
return (model, output_dict) if return_dict else model