Source code for deepcell.layers.upsample

# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/tf-keras-retinanet/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Upsampling layers"""

import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
from keras.utils import conv_utils


[docs] class UpsampleLike(Layer): """Layer for upsampling a Tensor to be the same shape as another Tensor. Adapted from https://github.com/fizyr/keras-retinanet. Args: data_format (str): A string, one of ``channels_last`` (default) or ``channels_first``. The ordering of the dimensions in the inputs. ``channels_last`` corresponds to inputs with shape ``(batch, height, width, channels)`` while ``channels_first`` corresponds to inputs with shape ``(batch, channels, height, width)``. """ def __init__(self, data_format=None, **kwargs): super().__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) def _resize_drop_axis(self, image, size, axis): image_shape = tf.shape(image) new_shape = [] axes_resized = list(set([0, 1, 2, 3, 4]) - set([0, 4, axis])) for ax in range(K.ndim(image) - 1): if ax != axis: new_shape.append(image_shape[ax]) if ax == 3: new_shape.append(image_shape[-1] * image_shape[axis]) new_shape_2 = [] for ax in range(K.ndim(image)): if ax == 0 or ax == 4 or ax == axis: new_shape_2.append(image_shape[ax]) elif ax == axes_resized[0]: new_shape_2.append(size[0]) elif ax == axes_resized[1]: new_shape_2.append(size[1]) new_image = tf.reshape(image, new_shape) new_image_resized = tf.image.resize( new_image, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) new_image_2 = tf.reshape(new_image_resized, new_shape_2) return new_image_2
[docs] def resize_volumes(self, volume, size): # TODO: K.resize_volumes? if self.data_format == 'channels_first': volume = tf.transpose(volume, (0, 2, 3, 4, 1)) new_size = (size[2], size[3], size[4]) else: new_size = (size[1], size[2], size[3]) new_shape_0 = (new_size[1], new_size[2]) new_shape_1 = (new_size[0], new_size[1]) resized_volume = self._resize_drop_axis(volume, new_shape_0, axis=1) resized_volume = self._resize_drop_axis(resized_volume, new_shape_1, axis=3) new_shape_static = [None, None, None, None, volume.get_shape()[-1]] resized_volume.set_shape(new_shape_static) if self.data_format == 'channels_first': resized_volume = tf.transpose(resized_volume, (0, 4, 1, 2, 3)) return resized_volume
[docs] def call(self, inputs, **kwargs): source, target = inputs target_shape = K.shape(target) if source.get_shape().ndims == 4: if self.data_format == 'channels_first': source = tf.transpose(source, (0, 2, 3, 1)) new_shape = (target_shape[2], target_shape[3]) # TODO: K.resize_images? output = tf.image.resize( source, new_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) output = tf.transpose(output, (0, 3, 1, 2)) return output new_shape = (target_shape[1], target_shape[2]) return tf.image.resize( source, new_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) if source.get_shape().ndims == 5: output = self.resize_volumes(source, target_shape) return output else: raise ValueError('Expected input[0] to have ndim of 4 or 5, found' ' %s.' % source.get_shape().ndims)
[docs] def compute_output_shape(self, input_shape): in_0 = tensor_shape.TensorShape(input_shape[0]).as_list() in_1 = tensor_shape.TensorShape(input_shape[1]).as_list() if self.data_format == 'channels_first': return tensor_shape.TensorShape([in_0[0], in_0[1]] + in_1[2:]) return tensor_shape.TensorShape([in_0[0]] + in_1[1:-1] + [in_0[-1]])
[docs] def get_config(self): config = {'data_format': self.data_format} base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))