Source code for deepcell.applications.cytoplasm_segmentation

# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cytoplasmic segmentation application"""


import os

import tensorflow as tf

from deepcell_toolbox.processing import histogram_normalization
from deepcell_toolbox.deep_watershed import deep_watershed

from deepcell.applications import Application


MODEL_PATH = ('https://deepcell-data.s3-us-west-1.amazonaws.com/'
              'saved-models/CytoplasmSegmentation-5.tar.gz')
MODEL_HASH = '97334472f59e6d85697c563ed65969ff'


[docs] class CytoplasmSegmentation(Application): """Loads a :mod:`deepcell.model_zoo.panopticnet.PanopticNet` model for cytoplasm segmentation with pretrained weights. The ``predict`` method handles prep and post processing steps to return a labeled image. Example: .. code-block:: python from skimage.io import imread from deepcell.applications import CytoplasmSegmentation # Load the image im = imread('HeLa_cytoplasm.png') # Expand image dimensions to rank 4 im = np.expand_dims(im, axis=-1) im = np.expand_dims(im, axis=0) # Create the application app = CytoplasmSegmentation() # create the lab labeled_image = app.predict(image) Args: model (tf.keras.Model): The model to load. If ``None``, a pre-trained model will be downloaded. """ #: Metadata for the dataset used to train the model dataset_metadata = { 'name': 'general_cyto', 'other': 'Pooled phase and fluorescent cytoplasm data - computationally curated' } #: Metadata for the model and training process model_metadata = { 'batch_size': 16, 'lr': 1e-4, 'lr_decay': 0.9, 'training_seed': 0, 'n_epochs': 8, 'training_steps_per_epoch': 7899 // 2, 'validation_steps_per_epoch': 1973 // 2 } def __init__(self, model=None, preprocessing_fn=histogram_normalization, postprocessing_fn=deep_watershed): if model is None: archive_path = tf.keras.utils.get_file( 'CytoplasmSegmentation.tgz', MODEL_PATH, file_hash=MODEL_HASH, extract=True, cache_subdir='models' ) model_path = os.path.splitext(archive_path)[0] model = tf.keras.models.load_model(model_path) super().__init__( model, model_image_shape=model.input_shape[1:], model_mpp=0.65, preprocessing_fn=preprocessing_fn, postprocessing_fn=postprocessing_fn, dataset_metadata=self.dataset_metadata, model_metadata=self.model_metadata)
[docs] def predict(self, image, batch_size=4, image_mpp=None, pad_mode='reflect', preprocess_kwargs=None, postprocess_kwargs=None): """Generates a labeled image of the input running prediction with appropriate pre and post processing functions. Input images are required to have 4 dimensions ``[batch, x, y, channel]``. Additional empty dimensions can be added using ``np.expand_dims``. Args: image (numpy.array): Input image with shape ``[batch, x, y, channel]``. batch_size (int): Number of images to predict on per batch. image_mpp (float): Microns per pixel for ``image``. pad_mode (str): The padding mode, one of "constant" or "reflect". preprocess_kwargs (dict): Keyword arguments to pass to the pre-processing function. postprocess_kwargs (dict): Keyword arguments to pass to the post-processing function. Raises: ValueError: Input data must match required rank of the application, calculated as one dimension more (batch dimension) than expected by the model. ValueError: Input data must match required number of channels. Returns: numpy.array: Labeled image """ if preprocess_kwargs is None: preprocess_kwargs = {} if postprocess_kwargs is None: postprocess_kwargs = { 'radius': 10, 'maxima_threshold': 0.1, 'interior_threshold': 0.01, 'exclude_border': False, 'small_objects_threshold': 0 } return self._predict_segmentation( image, batch_size=batch_size, image_mpp=image_mpp, pad_mode=pad_mode, preprocess_kwargs=preprocess_kwargs, postprocess_kwargs=postprocess_kwargs)