Source code for deepcell.layers.pooling

# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers to encode location data"""


import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import InputSpec
from keras.utils import conv_utils


[docs] class DilatedMaxPool2D(Layer): """Dilated max pooling layer for 2D inputs (e.g. images). Args: pool_size (int): An integer or tuple/list of 2 integers: (pool_height, pool_width) specifying the size of the pooling window. Can be a single integer to specify the same value for all spatial dimensions. strides (int): An integer or tuple/list of 2 integers, specifying the strides of the pooling operation. Can be a single integer to specify the same value for all spatial dimensions. dilation_rate (int): An integer or tuple/list of 2 integers, specifying the dilation rate for the pooling. padding (str): The padding method, either ``"valid"`` or ``"same"`` (case-insensitive). data_format (str): A string, one of ``channels_last`` (default) or ``channels_first``. The ordering of the dimensions in the inputs. ``channels_last`` corresponds to inputs with shape ``(batch, height, width, channels)`` while ``channels_first`` corresponds to inputs with shape ``(batch, channels, height, width)``. """ def __init__(self, pool_size=(2, 2), strides=None, dilation_rate=1, padding='valid', data_format=None, **kwargs): super().__init__(**kwargs) if strides is None or dilation_rate != 1 and dilation_rate != (1, 1): strides = (1, 1) self.pool_size = conv_utils.normalize_tuple(pool_size, 2, 'pool_size') self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') self.padding = conv_utils.normalize_padding(padding) self.data_format = conv_utils.normalize_data_format(data_format) self.input_spec = InputSpec(ndim=4)
[docs] def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() if self.data_format == 'channels_first': rows = input_shape[2] cols = input_shape[3] else: rows = input_shape[1] cols = input_shape[2] # TODO: workaround! padding = 'same' shapes do not match _padding = self.padding self.padding = 'valid' rows = conv_utils.conv_output_length(rows, self.pool_size[0], padding=self.padding, stride=self.strides[0], dilation=self.dilation_rate[0]) cols = conv_utils.conv_output_length(cols, self.pool_size[1], padding=self.padding, stride=self.strides[1], dilation=self.dilation_rate[1]) # END workaround self.padding = _padding if self.data_format == 'channels_first': output_shape = (input_shape[0], input_shape[1], rows, cols) else: output_shape = (input_shape[0], rows, cols, input_shape[3]) return tensor_shape.TensorShape(output_shape)
[docs] def call(self, inputs): if self.data_format == 'channels_first': inputs = K.permute_dimensions(inputs, pattern=[0, 2, 3, 1]) if self.padding == 'valid': outputs = tf.nn.pool(inputs, window_shape=self.pool_size, pooling_type='MAX', padding=self.padding.upper(), dilations=self.dilation_rate, strides=self.strides, data_format='NHWC') elif self.padding == 'same': input_shape = K.int_shape(inputs) rows = input_shape[1] cols = input_shape[2] rows_unpadded = conv_utils.conv_output_length( rows, self.pool_size[0], padding='valid', stride=self.strides[0], dilation=self.dilation_rate[0]) cols_unpadded = conv_utils.conv_output_length( cols, self.pool_size[1], padding='valid', stride=self.strides[1], dilation=self.dilation_rate[1]) w_pad = (rows - rows_unpadded) // 2 h_pad = (cols - cols_unpadded) // 2 w_pad = (w_pad, w_pad) h_pad = (h_pad, h_pad) pattern = [[0, 0], list(w_pad), list(h_pad), [0, 0]] # Pad the image outputs = tf.pad(inputs, pattern, mode='REFLECT') # Perform pooling outputs = tf.nn.pool(inputs, window_shape=self.pool_size, pooling_type='MAX', padding='VALID', dilations=self.dilation_rate, strides=self.strides, data_format='NHWC') if self.data_format == 'channels_first': outputs = K.permute_dimensions(outputs, pattern=[0, 3, 1, 2]) return outputs
[docs] def get_config(self): config = { 'pool_size': self.pool_size, 'padding': self.padding, 'dilation_rate': self.dilation_rate, 'strides': self.strides, 'data_format': self.data_format } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] class DilatedMaxPool3D(Layer): """Dilated max pooling layer for 3D inputs. Args: pool_size (int): An integer or tuple/list of 2 integers: (pool_height, pool_width) specifying the size of the pooling window. Can be a single integer to specify the same value for all spatial dimensions. strides (int): An integer or tuple/list of 2 integers, specifying the strides of the pooling operation. Can be a single integer to specify the same value for all spatial dimensions. dilation_rate (int): An integer or tuple/list of 2 integers, specifying the dilation rate for the pooling. padding (str): The padding method, either ``"valid"`` or ``"same"`` (case-insensitive). data_format (str): A string, one of ``channels_last`` (default) or ``channels_first``. The ordering of the dimensions in the inputs. ``channels_last`` corresponds to inputs with shape ``(batch, height, width, channels)`` while ``channels_first`` corresponds to inputs with shape ``(batch, channels, height, width)``. """ def __init__(self, pool_size=(1, 2, 2), strides=None, dilation_rate=1, padding='valid', data_format=None, **kwargs): super().__init__(**kwargs) data_format = conv_utils.normalize_data_format(data_format) if strides is None or dilation_rate != 1 and dilation_rate != (1, 1, 1): strides = (1, 1, 1) self.pool_size = conv_utils.normalize_tuple(pool_size, 3, 'pool_size') self.strides = conv_utils.normalize_tuple(strides, 3, 'strides') self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 3, 'dilation_rate') self.padding = conv_utils.normalize_padding(padding) self.data_format = conv_utils.normalize_data_format(data_format) self.input_spec = InputSpec(ndim=5)
[docs] def compute_output_shape(self, input_shape): if self.data_format == 'channels_first': time = input_shape[2] rows = input_shape[3] cols = input_shape[4] else: time = input_shape[1] rows = input_shape[2] cols = input_shape[3] # TODO: workaround! padding = 'same' shapes do not match _padding = self.padding self.padding = 'valid' time = conv_utils.conv_output_length(time, self.pool_size[0], padding=self.padding, stride=self.strides[0], dilation=self.dilation_rate[0]) rows = conv_utils.conv_output_length(rows, self.pool_size[1], padding=self.padding, stride=self.strides[1], dilation=self.dilation_rate[1]) cols = conv_utils.conv_output_length(cols, self.pool_size[2], padding=self.padding, stride=self.strides[2], dilation=self.dilation_rate[2]) # END workaround self.padding = _padding if self.data_format == 'channels_first': output_shape = (input_shape[0], input_shape[1], time, rows, cols) else: output_shape = (input_shape[0], time, rows, cols, input_shape[4]) return tensor_shape.TensorShape(output_shape)
[docs] def call(self, inputs): if self.data_format == 'channels_first': inputs = K.permute_dimensions(inputs, pattern=[0, 2, 3, 4, 1]) padding_input = self.padding.upper() if self.padding == 'valid': outputs = tf.nn.pool(inputs, window_shape=self.pool_size, pooling_type='MAX', padding=padding_input, dilations=self.dilation_rate, strides=self.strides, data_format='NDHWC') elif self.padding == 'same': input_shape = K.int_shape(inputs) times = input_shape[1] rows = input_shape[2] cols = input_shape[3] times_unpadded = conv_utils.conv_output_length( times, self.pool_size[0], padding='valid', stride=self.strides[0], dilation=self.dilation_rate[0]) rows_unpadded = conv_utils.conv_output_length( rows, self.pool_size[1], padding='valid', stride=self.strides[0], dilation=self.dilation_rate[1]) cols_unpadded = conv_utils.conv_output_length( cols, self.pool_size[2], padding='valid', stride=self.strides[1], dilation=self.dilation_rate[2]) t_pad = (times - times_unpadded) // 2 w_pad = (rows - rows_unpadded) // 2 h_pad = (cols - cols_unpadded) // 2 t_pad = (t_pad, t_pad) w_pad = (w_pad, w_pad) h_pad = (h_pad, h_pad) pattern = [[0, 0], list(t_pad), list(w_pad), list(h_pad), [0, 0]] # Pad the image outputs = tf.pad(inputs, pattern, mode='REFLECT') # Perform pooling outputs = tf.nn.pool(inputs, window_shape=self.pool_size, pooling_type='MAX', padding='VALID', dilations=self.dilation_rate, strides=self.strides, data_format='NDHWC') if self.data_format == 'channels_first': outputs = K.permute_dimensions(outputs, pattern=[0, 4, 1, 2, 3]) return outputs
[docs] def get_config(self): config = { 'pool_size': self.pool_size, 'padding': self.padding, 'dilation_rate': self.dilation_rate, 'strides': self.strides, 'data_format': self.data_format } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))