Training a cell tracking model

Implementation of: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning

[1]:
import os
import datetime
import errno

import numpy as np

import deepcell
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Load the data

Download the data from deepcell.datasets

deepcell.datasets provides access to a set of annotated live-cell imaging datasets which can be used for training cell segmentation and tracking models. All dataset objects share the load_data() method, which allows the user to specify the name of the file (path), the fraction of data reserved for testing (test_size) and a seed which is used to generate the random train-test split. Metadata associated with the dataset can be accessed through the metadata attribute.

Tracked data are stored as .trks files. .trks files are a special format that includes image and lineage data in np.arrays. To access .trks files, use deepcell.utils.tracking_utils.load_trks and deepcell.utils.tracking_utils.save_trks.

Training a tracking algorithm is a complicated process that requires alot of data. We recommend combining multiple data sets.

[2]:
# Download four different sets of data (saves to ~/.keras/datasets)
filename_3T3 = '3T3_NIH.trks'
(X_train, y_train), (X_test, y_test) = deepcell.datasets.tracked.nih_3t3.load_tracked_data(filename_3T3)
print('3T3 -\nX.shape: {}\ny.shape: {}'.format(X_train.shape, y_train.shape))

filename_HeLa = 'HeLa_S3.trks'
(X_train, y_train), (X_test, y_test) = deepcell.datasets.tracked.hela_s3.load_tracked_data(filename_HeLa)
print('HeLa -\nX.shape: {}\ny.shape: {}'.format(X_train.shape, y_train.shape))

filename_HEK = 'HEK293.trks'
(X_train, y_train), (X_test, y_test) = deepcell.datasets.tracked.hek293.load_tracked_data(filename_HEK)
print('HEK293 -\nX.shape: {}\ny.shape: {}'.format(X_train.shape, y_train.shape))

filename_RAW = 'RAW2647.trks'
(X_train, y_train), (X_test, y_test) = deepcell.datasets.tracked.raw2647.load_tracked_data(filename_RAW)
print('RAW264.7 -\nX.shape: {}\ny.shape: {}'.format(X_train.shape, y_train.shape))
Downloading data from https://deepcell-data.s3.amazonaws.com/tracked/3T3_NIH.trks
3229646848/3229644800 [==============================] - 157s 0us/step
3T3 -
X.shape: (192, 30, 154, 182, 1)
y.shape: (192, 30, 154, 182, 1)
Downloading data from https://deepcell-data.s3.amazonaws.com/tracked/HeLa_S3.trks
6370648064/6370641920 [==============================] - 280s 0us/step
HeLa -
X.shape: (144, 40, 216, 256, 1)
y.shape: (144, 40, 216, 256, 1)
Downloading data from https://deepcell-data.s3.amazonaws.com/tracked/HEK293.trks
1344610304/1344604160 [==============================] - 70s 0us/step
HEK293 -
X.shape: (207, 30, 135, 160, 1)
y.shape: (207, 30, 135, 160, 1)
Downloading data from https://deepcell-data.s3.amazonaws.com/tracked/RAW2647.trks
2164695040/2164695040 [==============================] - 204s 0us/step
RAW264.7 -
X.shape: (99, 30, 202, 240, 1)
y.shape: (99, 30, 202, 240, 1)

Preprocess the data

After downloading data from deepcell.datasets.tracked, we will compile the data into a single dataset. Neural networks require all the input data to be the same dimensions, so we will identify the maximum dimensions and pad smaller datasets to match the maximum dimensions. Neural networks also prefer 0-mean and unit-variance data, so each image will be normalized.

[3]:
from deepcell.utils.tracking_utils import load_trks
from deepcell.utils.tracking_utils import save_trks

# Define a normalizaiton function for the raw images that can be run before padding
def image_norm(original_image):
    # NNs prefer input data that is 0 mean and unit variance
    normed_image = (original_image - np.mean(original_image)) / np.std(original_image)
    return normed_image

