This notebook is part of the deepcell-tf documentation: https://deepcell.readthedocs.io/.

Training a cell tracking model

[ ]:
import os

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import CSVLogger
from tensorflow_addons.optimizers import RectifiedAdam
import yaml

import deepcell
from deepcell.data.tracking import Track, random_rotate, random_translate, temporal_slice
from deepcell.losses import weighted_categorical_crossentropy
from deepcell.model_zoo.tracking import GNNTrackingModel
from deepcell.utils.tfrecord_utils import get_tracking_dataset, write_tracking_dataset_to_tfr
from deepcell.utils.train_utils import count_gpus, rate_scheduler
from deepcell_toolbox.metrics import Metrics
from deepcell_tracking import CellTracker
from deepcell_tracking.metrics import benchmark_tracking_performance, calculate_summary_stats
from deepcell_tracking.trk_io import load_trks
from deepcell_tracking.utils import get_max_cells, is_valid_lineage

The DynamicNuclearNet tracking dataset can be downloaded from https://datasets.deepcell.org/

[ ]:
# Please change these file paths to match your file system.
data_dir = '/notebooks/data'

inf_model_path = "NuclearTrackingInf"
ne_model_path = "NuclearTrackingNE"
metrics_path = "train-metrics.yaml"
train_log_path = "train_log.csv"

prediction_dir = 'output'
# Check that prediction directory exists and make if needed
if not os.path.exists(prediction_dir):
    os.makedirs(prediction_dir)

Prepare the data for training

Tracked data are stored as .trks files. These files include images and lineage data in np.arrays. To manipulate .trks files, use deepcell_tracking.trk_io.load_trks and deepcell_tracking.trk_io.save_trks.

To facilitate training, we transform each movie’s image and lineage data into a Track object. Tracks help to encapsulate all of the feature creation from the movie, including:

  • Appearances: (num_frames, num_objects, 32, 32, 1)

  • Morphologies: (num_frames, num_objects, 32, 32, 3)

  • Centroids: (num_frames, num_objects, 2)

  • Normalized Adjacency Matrix: (num_frames, num_objects, num_objects, 3)

  • Temporal Adjacency Matrix (comparing across frames): (num_frames - 1, num_objects, num_objects, 3)

Each Track is then saved as a tfrecord file in order to load data from disk during training and reduce the total memory footprint.

[ ]:
appearance_dim = 32
distance_threshold = 64
crop_mode = "resize"
[ ]:
# This cell may take ~20 minutes to run
train_trks = load_trks(os.path.join(data_dir, "train.trks"))
val_trks = load_trks(os.path.join(data_dir, "val.trks"))

max_cells = max([get_max_cells(train_trks["y"]), get_max_cells(val_trks["y"])])

for split, trks in zip({"train", "val"}, [train_trks, val_trks]):
    print(f"Preparing {split} as tf record")

    with tf.device("/cpu:0"):
        tracks = Track(
            tracked_data=trks,
            appearance_dim=appearance_dim,
            distance_threshold=distance_threshold,
            crop_mode=crop_mode,
        )

        write_tracking_dataset_to_tfr(
            tracks, target_max_cells=max_cells, filename=split
        )

Training

Define training parameters

[ ]:
# Model architecture
n_layers = 1  # Number of graph convolution layers
n_filters = 64
encoder_dim = 64
embedding_dim = 64
graph_layer = "gat"
norm_layer = "batch"
[ ]:
# Data and augmentation
seed = 0
track_length = 8  # Number of frames per track object
rotation_range = 180
translation_range = 512
buffer_size = 128
[ ]:
# Training configuration
batch_size = 8
epochs = 50
steps_per_epoch = 1000
validation_steps = 200
lr = 1e-3

Load TFRecord Data

[ ]:
# Augmentation functions
def sample(X, y):
    return temporal_slice(X, y, track_length=track_length)

def rotate(X, y):
    return random_rotate(X, y, rotation_range=rotation_range)

def translate(X, y):
    return random_translate(X, y, range=translation_range)

with tf.device("/cpu:0"):
    train_data = get_tracking_dataset("train")
    train_data = train_data.shuffle(buffer_size, seed=seed).repeat()
    train_data = train_data.map(sample, num_parallel_calls=tf.data.AUTOTUNE)
    train_data = train_data.map(rotate, num_parallel_calls=tf.data.AUTOTUNE)
    train_data = train_data.map(translate, num_parallel_calls=tf.data.AUTOTUNE)
    train_data = train_data.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    val_data = get_tracking_dataset("val")
    val_data = val_data.shuffle(buffer_size, seed=seed).repeat()
    val_data = val_data.map(sample, num_parallel_calls=tf.data.AUTOTUNE)
    val_data = val_data.batch(batch_size).prefetch(tf.data.AUTOTUNE)

max_cells = list(train_data.take(1))[0][0]["appearances"].shape[2]

Initialize the model

[ ]:
def filter_and_flatten(y_true, y_pred):
    n_classes = tf.shape(y_true)[-1]
    new_shape = [-1, n_classes]
    y_true = tf.reshape(y_true, new_shape)
    y_pred = tf.reshape(y_pred, new_shape)

    # Mask out the padded cells
    y_true_reduced = tf.reduce_sum(y_true, axis=-1)
    good_loc = tf.where(y_true_reduced == 1)[:, 0]

    y_true = tf.gather(y_true, good_loc, axis=0)
    y_pred = tf.gather(y_pred, good_loc, axis=0)
    return y_true, y_pred


class Recall(tf.keras.metrics.Recall):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true, y_pred = filter_and_flatten(y_true, y_pred)
        super().update_state(y_true, y_pred, sample_weight)


