Source code for deepcell.applications.cell_tracking

# Copyright 2016-2023 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.github.com/vanvalenlab/deepcell-tf/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A model that can detect whether 2 cells are same, different, or related."""


from pathlib import Path

import tensorflow as tf

import deepcell_tracking

from deepcell.applications import Application
from deepcell.utils import fetch_data, extract_archive


MODEL_KEY = 'models/NuclearTrackingInf-75.tar.gz'
MODEL_NAME = 'NuclearTrackingInf'
MODEL_HASH = '5dbd8137be851a0c12557fcde5021444'

ENCODER_KEY = 'models/NuclearTrackingNE-75.tar.gz'
ENCODER_NAME = 'NuclearTrackingNE'
ENCODER_HASH = 'a466682c9d1d5e3672325bb8a13ab3e0'

MODEL_METADATA = {
    'batch_size': 8,
    'n_layers': 1,
    'graph_layer': 'gat',
    'epochs': 50,
    'steps_per_epoch': 1000,
    'validation_steps': 200,
    'rotation_range': 180,
    'translation_range': 512,
    'buffer_size': 128,
    'n_filters': 64,
    'embedding_dim': 64,
    'encoder_dim': 64,
    'lr': .001,
    'data_fraction': 1,
    'norm_layer': 'batch',
    'appearance_dim': 32,
    'distance_threshold': 64,
    'crop_mode': 'resize',
}

DISTANCE_THRESHOLD = 64
APPEARANCE_DIM = 32
CROP_MODE = 'resize'
NORM = True
BIRTH = 0.99
DEATH = 0.99
DIVISION = 0.01
TRACK_LENGTH = 8
MODEL_MPP = 0.65


[docs] class CellTracking(Application): """Loads a :mod:`deepcell.model_zoo.tracking.GNNTrackingModel` model for object tracking with pretrained weights using a simple ``predict`` interface. Args: model (``tf.keras.model``): Tracking inference model, defaults to latest published model neighborhood_encoder (``tf.keras.model``): Tracking neighborhood encoder, defaults to latest published model distance_threshold (int): Maximum distance between two cells to be considered adjacent appearance_dim (int): Length of appearance dimension birth (float): Cost of new cell in linear assignment matrix. death (float): Cost of cell death in linear assignment matrix. division (float): Cost of cell division in linear assignment matrix. track_length (int): Number of frames per track crop_mode (str): Type of cropping around each cell norm (str): Type of normalization layer """ #: Metadata for the dataset used to train the model dataset_metadata = { 'name': 'tracked_nuclear_train_large', 'other': 'Pooled tracked nuclear data from HEK293, HeLa-S3, NIH-3T3, and RAW264.7 cells.' } #: Metadata for the model and training process model_metadata = MODEL_METADATA def __init__(self, model=None, neighborhood_encoder=None, distance_threshold=DISTANCE_THRESHOLD, appearance_dim=APPEARANCE_DIM, birth=BIRTH, death=DEATH, division=DIVISION, track_length=TRACK_LENGTH, embedding_axis=0, crop_mode=CROP_MODE, norm=NORM): self.neighborhood_encoder = neighborhood_encoder self.distance_threshold = distance_threshold self.appearance_dim = appearance_dim self.birth = birth self.death = death self.division = division self.track_length = track_length self.embedding_axis = embedding_axis self.crop_mode = crop_mode self.norm = norm cache_subdir = "models" model_dir = Path.home() / ".deepcell" / "models" if self.neighborhood_encoder is None: archive_path = fetch_data( asset_key=ENCODER_KEY, cache_subdir=cache_subdir, file_hash=ENCODER_HASH ) extract_archive(archive_path, model_dir) model_path = model_dir / ENCODER_NAME self.neighborhood_encoder = tf.keras.models.load_model(model_path) if model is None: archive_path = fetch_data( asset_key=MODEL_KEY, cache_subdir=cache_subdir, file_hash=MODEL_HASH ) extract_archive(archive_path, model_dir) model_path = model_dir / MODEL_NAME model = tf.keras.models.load_model(model_path) super().__init__( model, model_mpp=MODEL_MPP, preprocessing_fn=None, postprocessing_fn=None, dataset_metadata=self.dataset_metadata, model_metadata=self.model_metadata)
[docs] def predict(self, image, labels, **kwargs): """Using both raw image data and segmentation masks, track objects across all frames. Args: image (numpy.array): Raw image data. labels (numpy.array): Labels for ``image``, integer masks. Returns: dict: Tracked labels and lineage information. """ cell_tracker = deepcell_tracking.CellTracker( image, labels, self.model, neighborhood_encoder=self.neighborhood_encoder, distance_threshold=self.distance_threshold, appearance_dim=self.appearance_dim, track_length=self.track_length, embedding_axis=self.embedding_axis, birth=self.birth, death=self.death, division=self.division, crop_mode=self.crop_mode, norm=self.norm) cell_tracker.track_cells() return cell_tracker._track_review_dict()
[docs] def track(self, image, labels, **kwargs): """Wrapper around predict() for convenience.""" return self.predict(image, labels, **kwargs)