# Define all the trks to load
basepath = os.path.expanduser(os.path.join('~', '.keras', 'datasets'))
trks_files = [os.path.join(basepath, filename_3T3),
              os.path.join(basepath, filename_HeLa),
              os.path.join(basepath, filename_HEK),
              os.path.join(basepath, filename_RAW)]

# Each TRKS file may have differrent dimensions,
# but the model expects uniform dimensions.
# Determine max dimensions and zero pad as neccesary.
max_frames = 1
max_y = 1
max_x = 1

for trks_file in trks_files:
    trks = load_trks(trks_file)

    # Store dimensions of raw and tracked
    # to check new data against to pad if neccesary
    if trks['X'][0].shape[0] > max_frames:
        max_frames = trks['X'][0].shape[0]
    if trks['X'][0].shape[1] > max_y:
        max_y = trks['X'][0].shape[1]
    if trks['X'][0].shape[2] > max_x:
        max_x = trks['X'][0].shape[2]
[4]:
# Load each trks file, normalize and pad as neccesary
lineages = []
X = []
y = []

k = 0
movie_counter = 0
for trks_file in trks_files:
    trks = load_trks(trks_file)
    for i, (lineage, raw, tracked) in enumerate(zip(trks['lineages'], trks['X'], trks['y'])):
        movie_counter = k + i

        # Normalize the raw images
        for frame in range(raw.shape[0]):
            raw[frame, :, :, 0] = image_norm(raw[frame, :, :, 0])

        # Image padding if neccesary - This assumes that raw and tracked have the same shape
        if raw.shape[1] < max_y:
            diff2pad = max_y - raw.shape[1]
            pad_width = int(diff2pad / 2)
            if diff2pad % 2 == 0:
                # Pad width can be split evenly
                raw = np.pad(raw, ((0,0), (pad_width,pad_width), (0,0), (0,0)), mode='constant', constant_values=0)
                tracked = np.pad(tracked, ((0,0), (pad_width,pad_width), (0,0), (0,0)), mode='constant', constant_values=0)
            else:
                # Pad width cannot be split evenly
                raw = np.pad(raw, ((0,0), (pad_width + 1, pad_width), (0,0), (0,0)), mode='constant', constant_values=0)
                tracked = np.pad(tracked, ((0,0), (pad_width + 1, pad_width), (0,0), (0,0)), mode='constant', constant_values=0)

        if raw.shape[2] < max_x:
            diff2pad = max_x - raw.shape[2]
            pad_width = int(diff2pad / 2)
            if diff2pad % 2 == 0:
                # Pad width can be split evenly
                raw = np.pad(raw, ((0,0), (0,0), (pad_width,pad_width), (0,0)), mode='constant', constant_values=0)
                tracked = np.pad(tracked, ((0,0), (0,0), (pad_width,pad_width), (0,0)), mode='constant', constant_values=0)
            else:
                # Pad width cannot be split evenly
                raw = np.pad(raw, ((0,0), (0,0), (pad_width+1,pad_width), (0,0)), mode='constant', constant_values=0)
                tracked = np.pad(tracked, ((0,0), (0,0), (pad_width+1,pad_width), (0,0)), mode='constant', constant_values=0)

        if raw.shape[0] < max_frames:
            pad_width = int(max_frames-raw.shape[0])
            raw = np.pad(raw, ((0,pad_width), (0,0), (0,0), (0,0)), mode='constant', constant_values=0)
            tracked = np.pad(tracked, ((0,pad_width), (0,0), (0,0), (0,0)), mode='constant', constant_values=0)

        lineages.append(lineage)
        X.append(raw)
        y.append(tracked)

    k = movie_counter + 1

# Save the combined datasets into one trks file
filename = 'combined_data.trks'
save_trks(os.path.join(basepath, filename), lineages, X, y)

Describe the data

Finally, we can view descriptive statistics on the complete dataset using deepcell.utils.tracking_utils.trks_stats.

