# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers to noramlize input images for 2D and 3D images"""
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.keras import backend as K
from tensorflow.keras import activations
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras.layers import Layer, InputSpec
from keras.utils import conv_utils
[docs]
class ImageNormalization2D(Layer):
"""Image Normalization layer for 2D data.
Args:
norm_method (str): Normalization method to use, one of:
"std", "max", "whole_image", None.
filter_size (int): The length of the convolution window.
data_format (str): A string, one of ``channels_last`` (default)
or ``channels_first``. The ordering of the dimensions in the
inputs. ``channels_last`` corresponds to inputs with shape
``(batch, height, width, channels)`` while ``channels_first``
corresponds to inputs with shape
``(batch, channels, height, width)``.
activation (function): Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: ``a(x) = x``).
use_bias (bool): Whether the layer uses a bias.
kernel_initializer (function): Initializer for the ``kernel`` weights
matrix, used for the linear transformation of the inputs.
bias_initializer (function): Initializer for the bias vector. If None,
the default initializer will be used.
kernel_regularizer (function): Regularizer function applied to the
``kernel`` weights matrix.
bias_regularizer (function): Regularizer function applied to the
bias vector.
activity_regularizer (function): Regularizer function applied to.
kernel_constraint (function): Constraint function applied to
the ``kernel`` weights matrix.
bias_constraint (function): Constraint function applied to the
bias vector.
"""
def __init__(self,
norm_method='std',
filter_size=61,
data_format=None,
activation=None,
use_bias=False,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
self.valid_modes = {'std', 'max', None, 'whole_image'}
if norm_method not in self.valid_modes:
raise ValueError(f'Invalid `norm_method`: "{norm_method}". '
f'Use one of {self.valid_modes}.')
if 'trainable' not in kwargs:
kwargs['trainable'] = False
super().__init__(
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=4) # hardcoded for 2D data
self.filter_size = filter_size
self.norm_method = norm_method
self.data_format = conv_utils.normalize_data_format(data_format)
if self.data_format == 'channels_first':
self.channel_axis = 1
else:
self.channel_axis = 3 # hardcoded for 2D data
if isinstance(self.norm_method, str):
self.norm_method = self.norm_method.lower()
[docs]
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank 4, '
'received input shape: %s' % input_shape)
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape.dims[channel_axis].value is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = int(input_shape[channel_axis])
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
kernel_shape = (self.filter_size, self.filter_size, input_dim, 1)
# self.kernel = self.add_weight(
# name='kernel',
# shape=kernel_shape,
# initializer=self.kernel_initializer,
# regularizer=self.kernel_regularizer,
# constraint=self.kernel_constraint,
# trainable=False,
# dtype=self.compute_dtype)
W = K.ones(kernel_shape, dtype=self.compute_dtype)
W = W / K.cast(K.prod(K.int_shape(W)), dtype=self.compute_dtype)
self.kernel = W
# self.set_weights([W])
if self.use_bias:
self.bias = self.add_weight(
name='bias',
shape=(self.filter_size, self.filter_size),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=False,
dtype=self.compute_dtype)
else:
self.bias = None
self.built = True
[docs]
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
return tensor_shape.TensorShape(input_shape)
def _average_filter(self, inputs):
# Depthwise convolution on CPU is only supported for NHWC format
if self.data_format == 'channels_first':
inputs = K.permute_dimensions(inputs, pattern=[0, 2, 3, 1])
outputs = tf.nn.depthwise_conv2d(inputs, self.kernel, [1, 1, 1, 1],
padding='SAME', data_format='NHWC')
if self.data_format == 'channels_first':
outputs = K.permute_dimensions(outputs, pattern=[0, 3, 1, 2])
return outputs
def _window_std_filter(self, inputs, epsilon=K.epsilon()):
c1 = self._average_filter(inputs)
c2 = self._average_filter(K.square(inputs))
output = K.sqrt(c2 - c1 * c1) + epsilon
return output
[docs]
def call(self, inputs):
if not self.norm_method:
outputs = inputs
elif self.norm_method == 'whole_image':
axes = [2, 3] if self.channel_axis == 1 else [1, 2]
outputs = inputs - K.mean(inputs, axis=axes, keepdims=True)
outputs = outputs / (K.std(inputs, axis=axes, keepdims=True) + K.epsilon())
elif self.norm_method == 'std':
outputs = inputs - self._average_filter(inputs)
outputs = outputs / self._window_std_filter(outputs)
elif self.norm_method == 'max':
outputs = inputs / K.max(inputs)
outputs = outputs - self._average_filter(outputs)
else:
raise NotImplementedError(f'"{self.norm_method}" is not a valid norm_method')
return outputs
[docs]
def get_config(self):
config = {
'norm_method': self.norm_method,
'filter_size': self.filter_size,
'data_format': self.data_format,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs]
class ImageNormalization3D(Layer):
"""Image Normalization layer for 3D data.
Args:
norm_method (str): Normalization method to use, one of:
"std", "max", "whole_image", None.
filter_size (int): The length of the convolution window.
data_format (str): A string, one of ``channels_last`` (default)
or ``channels_first``. The ordering of the dimensions in the
inputs. ``channels_last`` corresponds to inputs with shape
``(batch, height, width, channels)`` while ``channels_first``
corresponds to inputs with shape
``(batch, channels, height, width)``.
activation (function): Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: ``a(x) = x``).
use_bias (bool): Whether the layer uses a bias.
kernel_initializer (function): Initializer for the ``kernel`` weights
matrix, used for the linear transformation of the inputs.
bias_initializer (function): Initializer for the bias vector. If None,
the default initializer will be used.
kernel_regularizer (function): Regularizer function applied to the
``kernel`` weights matrix.
bias_regularizer (function): Regularizer function applied to the
bias vector.
activity_regularizer (function): Regularizer function applied to.
kernel_constraint (function): Constraint function applied to
the ``kernel`` weights matrix.
bias_constraint (function): Constraint function applied to the
bias vector.
"""
def __init__(self,
norm_method='std',
filter_size=61,
data_format=None,
activation=None,
use_bias=False,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
self.valid_modes = {'std', 'max', None, 'whole_image'}
if norm_method not in self.valid_modes:
raise ValueError(f'Invalid `norm_method`: "{norm_method}". '
f'Use one of {self.valid_modes}.')
if 'trainable' not in kwargs:
kwargs['trainable'] = False
super().__init__(
activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=5) # hardcoded for 3D data
self.filter_size = filter_size
self.norm_method = norm_method
self.data_format = conv_utils.normalize_data_format(data_format)
if self.data_format == 'channels_first':
self.channel_axis = 1
else:
self.channel_axis = 4 # hardcoded for 3D data
if isinstance(self.norm_method, str):
self.norm_method = self.norm_method.lower()
[docs]
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 5:
raise ValueError('Inputs should have rank 5, '
'received input shape: %s' % input_shape)
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape.dims[channel_axis].value is None:
raise ValueError('The channel dimension of the inputs '
'should be defined, found None: %s' % input_shape)
input_dim = int(input_shape[channel_axis])
self.input_spec = InputSpec(ndim=5, axes={channel_axis: input_dim})
if self.data_format == 'channels_first':
depth = int(input_shape[2])
else:
depth = int(input_shape[1])
kernel_shape = (depth, self.filter_size, self.filter_size, input_dim, 1)
# self.kernel = self.add_weight(
# 'kernel',
# shape=kernel_shape,
# initializer=self.kernel_initializer,
# regularizer=self.kernel_regularizer,
# constraint=self.kernel_constraint,
# trainable=False,
# dtype=self.compute_dtype)
W = K.ones(kernel_shape, dtype=self.compute_dtype)
W = W / K.cast(K.prod(K.int_shape(W)), dtype=self.compute_dtype)
self.kernel = W
# self.set_weights([W])
if self.use_bias:
self.bias = self.add_weight(
name='bias',
shape=(depth, self.filter_size, self.filter_size),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
trainable=False,
dtype=self.compute_dtype)
else:
self.bias = None
self.built = True
[docs]
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
return tensor_shape.TensorShape(input_shape)
def _average_filter(self, inputs):
if self.data_format == 'channels_first':
inputs = K.permute_dimensions(inputs, pattern=[0, 2, 3, 4, 1])
# TODO: conv3d vs depthwise_conv2d?
outputs = tf.nn.conv3d(inputs, self.kernel, [1, 1, 1, 1, 1],
padding='SAME', data_format='NDHWC')
if self.data_format == 'channels_first':
outputs = K.permute_dimensions(outputs, pattern=[0, 4, 1, 2, 3])
return outputs
def _window_std_filter(self, inputs, epsilon=K.epsilon()):
c1 = self._average_filter(inputs)
c2 = self._average_filter(K.square(inputs))
output = K.sqrt(c2 - c1 * c1) + epsilon
return output
[docs]
def call(self, inputs):
if not self.norm_method:
outputs = inputs
elif self.norm_method == 'whole_image':
axes = [3, 4] if self.channel_axis == 1 else [2, 3]
outputs = inputs - K.mean(inputs, axis=axes, keepdims=True)
outputs = outputs / (K.std(inputs, axis=axes, keepdims=True) + K.epsilon())
elif self.norm_method == 'std':
outputs = inputs - self._average_filter(inputs)
outputs = outputs / self._window_std_filter(outputs)
elif self.norm_method == 'max':
outputs = inputs / K.max(inputs)
outputs = outputs - self._average_filter(outputs)
else:
raise NotImplementedError(f'"{self.norm_method}" is not a valid norm_method')
return outputs
[docs]
def get_config(self):
config = {
'norm_method': self.norm_method,
'filter_size': self.filter_size,
'data_format': self.data_format,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))