[FIXED] Tensorflow accuracy from model.predict does not match final epoch val_accuracy of model.fit

Issue

I am trying to match the accuracy of a model.predict call to the final val_accuracy of model.fit(). I am using tf dataset.

val_ds = tf.keras.utils.image_dataset_from_directory(
    'my_path',
    validation_split=0.2,
    subset="validation",
    seed=38,
    image_size=(SIZE,SIZE),
)

The dataset setup for train_ds is similar. I prefetch both…

train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)

Than I get the labels for the val_ds so I can use them later

true_categories = tf.concat([y for x, y in val_ds], axis=0)

My model

inputs = tf.keras.Input(shape=(SIZE, SIZE, 3))
# ... some other layers
outputs = tf.keras.layers.Dense( len(CLASS_NAMES), activation = tf.keras.activations.softmax)(intermediate)
model = tf.keras.Model(inputs, outputs)

Compiles fine

model.compile(
  optimizer = 'adam', 
  loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
  metrics = ['accuracy'])

Seems to fit fine

history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=10, 
  class_weight=class_weights) #i do weight the classes due to imbalance

The last epoch output

Epoch 10: val_accuracy did not improve from 0.92291
176/176 [==============================] – 191s 1s/step – loss: 0.9876 – accuracy: 0.7318 – val_loss: 0.4650 – val_accuracy: 0.8580

Now I want to verify the val_accuracy == 0.8580 when I run model.predict()

predictions = model.predict(val_ds, verbose=2 ) 
flattened_predictions =  predictions.argmax(axis=1)
accuracy = metrics.accuracy_score(true_categories, flattened_predictions)
print ("Accuracy = ", accuracy)

Accuracy = 0.7980014275517487

I would have expected that to equal the last val accuracy, which was 0.8580, but it is off. My val_ds uses a seed so I should be getting the images in the same order when I shuffle, right? Getting ground truth labels is a pain using datasets, but I think (???) my method is correct.

I only have two classes and when I look at my predictions variable it looks like I am getting probabilities as I would expect, so I think I set up, compiled and fit my model correctly for sparse categorical cross entropy using softmax on my final layer output.

predictions[:3] #show the first 3 predictions, the values sum to 1.0 as expected

array([[0.42447385, 0.5755262 ],
[0.2162129 , 0.7837871 ],
[0.31917858, 0.6808214 ]], dtype=float32)

What am I missing?

Solution

What you are missing is that your validation dataset is shuffled at every iteration.

tf.keras.utils.image_dataset_from_directory has shuffle=True by default. And that shuffle method for a TensorFlow dataset has an argument reshuffle_each_iteration which is None by default. Therefore it is shuffled everytime.

The seed=38 parameter is used for tracking the samples that reserved for training and validation separately. In other words, with seed argument we can follow which samples will be used for validation dataset and vice versa.

As an example:

dataset = tf.data.Dataset.range(6)
dataset = dataset.shuffle(6, reshuffle_each_iteration=None, seed=154).batch(2)

print("First time iteration:")
for x in dataset:
    print(x)
print("\n")

print("Second time iteration")  
for x in dataset:
    print(x)

This will print:

First time iteration:
tf.Tensor([2 1], shape=(2,), dtype=int64)
tf.Tensor([3 0], shape=(2,), dtype=int64)
tf.Tensor([5 4], shape=(2,), dtype=int64)


Second time iteration
tf.Tensor([4 3], shape=(2,), dtype=int64)
tf.Tensor([0 5], shape=(2,), dtype=int64)
tf.Tensor([2 1], shape=(2,), dtype=int64)

Relevant source code for tf.keras.utils.image_dataset_from_directory can be found here.

If you want to match predictions with their respective labels, then you can loop over the dataset:

predictions = []
labels = []
for x, y in val_ds:
    predictions.append(np.argmax(model(x), axis=-1))
    labels.append(y.numpy())

predictions = np.concatenate(predictions, axis=0)
labels = np.concatenate(labels, axis=0)

Then you can check accuracy.

Answered By – Frightera

Answer Checked By – Katrina (Easybugfix Volunteer)

Leave a Reply

(*) Required, Your email will not be published