1

I'm implementing the U-Net model per the published paper here. This is my model so far:

def create_unet_model(image_size = IMAGE_SIZE):

    # Input layer is a 572,572 colour image
    input_layer = Input(shape=(image_size) + (3,))

    """ Begin Downsampling """

    # Block 1
    conv_1 = Conv2D(64, 3, activation = 'relu')(input_layer)
    conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)

    max_pool_1 = MaxPool2D(strides=2)(conv_2)

    # Block 2
    conv_3 = Conv2D(128, 3, activation = 'relu')(max_pool_1)
    conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)

    max_pool_2 = MaxPool2D(strides=2)(conv_4)

    # Block 3
    conv_5 = Conv2D(256, 3, activation = 'relu')(max_pool_2)
    conv_6 = Conv2D(256, 3, activation = 'relu')(conv_5)

    max_pool_3 = MaxPool2D(strides=2)(conv_6)

    # Block 4
    conv_7 = Conv2D(512, 3, activation = 'relu')(max_pool_3)
    conv_8 = Conv2D(512, 3, activation = 'relu')(conv_7)

    max_pool_4 = MaxPool2D(strides=2)(conv_8)

    """ Begin Upsampling """

    # Block 5
    conv_9 = Conv2D(1024, 3, activation = 'relu')(max_pool_4)
    conv_10 = Conv2D(1024, 3, activation = 'relu')(conv_9)

    upsample_1 = UpSampling2D()(conv_10)

    # Copy and Crop
    conv_8_cropped = Cropping2D(cropping=4)(conv_8)
    merge_1 = Concatenate()([conv_8_cropped, upsample_1])

    # Block 6
    conv_11 = Conv2D(512, 3, activation = 'relu')(merge_1)
    conv_12 = Conv2D(512, 3, activation = 'relu')(conv_11)

    upsample_2 = UpSampling2D()(conv_12)

    # Copy and Crop
    conv_6_cropped = Cropping2D(cropping=16)(conv_6)
    merge_2 = Concatenate()([conv_6_cropped, upsample_2])

    # Block 7
    conv_13 = Conv2D(256, 3, activation = 'relu')(merge_2)
    conv_14 = Conv2D(256, 3, activation = 'relu')(conv_13)
    upsample_3 = UpSampling2D()(conv_14)

    # Copy and Crop
    conv_4_cropped = Cropping2D(cropping=40)(conv_4)
    merge_3 = Concatenate()([conv_4_cropped, upsample_3])

    # Block 8
    conv_15 = Conv2D(128, 3, activation = 'relu')(merge_3)
    conv_16 = Conv2D(128, 3, activation = 'relu')(conv_15)
    upsample_4 = UpSampling2D()(conv_16)

    # Connect layers
    conv_2_cropped = Cropping2D(cropping=88)(conv_2)
    merge_4 = Concatenate()([conv_2_cropped, upsample_4])

    # Block 9
    conv_17 = Conv2D(64, 3, activation = 'relu')(merge_4)
    conv_18 = Conv2D(64, 3, activation = 'relu')(conv_17)

    # Output layer
    output_layer = Conv2D(1, 1, activation='sigmoid')(conv_18)

    """ Define the model """
    unet = Model(input_layer, output_layer)
    
    return unet

The cropping implemented as specified in this answer and is specific to 572x572 images.

Unfortunately this implementation causes a ResourceExhaustedError:

Exception has occurred: ResourceExhaustedError
 OOM when allocating tensor with shape[32,64,392,392] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[node model/cropping2d_3/strided_slice (defined at c:\main.py:74) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_function_3026]

Function call stack:
train_function
  File "C:\main.py", line 74, in main
    unet_model.fit(train_images, epochs=epochs, validation_data=validation_images, callbacks=CALLBACKS)
  File "C:\main.py", line 276, in <module>
    main()

My GPU is a GeForce RTX 2070 Super 8GB.

I verified that the image size was the source of this by reproducing the error in another u-net solution which I know works.

To workaround this issue, I'm trying to lower the image sizes e.g. 256x256. I've changed the Cropping2D layers to crop to the expected sizes for each layer:

# Copy and Crop - 24 -> 16
conv_8_cropped = Cropping2D(cropping=4)(conv_8)
merge_1 = Concatenate()([conv_8_cropped, upsample_1])

# Copy and Crop - 57 -> 24
conv_6_cropped = Cropping2D(cropping=((17,16),(17,16)))(conv_6)
merge_2 = Concatenate()([conv_6_cropped, upsample_2])

# Copy and Crop - 122 -> 40
conv_4_cropped = Cropping2D(cropping=41)(conv_4)
merge_3 = Concatenate()([conv_4_cropped, upsample_3])

# Copy and Crop - 252 -> 72
conv_2_cropped = Cropping2D(cropping=90)(conv_2)
merge_4 = Concatenate()([conv_2_cropped, upsample_4])

Updated model summary:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 254, 254, 64) 1792        input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 252, 252, 64) 36928       conv2d[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 126, 126, 64) 0           conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 124, 124, 128 73856       max_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 122, 122, 128 147584      conv2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 61, 61, 128)  0           conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 59, 59, 256)  295168      max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 57, 57, 256)  590080      conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 28, 28, 256)  0           conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 26, 26, 512)  1180160     max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 24, 24, 512)  2359808     conv2d_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 12, 12, 512)  0           conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 10, 10, 1024) 4719616     max_pooling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 8, 8, 1024)   9438208     conv2d_8[0][0]
__________________________________________________________________________________________________
cropping2d (Cropping2D)         (None, 16, 16, 512)  0           conv2d_7[0][0]
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 16, 16, 1024) 0           conv2d_9[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 16, 16, 1536) 0           cropping2d[0][0]
                                                                 up_sampling2d[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 14, 14, 512)  7078400     concatenate[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 12, 12, 512)  2359808     conv2d_10[0][0]
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D)       (None, 24, 24, 256)  0           conv2d_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 24, 24, 512)  0           conv2d_11[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 24, 24, 768)  0           cropping2d_1[0][0]
                                                                 up_sampling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 22, 22, 256)  1769728     concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 20, 20, 256)  590080      conv2d_12[0][0]
__________________________________________________________________________________________________
cropping2d_2 (Cropping2D)       (None, 40, 40, 128)  0           conv2d_3[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, 40, 40, 256)  0           conv2d_13[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 40, 40, 384)  0           cropping2d_2[0][0]
                                                                 up_sampling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 38, 38, 128)  442496      concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 36, 36, 128)  147584      conv2d_14[0][0]
__________________________________________________________________________________________________
cropping2d_3 (Cropping2D)       (None, 72, 72, 64)   0           conv2d_1[0][0]
__________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)  (None, 72, 72, 128)  0           conv2d_15[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 72, 72, 192)  0           cropping2d_3[0][0]
                                                                 up_sampling2d_3[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 70, 70, 64)   110656      concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 68, 68, 64)   36928       conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 68, 68, 1)    65          conv2d_17[0][0]
==================================================================================================
Total params: 31,378,945
Trainable params: 31,378,945
Non-trainable params: 0

This compiles fine but fails at train time with:

Exception has occurred: InvalidArgumentError
 Incompatible shapes: [32,68,68] vs. [32,256,256]
     [[node Equal (defined at c:\main.py:74) ]] [Op:__inference_train_function_3026]

Function call stack:
train_function

Does anyone know why the shapes are so incorrect at runtime and how I can fix them?

Update Image loading as part of custom Sequence implementation

source_image = load_img(source_image_paths[i], target_size=self.image_size, color_mode='grayscale')
target_image = load_img(target_image_paths[i], target_size=self.image_size, color_mode='grayscale')

#Start classes at 0
target_image = np.array(target_image) - 1

target_image_array.append(target_image)
source_image_array.append(np.array(source_image))
TomSelleck
  • 115
  • 5

1 Answers1

2

It appears that the original images are 68x68 pixels and the model expects 256x256.

You can use the Keras image processing API, in particular the smart_resize function to transform the images to expected number of pixels.

Something like this:

from tf.keras.preprocessing.image import smart_resize

target_size = (256,256)
image_resized = smart_resize(image_original, size=target_size, interpolation='bilinear')
Brian Spiering
  • 20,142
  • 2
  • 25
  • 102