[5]:
# View stats on this combined file
from deepcell.utils.tracking_utils import trks_stats
trks_stats(os.path.join(basepath, filename))
Dataset Statistics:
Image data shape:  (803, 40, 216, 256, 1)
Number of lineages (should equal batch size):  803
Total number of unique tracks (cells)      -  12697
Total number of divisions                  -  944
Average cell density (cells/100 sq pixels) -  0.017033540852301552
Average number of frames per track         -  25

Create the training data

Randomly select a portion of the data to train.

[6]:
# combined_data.trks contains all of the data available

# To hold out a portion of this data for testing we will establish a random seed
test_seed = 1

# And how much of the data to hold out
test_size = .1

# Get the full dataset
trks = load_trks(os.path.join(basepath, filename))
total_data_size = trks['X'].shape[0]

# Select a portion of this dataset randomly
import random
random.seed(test_seed)
train_data_range = int(total_data_size * (1 - test_size))

idx_train = random.sample(range(total_data_size), train_data_range)

lineages, X, y = [], [], []
for i in idx_train:
    lineages.append(trks['lineages'][i])
    X.append(trks['X'][i])
    y.append(trks['y'][i])

# Resave the portion we wish to use as the training (and validation) dataset
filename_train = 'combined_training_data.trks'
save_trks(os.path.join(basepath, filename_train), lineages, X, y)

# View stats on this combined file
trks_stats(os.path.join(basepath, filename_train))
Dataset Statistics:
Image data shape:  (722, 40, 216, 256, 1)
Number of lineages (should equal batch size):  722
Total number of unique tracks (cells)      -  11510
Total number of divisions                  -  844
Average cell density (cells/100 sq pixels) -  0.017189596498441827
Average number of frames per track         -  25

Set up file path constants

[7]:
# The path to the data file is currently required for `train_model_()` functions

# Change DATA_DIR if you are not using `deepcell.datasets`
DATA_DIR = os.path.expanduser(os.path.join('~', '.keras', 'datasets'))

# DATA_FILE should be a trks file (contains 2 np arrays and a lineage dictionary)
DATA_FILE = os.path.join(DATA_DIR, filename_train)

# confirm the data file is available
assert os.path.isfile(DATA_FILE)
[8]:
# Set up other required filepaths

# If the data file is in a subdirectory, mirror it in MODEL_DIR and LOG_DIR
PREFIX = os.path.relpath(os.path.dirname(DATA_FILE), DATA_DIR)

ROOT_DIR = '/data'  # TODO: Change this! Usually a mounted volume
MODEL_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'models', PREFIX))
LOG_DIR = os.path.abspath(os.path.join(ROOT_DIR, 'logs', PREFIX))

# create directories if they do not exist
for d in (MODEL_DIR, LOG_DIR):
    try:
        os.makedirs(d)
    except OSError as exc:  # Guard against race condition
        if exc.errno != errno.EEXIST:
            raise

Training a New Model

Set up training parameters

[9]:
from tensorflow.keras.optimizers import SGD
from deepcell.utils.train_utils import rate_scheduler

n_epoch = 10     # Number of training epochs
test_size = .20  # % of data saved as validation
train_seed = 1   # Random seed for training/validation data split

optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
lr_sched = rate_scheduler(lr=0.01, decay=0.99)

# Tracking training settings
features = {'appearance', 'distance', 'neighborhood', 'regionprop'}
min_track_length = 9
neighborhood_scale_size = 30
batch_size = 128
crop_dim = 32
in_shape = (crop_dim, crop_dim, 1)

model_name = 'tracking_model_seed{}_tl{}'.format(train_seed, min_track_length)

Create the Data Generators

[10]:
import deepcell.image_generators as generators
from deepcell.utils.data_utils import get_data

# Get the data
train_dict, test_dict = get_data(DATA_FILE, mode='siamese_daughters',
                                 seed=train_seed, test_size=test_size)

