Source code for deepcell.applications.mesmer

# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mesmer application"""


from pathlib import Path

import numpy as np
import tensorflow as tf

from deepcell_toolbox.deep_watershed import deep_watershed
from deepcell_toolbox.processing import percentile_threshold
from deepcell_toolbox.processing import histogram_normalization

from deepcell.applications import Application
from deepcell.utils import fetch_data, extract_archive


MODEL_KEY = 'models/MultiplexSegmentation-9.tar.gz'
MODEL_NAME = 'MultiplexSegmentation'
MODEL_HASH = 'a1dfbce2594f927b9112f23a0a1739e0'


# pre- and post-processing functions
def mesmer_preprocess(image, **kwargs):
    """Preprocess input data for Mesmer model.

    Args:
        image: array to be processed

    Returns:
        np.array: processed image array
    """

    if len(image.shape) != 4:
        raise ValueError(f"Image data must be 4D, got image of shape {image.shape}")

    output = np.copy(image)
    threshold = kwargs.get('threshold', True)
    if threshold:
        percentile = kwargs.get('percentile', 99.9)
        output = percentile_threshold(image=output, percentile=percentile)

    normalize = kwargs.get('normalize', True)
    if normalize:
        kernel_size = kwargs.get('kernel_size', 128)
        output = histogram_normalization(image=output, kernel_size=kernel_size)

    return output


def format_output_mesmer(output_list):
    """Takes list of model outputs and formats into a dictionary for better readability

    Args:
        output_list (list): predictions from semantic heads

    Returns:
        dict: Dict of predictions for whole cell and nuclear.

    Raises:
        ValueError: if model output list is not len(4)
    """
    expected_length = 4
    if len(output_list) != expected_length:
        raise ValueError('output_list was length {}, expecting length {}'.format(
            len(output_list), expected_length))

    formatted_dict = {
        'whole-cell': [output_list[0], output_list[1][..., 1:2]],
        'nuclear': [output_list[2], output_list[3][..., 1:2]],
    }

    return formatted_dict


def mesmer_postprocess(model_output, compartment='whole-cell',
                       whole_cell_kwargs=None, nuclear_kwargs=None):
    """Postprocess Mesmer output to generate predictions for distinct cellular compartments

    Args:
        model_output (dict): Output from the Mesmer model. A dict with a key corresponding to
            each cellular compartment with a model prediction. Each key maps to a subsequent dict
            with the following keys entries
            - inner-distance: Prediction for the inner distance transform.
            - outer-distance: Prediction for the outer distance transform
            - fgbg-fg: prediction for the foreground/background transform
            - pixelwise-interior: Prediction for the interior/border/background transform.
        compartment: which cellular compartments to generate predictions for.
            must be one of 'whole_cell', 'nuclear', 'both'
        whole_cell_kwargs (dict): Optional list of post-processing kwargs for whole-cell prediction
        nuclear_kwargs (dict): Optional list of post-processing kwargs for nuclear prediction

    Returns:
        numpy.array: Uniquely labeled mask for each compartment

    Raises:
        ValueError: for invalid compartment flag
    """

    valid_compartments = ['whole-cell', 'nuclear', 'both']

    if whole_cell_kwargs is None:
        whole_cell_kwargs = {}

    if nuclear_kwargs is None:
        nuclear_kwargs = {}

    if compartment not in valid_compartments:
        raise ValueError(f'Invalid compartment supplied: {compartment}. '
                         f'Must be one of {valid_compartments}')

    if compartment == 'whole-cell':
        label_images = deep_watershed(model_output['whole-cell'],
                                      **whole_cell_kwargs)
    elif compartment == 'nuclear':
        label_images = deep_watershed(model_output['nuclear'],
                                      **nuclear_kwargs)
    elif compartment == 'both':
        label_images_cell = deep_watershed(model_output['whole-cell'],
                                           **whole_cell_kwargs)

        label_images_nucleus = deep_watershed(model_output['nuclear'],
                                              **nuclear_kwargs)

        label_images = np.concatenate([
            label_images_cell,
            label_images_nucleus
        ], axis=-1)

    else:
        raise ValueError(f'Invalid compartment supplied: {compartment}. '
                         f'Must be one of {valid_compartments}')

    return label_images


