Tuesday, November 27, 2018

A PyTorch implementation of Image Segmentation Using UNet, Stratification and K-Fold Learning

This PyTorch script is the by-product of attending a Kaggle competition for image segmentation (https://www.kaggle.com/c/tgs-salt-identification-challenge). In this competition, the original images come from geological survey. The whole area of the original image can be divided into subarea with salt under the surface and subarea without salt under the surface (see original mask image). The competition goal is to segment test images into binary masks in which white means salt area and black means non-salt area.


Usually I wrote deep learning scripts using Keras. However, in this case, we choose to use PyTorch for pragmatic considerations. It is well-known that UNet [1] provides good performance for segmentation task. Part of the UNet is based on well-known neural network models such as VGG or Resnet. Compared with Keras, PyTorch seems to provide more options of pre-trained models. For instance, pre-trained model for Resnet34 is available in PyTorch but not in Keras. In order to capture the benefit of transfer learning, PyTorch is chosen over Keras for implementation.

For performance enhancement, when dividing training data to training set and validation set, stratification is used to ensure that images with various salt coverage percentage are all well-represented. In codes below, all training images are divided into 10 categories based on the percentage of their salt coverage ratio from low to high.


# Stratification: data binning based on salt size in mask. Divide each category to training and validation data
ind = np.arange(len(Y_train_shaped))
np.random.shuffle(ind)
coverage = []
for i in range(0, len(Y_train_shaped)):
  coverage.append(np.sum(Y_train_shaped[ind[i]]))

hist, bin_edges = np.histogram(coverage)
# In np.digitize, each index i returned is such that bins[i-1] <= x < bins[i]
# Need to increase the last bin_edges by 1 to avoid genarating a new category with digitize
bin_edges[len(bin_edges)-1] = bin_edges[len(bin_edges)-1] + 1
cindex = np.digitize(coverage,bin_edges)

In the next step, when training images are divided into training and validation set with 8:2 ratio, they are divided into 8:2 in each of these 10 salt-percentage-based categories. This is call stratification. In addition, we use 5-fold cross validation. It means the all training images are equally divided to 5 sets: 0/1/2/3/4. Then model 0 is trained with set 0 as validation and set 1/2/3/4 as training; model 1 is trained with set 1 as validation and set 0/2/3/4 as training; and so on. When all finished, we will have 5 trained models, and the final test results can be ensemble of the outputs of these 5 models. The benefit of K-fold learning is that all data is fully utilized and improved performance by ensemble; the drawback is higher computational complexity.


val_size = 2/10
for ii in range(5): #5-fold learning
    k = ii
    print('Training for '+str(k)+' of 5 fold starts!')
    train_idxs = []
    val_idxs = []
    for i in range(0,10):
      index_temp = ind[cindex==i+1]
      list_temp = index_temp.T.tolist()
      val_samples = round(len(index_temp)*val_size)
      if (k == 0):
          val_idxs = val_idxs + list_temp[:val_samples]
          train_idxs = train_idxs + list_temp[val_samples:]
      elif (k == 4):
          val_idxs = val_idxs + list_temp[4*val_samples:]
          train_idxs = train_idxs + list_temp[:4*val_samples]
      else:
          val_idxs = val_idxs + list_temp[k*val_samples:(k+1)*val_samples]
          train_idxs = train_idxs + list_temp[:k*val_samples] + list_temp[(k+1)*val_samples:]

In this implementation, we use AlbuNet [2], which is an variation of UNet.


    model = AlbuNet(pretrained=True, is_deconv=True);
    model.cuda();

    criterion = nn.BCEWithLogitsLoss()
    learning_rate = 1e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

AlbuNet is UNet based on Resnet34. Its topology is shown below:



















To capture the whole implementation, source code of this implementation can be found here. You can run the script after downloading the dataset from Kaggle.


[1] Olaf Ronneberger et al, U-Net: Convolutional Networks for Biomedical Image Segmentation
[2] Vladimir Iglovikov and Alexey Shvets, TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation


No comments:

Post a Comment