class Precision(tf.keras.metrics.Precision):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true, y_pred = filter_and_flatten(y_true, y_pred)
        super().update_state(y_true, y_pred, sample_weight)


def loss_function(y_true, y_pred):
    y_true, y_pred = filter_and_flatten(y_true, y_pred)
    return weighted_categorical_crossentropy(
        y_true, y_pred, n_classes=tf.shape(y_true)[-1], axis=-1
    )
[ ]:
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")

with strategy.scope():
    model = GNNTrackingModel(
        max_cells=max_cells,
        graph_layer=graph_layer,
        track_length=track_length,
        n_filters=n_filters,
        embedding_dim=embedding_dim,
        encoder_dim=encoder_dim,
        n_layers=n_layers,
        norm_layer=norm_layer,
    )

    loss = {"temporal_adj_matrices": loss_function}

    optimizer = RectifiedAdam(learning_rate=lr, clipnorm=0.001)

    training_metrics = [
        Recall(class_id=0, name="same_recall"),
        Recall(class_id=1, name="different_recall"),
        Recall(class_id=2, name="daughter_recall"),
        Precision(class_id=0, name="same_precision"),
        Precision(class_id=1, name="different_precision"),
        Precision(class_id=2, name="daughter_precision"),
    ]

    model.training_model.compile(
        loss=loss, optimizer=optimizer, metrics=training_metrics
    )

Train the model

[ ]:
# Clear clutter from previous TensorFlow graphs.
tf.keras.backend.clear_session()

monitor = "val_loss"

csv_logger = CSVLogger(train_log_path)

# Create callbacks for early stopping and pruning.
callbacks = [
    tf.keras.callbacks.LearningRateScheduler(rate_scheduler(lr=lr, decay=0.99)),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor=monitor,
        factor=0.1,
        patience=5,
        verbose=1,
        mode="auto",
        min_delta=0.0001,
        cooldown=0,
        min_lr=0,
    ),
    csv_logger,
]

print(f"Training on {count_gpus()} GPUs.")

# Train model.
history = model.training_model.fit(
    train_data,
    steps_per_epoch=steps_per_epoch,
    epochs=epochs,
    validation_data=val_data,
    validation_steps=validation_steps,
    callbacks=callbacks,
)

print("Final", monitor, ":", history.history[monitor][-1])
[ ]:
# Save models
model.inference_model.save(inf_model_path, include_optimizer=False, overwrite=True)
model.neighborhood_encoder.save(
    ne_model_path, include_optimizer=False, overwrite=True
)
[ ]:
# Record training metrics
all_metrics = {
    "metrics": {"training": {k: float(v[-1]) for k, v in history.history.items()}}
}

# save a metadata.yaml file in the saved model directory
with open(metrics_path, "w") as f:
    yaml.dump(all_metrics, f)

Evaluate model performance

Set tracking parameters and CellTracker

[ ]:
death = 0.99
birth = 0.99
division = 0.01

Load test data

[ ]:
test_data = load_trks(os.path.join(data_dir, "test.trks"))
X_test = test_data["X"]
y_test = test_data["y"]
lineages_test = test_data["lineages"]

# Load metadata array
with np.load(os.path.join(data_dir, "data-source.npz"), allow_pickle=True) as data:
    meta = data["test"]

Predict and benchmark

[ ]:
metrics = {}
exp_metrics = {}
bad_batches = []
for b in range(len(X_test)):
    # currently NOT saving any recall/precision information
    gt_path = os.path.join(prediction_dir, f"{b}-gt.trk")
    res_path = os.path.join(prediction_dir, f"{b}-res.trk")

    # Check that lineage is valid before proceeding
    if not is_valid_lineage(y_test[b], lineages_test[b]):
        bad_batches.append(b)
        continue

    frames = find_frames_with_objects(y_test[b])

    tracker = CellTracker(
        movie=X_test[b][frames],
        annotation=y_test[b][frames],
        track_length=track_length,
        neighborhood_encoder=ne_model,
        tracking_model=inf_model,
        death=death,
        birth=birth,
        division=division,
    )

    try:
        tracker.track_cells()
    except Exception as err:
        print(
            "Failed to track batch {} due to {}: {}".format(
                b, err.__class__.__name__, err
            )
        )
        bad_batches.append(b)
        continue

    tracker.dump(res_path)

    gt = {
        "X": X_test[b][frames],
        "y_tracked": y_test[b][frames],
        "tracks": lineages_test[b],
    }

    tracker.dump(filename=gt_path, track_review_dict=gt)

    results = benchmark_tracking_performance(
        gt_path, res_path, threshold=iou_thresh
    )

    exp = meta[b, 1]  # Grab the experiment column from metadata
    tmp_exp = exp_metrics.get(exp, {})

    for k in results:
        if k in metrics:
            metrics[k] += results[k]
        else:
            metrics[k] = results[k]

        if k in tmp_exp:
            tmp_exp[k] += results[k]
        else:
            tmp_exp[k] = results[k]

    exp_metrics[exp] = tmp_exp
[ ]:
# Calculate summary stats for each set of metrics
tmp_metrics = metrics.copy()
del tmp_metrics["mismatch_division"]
summary = calculate_summary_stats(**tmp_metrics, n_digits=3)
metrics = {**metrics, **summary}

for exp, m in exp_metrics.items():
    tmp_m = m.copy()
    del tmp_m["mismatch_division"]
    summary = calculate_summary_stats(**tmp_m, n_digits=3)
    exp_metrics[exp] = {**m, **summary}

# save a metadata.yaml file in the saved model directory
with open(metrics_path, "w") as f:
    yaml.dump(all_metrics, f)
[ ]: