# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for applications"""
import logging
import timeit
import numpy as np
from deepcell_toolbox.utils import resize, tile_image, untile_image
[docs]
class Application:
"""Application object that takes a model with weights
and manages predictions
Args:
model (tensorflow.keras.Model): ``tf.keras.Model``
with loaded weights.
model_image_shape (tuple): Shape of input expected by ``model``.
dataset_metadata (str or dict): Metadata for the data that
``model`` was trained on.
model_metadata (str or dict): Training metadata for ``model``.
model_mpp (float): Microns per pixel resolution of the
training data used for ``model``.
preprocessing_fn (function): Pre-processing function to apply
to data prior to prediction.
postprocessing_fn (function): Post-processing function to apply
to data after prediction.
Must accept an input of a list of arrays and then
return a single array.
format_model_output_fn (function): Convert model output
from a list of matrices to a dictionary with keys for
each semantic head.
Raises:
ValueError: ``preprocessing_fn`` must be a callable function
ValueError: ``postprocessing_fn`` must be a callable function
ValueError: ``model_output_fn`` must be a callable function
"""
def __init__(self,
model,
model_image_shape=(128, 128, 1),
model_mpp=0.65,
preprocessing_fn=None,
postprocessing_fn=None,
format_model_output_fn=None,
dataset_metadata=None,
model_metadata=None):
self.model = model
self.model_image_shape = model_image_shape
# Require dimension 1 larger than model_input_shape due to addition of batch dimension
self.required_rank = len(self.model_image_shape) + 1
self.required_channels = self.model_image_shape[-1]
self.model_mpp = model_mpp
self.preprocessing_fn = preprocessing_fn
self.postprocessing_fn = postprocessing_fn
self.format_model_output_fn = format_model_output_fn
self.dataset_metadata = dataset_metadata
self.model_metadata = model_metadata
self.logger = logging.getLogger(self.__class__.__name__)
# Test that pre and post processing functions are callable
if self.preprocessing_fn is not None and not callable(self.preprocessing_fn):
raise ValueError('Preprocessing_fn must be a callable function.')
if self.postprocessing_fn is not None and not callable(self.postprocessing_fn):
raise ValueError('Postprocessing_fn must be a callable function.')
if self.format_model_output_fn is not None and not callable(self.format_model_output_fn):
raise ValueError('Format_model_output_fn must be a callable function.')
[docs]
def predict(self, x):
raise NotImplementedError
[docs]
def _preprocess(self, image, **kwargs):
"""Preprocess ``image`` if ``preprocessing_fn`` is defined.
Otherwise return ``image`` unmodified.
Args:
image (numpy.array): 4D stack of images
kwargs (dict): Keyword arguments for ``preprocessing_fn``.
Returns:
numpy.array: The pre-processed ``image``.
"""
if self.preprocessing_fn is not None:
t = timeit.default_timer()
self.logger.debug('Pre-processing data with %s and kwargs: %s',
self.preprocessing_fn.__name__, kwargs)
image = self.preprocessing_fn(image, **kwargs)
self.logger.debug('Pre-processed data with %s in %s s',
self.preprocessing_fn.__name__,
timeit.default_timer() - t)
return image
[docs]
def _postprocess(self, image, **kwargs):
"""Applies postprocessing function to image if one has been defined.
Otherwise returns unmodified image.
Args:
image (numpy.array or list): Input to postprocessing function
either an ``numpy.array`` or list of ``numpy.arrays``.
Returns:
numpy.array: labeled image
"""
if self.postprocessing_fn is not None:
t = timeit.default_timer()
self.logger.debug('Post-processing results with %s and kwargs: %s',
self.postprocessing_fn.__name__, kwargs)
image = self.postprocessing_fn(image, **kwargs)
# Restore channel dimension if not already there
if len(image.shape) == self.required_rank - 1:
image = np.expand_dims(image, axis=-1)
self.logger.debug('Post-processed results with %s in %s s',
self.postprocessing_fn.__name__,
timeit.default_timer() - t)
elif isinstance(image, list) and len(image) == 1:
image = image[0]
return image
[docs]
def _untile_output(self, output_tiles, tiles_info):
"""Untiles either a single array or a list of arrays
according to a dictionary of tiling specs
Args:
output_tiles (numpy.array or list): Array or list of arrays.
tiles_info (dict): Tiling specs output by the tiling function.
Returns:
numpy.array or list: Array or list according to input with untiled images
"""
# If padding was used, remove padding
if tiles_info.get('padding', False):
def _process(im, tiles_info):
((xl, xh), (yl, yh)) = tiles_info['x_pad'], tiles_info['y_pad']
# Edge-case: upper-bound == 0 - this can occur when only one of
# either X or Y is smaller than model_img_shape while the other
# is equal to model_image_shape.
xh = -xh if xh != 0 else None
yh = -yh if yh != 0 else None
return im[:, xl:xh, yl:yh, :]
# Otherwise untile
else:
def _process(im, tiles_info):
out = untile_image(im, tiles_info, model_input_shape=self.model_image_shape)
return out
if isinstance(output_tiles, list):
output_images = [_process(o, tiles_info) for o in output_tiles]
else:
output_images = _process(output_tiles, tiles_info)
return output_images
[docs]
def _resize_output(self, image, original_shape):
"""Rescales input if the shape does not match the original shape
excluding the batch and channel dimensions.
Args:
image (numpy.array): Image to be rescaled to original shape
original_shape (tuple): Shape of the original input image
Returns:
numpy.array: Rescaled image
"""
if not isinstance(image, list):
image = [image]
for i in range(len(image)):
img = image[i]
# Compare x,y based on rank of image
if len(img.shape) == 4:
same = img.shape[1:-1] == original_shape[1:-1]
elif len(img.shape) == 3:
same = img.shape[1:] == original_shape[1:-1]
else:
same = img.shape == original_shape[1:-1]
# Resize if same is false
if not same:
# Resize function only takes the x,y dimensions for shape
new_shape = original_shape[1:-1]
img = resize(img, new_shape,
data_format='channels_last',
labeled_image=True)
image[i] = img
if len(image) == 1:
image = image[0]
return image
[docs]
def _batch_predict(self, tiles, batch_size):
"""Batch process tiles to generate model predictions.
The built-in keras.predict function has support for batching, but
loads the entire image stack into GPU memory, which is prohibitive
for large images. This function uses similar code to the underlying
model.predict function without soaking up GPU memory.
Args:
tiles (numpy.array): Tiled data which will be fed to model
batch_size (int): Number of images to predict on per batch
Returns:
list: Model outputs
"""
# list to hold final output
output_tiles = []
# loop through each batch
for i in range(0, tiles.shape[0], batch_size):
batch_inputs = tiles[i:i + batch_size, ...]
batch_outputs = self.model.predict(batch_inputs, batch_size=batch_size)
# model with only a single output gets temporarily converted to a list
if not isinstance(batch_outputs, list):
batch_outputs = [batch_outputs]
# initialize output list with empty arrays to hold all batches
if not output_tiles:
for batch_out in batch_outputs:
shape = (tiles.shape[0],) + batch_out.shape[1:]
output_tiles.append(np.zeros(shape, dtype=tiles.dtype))
# save each batch to corresponding index in output list
for j, batch_out in enumerate(batch_outputs):
output_tiles[j][i:i + batch_size, ...] = batch_out
return output_tiles
[docs]
def _run_model(self,
image,
batch_size=4,
pad_mode='constant',
preprocess_kwargs={}):
"""Run the model to generate output probabilities on the data.
Args:
image (numpy.array): Image with shape ``[batch, x, y, channel]``
batch_size (int): Number of images to predict on per batch.
pad_mode (str): The padding mode, one of "constant" or "reflect".
preprocess_kwargs (dict): Keyword arguments to pass to
the preprocessing function.
Returns:
numpy.array: Model outputs
"""
# Preprocess image if function is defined
image = self._preprocess(image, **preprocess_kwargs)
# Tile images, raises error if the image is not 4d
tiles, tiles_info = self._tile_input(image, pad_mode=pad_mode)
# Run images through model
t = timeit.default_timer()
output_tiles = self._batch_predict(tiles=tiles, batch_size=batch_size)
self.logger.debug('Model inference finished in %s s',
timeit.default_timer() - t)
# Untile images
output_images = self._untile_output(output_tiles, tiles_info)
# restructure outputs into a dict if function provided
formatted_images = self._format_model_output(output_images)
return formatted_images
[docs]
def _predict_segmentation(self,
image,
batch_size=4,
image_mpp=None,
pad_mode='constant',
preprocess_kwargs={},
postprocess_kwargs={}):
"""Generates a labeled image of the input running prediction with
appropriate pre and post processing functions.
Input images are required to have 4 dimensions
``[batch, x, y, channel]``. Additional empty dimensions can be added
using ``np.expand_dims``.
Args:
image (numpy.array): Input image with shape
``[batch, x, y, channel]``.
batch_size (int): Number of images to predict on per batch.
image_mpp (float): Microns per pixel for ``image``.
pad_mode (str): The padding mode, one of "constant" or "reflect".
preprocess_kwargs (dict): Keyword arguments to pass to the
pre-processing function.
postprocess_kwargs (dict): Keyword arguments to pass to the
post-processing function.
Raises:
ValueError: Input data must match required rank, calculated as one
dimension more (batch dimension) than expected by the model.
ValueError: Input data must match required number of channels.
Returns:
numpy.array: Labeled image
"""
# Check input size of image
if len(image.shape) != self.required_rank:
raise ValueError(f'Input data must have {self.required_rank} dimensions. '
f'Input data only has {len(image.shape)} dimensions')
if image.shape[-1] != self.required_channels:
raise ValueError(f'Input data must have {self.required_channels} channels. '
f'Input data only has {image.shape[-1]} channels')
# Resize image, returns unmodified if appropriate
resized_image = self._resize_input(image, image_mpp)
# Generate model outputs
output_images = self._run_model(
image=resized_image, batch_size=batch_size,
pad_mode=pad_mode, preprocess_kwargs=preprocess_kwargs
)
# Postprocess predictions to create label image
label_image = self._postprocess(output_images, **postprocess_kwargs)
# Resize label_image back to original resolution if necessary
label_image = self._resize_output(label_image, image.shape)
return label_image