# Build the generators and iterators
datagen_train = generators.SiameseDataGenerator(
    rotation_range=180, # randomly rotate images by 0 to rotation_range degrees
    shear_range=0,      # randomly shear images in the range (radians , -shear_range to shear_range)
    horizontal_flip=1,  # randomly flip images
    vertical_flip=1)    # randomly flip images

train_data = datagen_train.flow(
    test_dict,
    batch_size=batch_size,
    seed=train_seed,
    crop_dim=crop_dim,
    neighborhood_scale_size=neighborhood_scale_size,
    min_track_length=min_track_length,
    features=features)

datagen_test = generators.SiameseDataGenerator(
    rotation_range=0,  # randomly rotate images by 0 to rotation_range degrees
    shear_range=0,     # randomly shear images in the range (radians , -shear_range to shear_range)
    horizontal_flip=0, # randomly flip images
    vertical_flip=0)   # randomly flip images

test_data = datagen_test.flow(
    test_dict,
    batch_size=batch_size,
    seed=train_seed,
    crop_dim=crop_dim,
    neighborhood_scale_size=neighborhood_scale_size,
    min_track_length=min_track_length,
    features=features)

Instantiate the tracking model

[11]:
from deepcell import model_zoo

tracking_model = model_zoo.siamese_model(
    input_shape=in_shape,
    neighborhood_scale_size=neighborhood_scale_size,
    features=features)

Define the loss function

[15]:
from deepcell import losses

n_classes = tracking_model.layers[-1].output_shape[-1]

def loss_function(y_true, y_pred):
    return losses.weighted_categorical_crossentropy(y_true, y_pred,
                                                    n_classes=n_classes,
                                                    from_logits=False)

Compile the model

Before a model must be trained, it must be compiled with the chosen loss function and optimizer.

[16]:
tracking_model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])

Train the model

Call fit_generator on the compiled model, along with a default set of callbacks.

[22]:
from deepcell.utils.train_utils import get_callbacks
from deepcell.utils.train_utils import count_gpus
from deepcell.utils import tracking_utils


model_path = os.path.join(MODEL_DIR, '{}.h5'.format(model_name))
loss_path = os.path.join(MODEL_DIR, '{}.npz'.format(model_name))

num_gpus = count_gpus()

print('Training on', num_gpus, 'GPUs.')

train_callbacks = get_callbacks(
    model_path,
    lr_sched=lr_sched,
    tensorboard_log_dir=LOG_DIR,
    save_weights_only=num_gpus >= 2,
    monitor='val_loss',
    verbose=1)

# rough estimate for steps_per_epoch
total_train_pairs = tracking_utils.count_pairs(train_dict['y'], same_probability=5.0)
total_test_pairs = tracking_utils.count_pairs(test_dict['y'], same_probability=5.0)

# fit the model on the batches generated by datagen.flow()
loss_history = tracking_model.fit_generator(
    train_data,
    steps_per_epoch=total_train_pairs // batch_size,
    epochs=n_epoch,
    validation_data=test_data,
    validation_steps=total_test_pairs // batch_size,
    callbacks=train_callbacks)
