This page was generated from notebooks/training/tracking/Training and Tracking with GNNs.ipynb
This notebook is part of the deepcell-tf documentation: https://deepcell.readthedocs.io/.
Training a cell tracking model
[ ]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import CSVLogger
from tensorflow_addons.optimizers import RectifiedAdam
import yaml
import deepcell
from deepcell.data.tracking import Track, random_rotate, random_translate, temporal_slice
from deepcell.losses import weighted_categorical_crossentropy
from deepcell.model_zoo.tracking import GNNTrackingModel
from deepcell.utils.tfrecord_utils import get_tracking_dataset, write_tracking_dataset_to_tfr
from deepcell.utils.train_utils import count_gpus, rate_scheduler
from deepcell_toolbox.metrics import Metrics
from deepcell_tracking import CellTracker
from deepcell_tracking.metrics import benchmark_tracking_performance, calculate_summary_stats
from deepcell_tracking.trk_io import load_trks
from deepcell_tracking.utils import get_max_cells, is_valid_lineage
The DynamicNuclearNet tracking dataset can be downloaded from https://datasets.deepcell.org/
[ ]:
# Please change these file paths to match your file system.
data_dir = '/notebooks/data'
inf_model_path = "NuclearTrackingInf"
ne_model_path = "NuclearTrackingNE"
metrics_path = "train-metrics.yaml"
train_log_path = "train_log.csv"
prediction_dir = 'output'
# Check that prediction directory exists and make if needed
if not os.path.exists(prediction_dir):
os.makedirs(prediction_dir)
Prepare the data for training
Tracked data are stored as .trks
files. These files include images and lineage data in np.arrays. To manipulate .trks
files, use deepcell_tracking.trk_io.load_trks
and deepcell_tracking.trk_io.save_trks
.
To facilitate training, we transform each movie’s image and lineage data into a Track
object. Tracks
help to encapsulate all of the feature creation from the movie, including:
Appearances:
(num_frames, num_objects, 32, 32, 1)
Morphologies:
(num_frames, num_objects, 32, 32, 3)
Centroids:
(num_frames, num_objects, 2)
Normalized Adjacency Matrix:
(num_frames, num_objects, num_objects, 3)
Temporal Adjacency Matrix (comparing across frames):
(num_frames - 1, num_objects, num_objects, 3)
Each Track
is then saved as a tfrecord file in order to load data from disk during training and reduce the total memory footprint.
[ ]:
appearance_dim = 32
distance_threshold = 64
crop_mode = "resize"
[ ]:
# This cell may take ~20 minutes to run
train_trks = load_trks(os.path.join(data_dir, "train.trks"))
val_trks = load_trks(os.path.join(data_dir, "val.trks"))
max_cells = max([get_max_cells(train_trks["y"]), get_max_cells(val_trks["y"])])
for split, trks in zip({"train", "val"}, [train_trks, val_trks]):
print(f"Preparing {split} as tf record")
with tf.device("/cpu:0"):
tracks = Track(
tracked_data=trks,
appearance_dim=appearance_dim,
distance_threshold=distance_threshold,
crop_mode=crop_mode,
)
write_tracking_dataset_to_tfr(
tracks, target_max_cells=max_cells, filename=split
)
Training
Define training parameters
[ ]:
# Model architecture
n_layers = 1 # Number of graph convolution layers
n_filters = 64
encoder_dim = 64
embedding_dim = 64
graph_layer = "gat"
norm_layer = "batch"
[ ]:
# Data and augmentation
seed = 0
track_length = 8 # Number of frames per track object
rotation_range = 180
translation_range = 512
buffer_size = 128
[ ]:
# Training configuration
batch_size = 8
epochs = 50
steps_per_epoch = 1000
validation_steps = 200
lr = 1e-3
Load TFRecord Data
[ ]:
# Augmentation functions
def sample(X, y):
return temporal_slice(X, y, track_length=track_length)
def rotate(X, y):
return random_rotate(X, y, rotation_range=rotation_range)
def translate(X, y):
return random_translate(X, y, range=translation_range)
with tf.device("/cpu:0"):
train_data = get_tracking_dataset("train")
train_data = train_data.shuffle(buffer_size, seed=seed).repeat()
train_data = train_data.map(sample, num_parallel_calls=tf.data.AUTOTUNE)
train_data = train_data.map(rotate, num_parallel_calls=tf.data.AUTOTUNE)
train_data = train_data.map(translate, num_parallel_calls=tf.data.AUTOTUNE)
train_data = train_data.batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_data = get_tracking_dataset("val")
val_data = val_data.shuffle(buffer_size, seed=seed).repeat()
val_data = val_data.map(sample, num_parallel_calls=tf.data.AUTOTUNE)
val_data = val_data.batch(batch_size).prefetch(tf.data.AUTOTUNE)
max_cells = list(train_data.take(1))[0][0]["appearances"].shape[2]
Initialize the model
[ ]:
def filter_and_flatten(y_true, y_pred):
n_classes = tf.shape(y_true)[-1]
new_shape = [-1, n_classes]
y_true = tf.reshape(y_true, new_shape)
y_pred = tf.reshape(y_pred, new_shape)
# Mask out the padded cells
y_true_reduced = tf.reduce_sum(y_true, axis=-1)
good_loc = tf.where(y_true_reduced == 1)[:, 0]
y_true = tf.gather(y_true, good_loc, axis=0)
y_pred = tf.gather(y_pred, good_loc, axis=0)
return y_true, y_pred
class Recall(tf.keras.metrics.Recall):
def update_state(self, y_true, y_pred, sample_weight=None):
y_true, y_pred = filter_and_flatten(y_true, y_pred)
super().update_state(y_true, y_pred, sample_weight)
class Precision(tf.keras.metrics.Precision):
def update_state(self, y_true, y_pred, sample_weight=None):
y_true, y_pred = filter_and_flatten(y_true, y_pred)
super().update_state(y_true, y_pred, sample_weight)
def loss_function(y_true, y_pred):
y_true, y_pred = filter_and_flatten(y_true, y_pred)
return weighted_categorical_crossentropy(
y_true, y_pred, n_classes=tf.shape(y_true)[-1], axis=-1
)
[ ]:
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
with strategy.scope():
model = GNNTrackingModel(
max_cells=max_cells,
graph_layer=graph_layer,
track_length=track_length,
n_filters=n_filters,
embedding_dim=embedding_dim,
encoder_dim=encoder_dim,
n_layers=n_layers,
norm_layer=norm_layer,
)
loss = {"temporal_adj_matrices": loss_function}
optimizer = RectifiedAdam(learning_rate=lr, clipnorm=0.001)
training_metrics = [
Recall(class_id=0, name="same_recall"),
Recall(class_id=1, name="different_recall"),
Recall(class_id=2, name="daughter_recall"),
Precision(class_id=0, name="same_precision"),
Precision(class_id=1, name="different_precision"),
Precision(class_id=2, name="daughter_precision"),
]
model.training_model.compile(
loss=loss, optimizer=optimizer, metrics=training_metrics
)
Train the model
[ ]:
# Clear clutter from previous TensorFlow graphs.
tf.keras.backend.clear_session()
monitor = "val_loss"
csv_logger = CSVLogger(train_log_path)
# Create callbacks for early stopping and pruning.
callbacks = [
tf.keras.callbacks.LearningRateScheduler(rate_scheduler(lr=lr, decay=0.99)),
tf.keras.callbacks.ReduceLROnPlateau(
monitor=monitor,
factor=0.1,
patience=5,
verbose=1,
mode="auto",
min_delta=0.0001,
cooldown=0,
min_lr=0,
),
csv_logger,
]
print(f"Training on {count_gpus()} GPUs.")
# Train model.
history = model.training_model.fit(
train_data,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=val_data,
validation_steps=validation_steps,
callbacks=callbacks,
)
print("Final", monitor, ":", history.history[monitor][-1])
[ ]:
# Save models
model.inference_model.save(inf_model_path, include_optimizer=False, overwrite=True)
model.neighborhood_encoder.save(
ne_model_path, include_optimizer=False, overwrite=True
)
[ ]:
# Record training metrics
all_metrics = {
"metrics": {"training": {k: float(v[-1]) for k, v in history.history.items()}}
}
# save a metadata.yaml file in the saved model directory
with open(metrics_path, "w") as f:
yaml.dump(all_metrics, f)
Evaluate model performance
Set tracking parameters and CellTracker
[ ]:
death = 0.99
birth = 0.99
division = 0.01
Load test data
[ ]:
test_data = load_trks(os.path.join(data_dir, "test.trks"))
X_test = test_data["X"]
y_test = test_data["y"]
lineages_test = test_data["lineages"]
# Load metadata array
with np.load(os.path.join(data_dir, "data-source.npz"), allow_pickle=True) as data:
meta = data["test"]
Predict and benchmark
[ ]:
metrics = {}
exp_metrics = {}
bad_batches = []
for b in range(len(X_test)):
# currently NOT saving any recall/precision information
gt_path = os.path.join(prediction_dir, f"{b}-gt.trk")
res_path = os.path.join(prediction_dir, f"{b}-res.trk")
# Check that lineage is valid before proceeding
if not is_valid_lineage(y_test[b], lineages_test[b]):
bad_batches.append(b)
continue
frames = find_frames_with_objects(y_test[b])
tracker = CellTracker(
movie=X_test[b][frames],
annotation=y_test[b][frames],
track_length=track_length,
neighborhood_encoder=ne_model,
tracking_model=inf_model,
death=death,
birth=birth,
division=division,
)
try:
tracker.track_cells()
except Exception as err:
print(
"Failed to track batch {} due to {}: {}".format(
b, err.__class__.__name__, err
)
)
bad_batches.append(b)
continue
tracker.dump(res_path)
gt = {
"X": X_test[b][frames],
"y_tracked": y_test[b][frames],
"tracks": lineages_test[b],
}
tracker.dump(filename=gt_path, track_review_dict=gt)
results = benchmark_tracking_performance(
gt_path, res_path, threshold=iou_thresh
)
exp = meta[b, 1] # Grab the experiment column from metadata
tmp_exp = exp_metrics.get(exp, {})
for k in results:
if k in metrics:
metrics[k] += results[k]
else:
metrics[k] = results[k]
if k in tmp_exp:
tmp_exp[k] += results[k]
else:
tmp_exp[k] = results[k]
exp_metrics[exp] = tmp_exp
[ ]:
# Calculate summary stats for each set of metrics
tmp_metrics = metrics.copy()
del tmp_metrics["mismatch_division"]
summary = calculate_summary_stats(**tmp_metrics, n_digits=3)
metrics = {**metrics, **summary}
for exp, m in exp_metrics.items():
tmp_m = m.copy()
del tmp_m["mismatch_division"]
summary = calculate_summary_stats(**tmp_m, n_digits=3)
exp_metrics[exp] = {**m, **summary}
# save a metadata.yaml file in the saved model directory
with open(metrics_path, "w") as f:
yaml.dump(all_metrics, f)
[ ]: