[FIXED] model.predict() returning random float instead of 3 probabilities

Issue

My model takes as input like ‘USD’ and predicts the category of it, in this case 1 since I defined currency to be 1. In total there’s three categories to predict: 0=book name, 1=currency, or 2=security number. My training, validation, and test sets look like this, where the value is mapped to it’s label.

         Value          Label
0   CAT_USD_CORP          0
1   USD                   1
2   US348595EV89          2
3   ATTR_IRT_LDN_ISD_ISD  0
4   CAD                   1
              ...   

and I trained and tested the model with this code which works perfectly fine and gives me good accuracy:

model = tf.keras.Sequential([
  feature_layer,
  layers.Dense(128, activation='relu'),
  layers.Dense(128, activation='relu'),
  layers.Dropout(.1),
  layers.Dense(1)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_ds,
          validation_data=val_ds,
          epochs=12)

loss, accuracy = model.evaluate(test_ds)
print("Accuracy", accuracy)

but when I try to predict a value I input it always gives me an array of a random float instead of the probability of the 3 categories it falls under for example:

model.predict(np.array(['USD'])) returns array([[1195770.]], dtype=float32)
and

model.predict(np.array(['US568A45EV89'])) returns array([[1938599.4]], dtype=float32),

when they’re supposed to return an array of 3 probabilities.

Solution

If you have 3 classes then the top layer of your model should have 3 nodes as in

layers.Dense(3, activation='softmax')

In model.compile you have

loss=tf.keras.losses.BinaryCrossentropy(from_logits=True) 

If you have 3 classes this is NOT binary. If your labels are one hot encoded use

loss=tf.keras.losses.CategoricalCrossentropy()

if integer encode use

loss=tf.keras.losses.SparseCategoricalCrossentropy()

Then when you do predictions each prediction will be a list of 3 values

Answered By – Gerry P

Answer Checked By – Candace Johnson (Easybugfix Volunteer)

Leave a Reply

(*) Required, Your email will not be published