Training on 1 GPUs.
Epoch 1/10
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
5535/5536 [============================>.] - ETA: 0s - loss: 0.0680 - acc: 0.9831Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:17 - loss: 0.0252 - acc: 0.9961
Epoch 00001: val_loss improved from inf to 0.02517, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4476s 808ms/step - loss: 0.0680 - acc: 0.9831 - val_loss: 0.0252 - val_acc: 0.9961
Epoch 2/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0370 - acc: 0.9943Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:16 - loss: 0.0224 - acc: 0.9972
Epoch 00002: val_loss improved from 0.02517 to 0.02242, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4458s 805ms/step - loss: 0.0370 - acc: 0.9943 - val_loss: 0.0224 - val_acc: 0.9972
Epoch 3/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0305 - acc: 0.9959Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:21 - loss: 0.0232 - acc: 0.9967
Epoch 00003: val_loss did not improve from 0.02242
5536/5536 [==============================] - 4456s 805ms/step - loss: 0.0305 - acc: 0.9959 - val_loss: 0.0232 - val_acc: 0.9967
Epoch 4/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0266 - acc: 0.9969Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:16 - loss: 0.0201 - acc: 0.9980
Epoch 00004: val_loss improved from 0.02242 to 0.02006, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4455s 805ms/step - loss: 0.0266 - acc: 0.9969 - val_loss: 0.0201 - val_acc: 0.9980
Epoch 5/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0260 - acc: 0.9969Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:06 - loss: 0.0191 - acc: 0.9979
Epoch 00005: val_loss improved from 0.02006 to 0.01907, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4449s 804ms/step - loss: 0.0260 - acc: 0.9969 - val_loss: 0.0191 - val_acc: 0.9979
Epoch 6/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0263 - acc: 0.9973Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:17 - loss: 0.0234 - acc: 0.9968
Epoch 00006: val_loss did not improve from 0.01907
5536/5536 [==============================] - 4462s 806ms/step - loss: 0.0263 - acc: 0.9973 - val_loss: 0.0234 - val_acc: 0.9968
Epoch 7/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0232 - acc: 0.9977Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:14 - loss: 0.0187 - acc: 0.9982
Epoch 00007: val_loss improved from 0.01907 to 0.01870, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4449s 804ms/step - loss: 0.0232 - acc: 0.9977 - val_loss: 0.0187 - val_acc: 0.9982
Epoch 8/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0218 - acc: 0.9979Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:10 - loss: 0.0170 - acc: 0.9984
Epoch 00008: val_loss improved from 0.01870 to 0.01698, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4452s 804ms/step - loss: 0.0218 - acc: 0.9979 - val_loss: 0.0170 - val_acc: 0.9984
Epoch 9/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0231 - acc: 0.9979Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:20 - loss: 0.0181 - acc: 0.9982
Epoch 00009: val_loss did not improve from 0.01698
5536/5536 [==============================] - 4468s 807ms/step - loss: 0.0231 - acc: 0.9979 - val_loss: 0.0181 - val_acc: 0.9982
Epoch 10/10
5535/5536 [============================>.] - ETA: 0s - loss: 0.0201 - acc: 0.9983Epoch 1/10
1589/5536 [=======>......................] - ETA: 6:15 - loss: 0.0150 - acc: 0.9990
Epoch 00010: val_loss improved from 0.01698 to 0.01501, saving model to /data/models/tracking_model_seed1_tl9.h5
5536/5536 [==============================] - 4456s 805ms/step - loss: 0.0201 - acc: 0.9983 - val_loss: 0.0150 - val_acc: 0.9990

Evaluate Model Performance

Requires a Seed Value

[23]:
from sklearn.metrics import confusion_matrix

Y = []
Y_pred = []

for i in range(1,1000):
    if i % 100 == 0:
        print(".", end="")
    lst, y_true = next(test_data)
    y_true = np.argmax(y_true, axis=-1)
    y_pred = np.argmax(tracking_model.predict(lst), axis=-1)
    Y.append(y_true)
    Y_pred.append(y_pred)

Y = np.concatenate(Y, axis=0)
Y_pred = np.concatenate(Y_pred, axis=0)

print("")
cm = confusion_matrix(Y, Y_pred)
print(cm)
.........
[[40769    38    23]
 [   18 40452    50]
 [    0     0 40490]]
[24]:
test_acc = sum(np.array(Y) == np.array(Y_pred)) / len(Y)
print('Accuracy across all three classes: ', test_acc)

# Normalize the diagonal entries of the confusion matrix
cm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
# Diagonal entries are the accuracies of each class
print('Accuracy for each individual class [Different, Same, Daughter]: ', cm.diagonal())
Accuracy across all three classes:  0.9989412344057781
Accuracy for each individual class [Different, Same, Daughter]:  [0.998506   0.99832182 1.        ]

Next Steps

This model is used within an assignment problem framework to track cells through time-lapse sequences and build cell lineages. To see how this works on example data, refer to Part 2 of this notebook series: Tracking Example with Benchmarking.