[docs] class Mesmer(Application): """Loads a :mod:`deepcell.model_zoo.panopticnet.PanopticNet` model for tissue segmentation with pretrained weights. The ``predict`` method handles prep and post processing steps to return a labeled image. Example: .. code-block:: python from skimage.io import imread from deepcell.applications import Mesmer # Load the images im1 = imread('TNBC_DNA.tiff') im2 = imread('TNBC_Membrane.tiff') # Combined together and expand to 4D im = np.stack((im1, im2), axis=-1) im = np.expand_dims(im,0) # Create the application app = Mesmer() # create the lab labeled_image = app.predict(image) Args: model (tf.keras.Model): The model to load. If ``None``, a pre-trained model will be downloaded. """ #: Metadata for the dataset used to train the model dataset_metadata = { 'name': '20200315_IF_Training_6.npz', 'other': 'Pooled whole-cell data across tissue types' } #: Metadata for the model and training process model_metadata = { 'batch_size': 1, 'lr': 1e-5, 'lr_decay': 0.99, 'training_seed': 0, 'n_epochs': 30, 'training_steps_per_epoch': 1739 // 1, 'validation_steps_per_epoch': 193 // 1 } def __init__(self, model=None): if model is None: cache_subdir = "models" model_dir = Path.home() / ".deepcell" / "models" archive_path = fetch_data( asset_key=MODEL_KEY, cache_subdir=cache_subdir, file_hash=MODEL_HASH ) extract_archive(archive_path, model_dir) model_path = model_dir / MODEL_NAME model = tf.keras.models.load_model(model_path) super().__init__( model, model_image_shape=model.input_shape[1:], model_mpp=0.5, preprocessing_fn=mesmer_preprocess, postprocessing_fn=mesmer_postprocess, format_model_output_fn=format_output_mesmer, dataset_metadata=self.dataset_metadata, model_metadata=self.model_metadata)
[docs] def predict(self, image, batch_size=4, image_mpp=None, preprocess_kwargs={}, compartment='whole-cell', pad_mode='constant', postprocess_kwargs_whole_cell={}, postprocess_kwargs_nuclear={}): """Generates a labeled image of the input running prediction with appropriate pre and post processing functions. Input images are required to have 4 dimensions ``[batch, x, y, channel]``. Additional empty dimensions can be added using ``np.expand_dims``. Args: image (numpy.array): Input image with shape ``[batch, x, y, channel]``. batch_size (int): Number of images to predict on per batch. image_mpp (float): Microns per pixel for ``image``. compartment (str): Specify type of segmentation to predict. Must be one of ``"whole-cell"``, ``"nuclear"``, ``"both"``. preprocess_kwargs (dict): Keyword arguments to pass to the pre-processing function. postprocess_kwargs (dict): Keyword arguments to pass to the post-processing function. Raises: ValueError: Input data must match required rank of the application, calculated as one dimension more (batch dimension) than expected by the model. ValueError: Input data must match required number of channels. Returns: numpy.array: Instance segmentation mask. """ default_kwargs_cell = { 'maxima_threshold': 0.075, 'maxima_smooth': 0, 'interior_threshold': 0.2, 'interior_smooth': 2, 'small_objects_threshold': 15, 'fill_holes_threshold': 15, 'radius': 2 } default_kwargs_nuc = { 'maxima_threshold': 0.1, 'maxima_smooth': 0, 'interior_threshold': 0.2, 'interior_smooth': 2, 'small_objects_threshold': 15, 'fill_holes_threshold': 15, 'radius': 2 } # overwrite defaults with any user-provided values postprocess_kwargs_whole_cell = {**default_kwargs_cell, **postprocess_kwargs_whole_cell} postprocess_kwargs_nuclear = {**default_kwargs_nuc, **postprocess_kwargs_nuclear} # create dict to hold all of the post-processing kwargs postprocess_kwargs = { 'whole_cell_kwargs': postprocess_kwargs_whole_cell, 'nuclear_kwargs': postprocess_kwargs_nuclear, 'compartment': compartment } return self._predict_segmentation(image, batch_size=batch_size, image_mpp=image_mpp, pad_mode=pad_mode, preprocess_kwargs=preprocess_kwargs, postprocess_kwargs=postprocess_kwargs)