[FIXED] TF pipeline to dynamically extract patches and flatten dataset

Issue

I was trying to train an autoencoder on image patches. My training data consists of single-channel images loaded into a numpy array with shape [10000, 256, 512, 1]. I know how to extract patches from the images but it is rather non-intuitive that the batches select images and thus the number of points per batch depends on how many patches are extracted per image. If 32 patches are extracted per image, I’d like the dataset to behave as if it were [320000, 256, 512, 1] so that shuffling and batches pull from several images at a time but with the patches extracted on the fly so that this doesn’t have to be kept in memory.

The closest question I’ve seen around is Load tensorflow images and create patches but, as I’ve mentioned, it doesn’t provide what I want.

PATCH_SIZE = 64

def extract_patches(imgs, patch_size=PATCH_SIZE, stride=PATCH_SIZE//2):
    # extract patches and reshape them into patch images
    n_channels = imgs.shape[-1]
    if len(imgs.shape) < 4:
        imgs = tf.expand_dims(imgs, axis=0)  
    return tf.reshape(tf.image.extract_patches(imgs,
                                               sizes=[1, patch_size, patch_size, n_channels],
                                               strides=[1, stride, stride, n_channels],
                                               rates=[1, 1, 1, 1],
                                               padding='VALID'),
                      (-1, patch_size, patch_size, n_channels))

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            )

creates a dataset that returns batches with shape (batch_size, 105, 64, 64, 1) whereas I want a rank 4 tensor with shape (batch_size, 64, 64, 1) and shuffle to operate on patches (rather than collections of patches for each image). If I put .map at the end of the pipeline

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            )

This does flatten the batches and returns a rank 4 tensor, but in this case each batch has shape (840, 64, 64, 1).

Solution

What I feel is that what you want to achieve is not possible because you want to shuffle all the patches from all batches of the dataset image and generate it over the fly without saving it in memory. And because your single image after applying extract_patches is returning 105 patches (because your stride(32) and patch size(64) is not matching) what you can do to achieve rank 4 tensor after applying .batch() is reshaping it as follows,

dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            .map(lambda x:tf.reshape(x, (batch_size * 105,64,64,1)))
            .batch(batch_size)
            )

I’m not sure but you can try this.

**Correction: this approach won’t work as .batch() will always return the higher rank of dataset element. As tf.data.Dataset.batch() documents mentions

Combines consecutive elements of this dataset into batches.

Final Update:
you can do this by using .unbatch() function just after .map(), that will reduce your 4 rank dataset element (105,64,64,3) to 3 rank element (64,64,3) and then you can use .shuffle(), .batch() or any other function just like regular dataset.

dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
                .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
                .unbatch()
                .shuffle(10*batch_size, reshuffle_each_iteration=True)
                .batch(batch_size)
                )

Answered By – abhi_khoyani

Answer Checked By – Mildred Charles (Easybugfix Admin)

Leave a Reply

(*) Required, Your email will not be published