This page was generated from notebooks/training/panopticnets/Nuclear Segmentation - DeepWatershed.ipynb
Training a segmentation model
deepcell-tf
leverages Jupyter Notebooks in order to train models. Example notebooks are available for most model architectures in the notebooks folder. Most notebooks are structured similarly to this example and thus this notebook serves as a core reference for the deepcell approach to model training.
[ ]:
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.feature import peak_local_max
import tensorflow as tf
from deepcell.applications import NuclearSegmentation
from deepcell.image_generators import CroppingDataGenerator
from deepcell.losses import weighted_categorical_crossentropy
from deepcell.model_zoo.panopticnet import PanopticNet
from deepcell.utils.train_utils import count_gpus, rate_scheduler
from deepcell_toolbox.deep_watershed import deep_watershed
from deepcell_toolbox.metrics import Metrics
from deepcell_toolbox.processing import histogram_normalization
File paths
[ ]:
data_dir = '/notebooks/data'
model_path = 'NuclearSegmentation'
metrics_path = 'metrics.yaml'
train_log = 'train_log.csv'
Load the data
The DynamicNuclearNet tracking dataset can be downloaded from https://datasets.deepcell.org/
[ ]:
with np.load(os.path.join(data_dir, 'train.npz')) as data:
X_train = data['X']
y_train = data['y']
with np.load(os.path.join(data_dir, 'val.npz')) as data:
X_val = data['X']
y_val = data['y']
with np.load(os.path.join(data_dir, 'test.npz')) as data:
X_test = data['X']
y_test = data['y']
Training parameters
The majority of DeepCell models support a variety backbone choices specified in the “backbone” parameter. Backbones are provided through keras_applications and can be instantiated with weights that are pretrained on ImageNet.
[ ]:
# Model architecture
backbone = "efficientnetv2bl"
location = True
pyramid_levels = ["P1","P2","P3","P4","P5","P6","P7"]
[ ]:
# Augmentation and transform parameters
seed = 0
min_objects = 1
zoom_min = 0.75
crop_size = 256
outer_erosion_width = 1
inner_distance_alpha = "auto"
inner_distance_beta = 1
inner_erosion_width = 0
[ ]:
# Post processing parameters
maxima_threshold = 0.1
interior_threshold = 0.01
exclude_border = False
small_objects_threshold = 0
min_distance = 10
[ ]:
# Training configuration
epochs = 16
batch_size = 16
lr = 1e-4
Create data generators
[ ]:
# data augmentation parameters
zoom_max = 1 / zoom_min
# Preprocess the data
X_train = histogram_normalization(X_train)
X_val = histogram_normalization(X_val)
# use augmentation for training but not validation
datagen = CroppingDataGenerator(
rotation_range=180,
zoom_range=(zoom_min, zoom_max),
horizontal_flip=True,
vertical_flip=True,
crop_size=(crop_size, crop_size),
)
datagen_val = CroppingDataGenerator(
crop_size=(crop_size, crop_size)
)
[ ]:
transforms = ["inner-distance", "outer-distance", "fgbg"]
transforms_kwargs = {
"outer-distance": {"erosion_width": outer_erosion_width},
"inner-distance": {
"alpha": inner_distance_alpha,
"beta": inner_distance_beta,
"erosion_width": inner_erosion_width,
},
}
train_data = datagen.flow(
{'X': X_train, 'y': y_train},
seed=seed,
min_objects=min_objects,
transforms=transforms,
transforms_kwargs=transforms_kwargs,
batch_size=batch_size,
)
print("Created training data generator.")
val_data = datagen_val.flow(
{'X': X_val, 'y': y_val},
seed=seed,
min_objects=min_objects,
transforms=transforms,
transforms_kwargs=transforms_kwargs,
batch_size=batch_size,
)
print("Created validation data generator.")
Visualize the data generator output.
[ ]:
inputs, outputs = train_data.next()
img = inputs[0]
inner_distance = outputs[0]
outer_distance = outputs[1]
fgbg = outputs[2]
fig, axes = plt.subplots(1, 4, figsize=(15, 15))
axes[0].imshow(img[..., 0])
axes[0].set_title('Source Image')
axes[1].imshow(inner_distance[0, ..., 0])
axes[1].set_title('Inner Distance')
axes[2].imshow(outer_distance[0, ..., 0])
axes[2].set_title('Outer Distance')
axes[3].imshow(fgbg[0, ..., 0])
axes[3].set_title('Foreground/Background')
plt.show()
Create the PanopticNet Model
Here we instantiate a PanopticNet
model from deepcell.model_zoo
using 3 semantic heads: inner distance (1 class), outer distance (1 class), foreground/background distance (2 classes)
[ ]:
input_shape = (crop_size, crop_size, 1)
model = PanopticNet(
backbone=backbone,
input_shape=input_shape,
norm_method=None,
num_semantic_classes=[1, 1, 2], # inner distance, outer distance, fgbg
location=location,
include_top=True,
backbone_levels=["C1", "C2", "C3", "C4", "C5"],
pyramid_levels=pyramid_levels,
)
Create a loss function for each semantic head
Each semantic head is trained with it’s own loss function. Mean Square Error is used for regression-based heads, whereas weighted_categorical_crossentropy
is used for classification heads.
The losses are saved as a dictionary and passed to model.compile
.
[ ]:
def semantic_loss(n_classes):
def _semantic_loss(y_pred, y_true):
if n_classes > 1:
return 0.01 * weighted_categorical_crossentropy(
y_pred, y_true, n_classes=n_classes
)
return tf.keras.losses.MSE(y_pred, y_true)
return _semantic_loss
loss = {}
# Give losses for all of the semantic heads
for layer in model.layers:
if layer.name.startswith("semantic_"):
n_classes = layer.output_shape[-1]
loss[layer.name] = semantic_loss(n_classes)
optimizer = tf.keras.optimizers.Adam(lr=lr, clipnorm=0.001)
model.compile(loss=loss, optimizer=optimizer)
Train the model
Call fit
on the compiled model, along with a default set of callbacks.
[ ]:
# Clear clutter from previous TensorFlow graphs.
tf.keras.backend.clear_session()
monitor = "val_loss"
csv_logger = tf.keras.callbacks.CSVLogger(train_log)
# Create callbacks for early stopping and pruning.
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
model_path,
monitor=monitor,
save_best_only=True,
verbose=1,
save_weights_only=False,
),
tf.keras.callbacks.LearningRateScheduler(rate_scheduler(lr=lr, decay=0.99)),
tf.keras.callbacks.ReduceLROnPlateau(
monitor=monitor,
factor=0.1,
patience=5,
verbose=1,
mode="auto",
min_delta=0.0001,
cooldown=0,
min_lr=0,
),
tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
csv_logger,
]
print(f"Training on {count_gpus()} GPUs.")
# Train model.
history = model.fit(
train_data,
steps_per_epoch=train_data.y.shape[0] // batch_size,
epochs=epochs,
validation_data=val_data,
validation_steps=val_data.y.shape[0] // batch_size,
callbacks=callbacks,
)
print("Final", monitor, ":", history.history[monitor][-1])
Save prediction model
We can now create a new prediction model without the foreground background semantic head. While this head is very useful during training, the output is unused during prediction. By using model.load_weights(path, by_name=True)
, the semantic head can be removed.
[ ]:
with tempfile.TemporaryDirectory() as tmpdirname:
weights_path = os.path.join(str(tmpdirname), "model_weights.h5")
model.save_weights(weights_path, save_format="h5")
prediction_model = PanopticNet(
backbone=backbone,
input_shape=input_shape,
norm_method=None,
num_semantic_heads=2,
num_semantic_classes=[1, 1], # inner distance, outer distance
location=location, # should always be true
include_top=True,
backbone_levels=["C1", "C2", "C3", "C4", "C5"],
pyramid_levels=pyramid_levels,
)
prediction_model.load_weights(weights_path, by_name=True)
Predict on test data
[ ]:
X_test = histograph_normalization(X_test)
test_images = prediction_model.predict(X_test)
[ ]:
index = np.random.choice(X_test.shape[0])
print(index)
fig, axes = plt.subplots(1, 4, figsize=(20, 20))
masks = deep_watershed(
test_images,
radius=radius,
maxima_threshold=maxima_threshold,
interior_threshold=interior_threshold,
exclude_border=exclude_border,
small_objects_threshold=small_objects_threshold,
min_distance=min_distance
)
# calculated in the postprocessing above, but useful for visualizing
inner_distance = test_images[0]
outer_distance = test_images[1]
coords = peak_local_max(
inner_distance[index],
min_distance=min_distance
)
# raw image with centroid
axes[0].imshow(X_test[index, ..., 0])
axes[0].scatter(coords[..., 1], coords[..., 0],
color='r', marker='.', s=10)
axes[1].imshow(inner_distance[index, ..., 0], cmap='jet')
axes[2].imshow(outer_distance[index, ..., 0], cmap='jet')
axes[3].imshow(masks[index, ...], cmap='jet')
plt.show()
Evaluate results
The deepcell.metrics
package is used to measure advanced metrics for instance segmentation predictions.
[ ]:
outputs = model.predict(X_test)
y_pred = []
for i in range(outputs[0].shape[0]):
mask = deep_watershed(
[t[[i]] for t in outputs],
radius=radius,
maxima_threshold=maxima_threshold,
interior_threshold=interior_threshold,
exclude_border=exclude_border,
small_objects_threshold=small_objects_threshold,
min_distance=min_distance)
y_pred.append(mask[0])
y_pred = np.stack(y_pred, axis=0)
y_pred = np.expand_dims(y_pred, axis=-1)
y_true = y_test.copy()
m = Metrics('DeepWatershed', seg=False)
m.calc_object_stats(y_true, y_pred)