[FIXED] Trouble with loss function tf.nn.weighted_cross_entropy_with_logits


I am trying to train a u-net network with binary targets. The usual Binary Cross Entropy loss does not perform well, since the lables are very imbalanced (many more 0 pixels than 1s). So I want to punish false negatives more. But tensorflow doesn’t have a ready-made weighted binary cross entropy. Since I didn’t want to write a loss from scratch, I’m trying to use tf.nn.weighted_cross_entropy_with_logits. to be able to easily feed the loss to model.compile function, I’m writing this wrapper:

def loss_wrapper(y,x):
   x = tf.cast(x,'float32')
   loss = tf.nn.weighted_cross_entropy_with_logits(y,x,pos_weight=10)
   return loss

However, regardless of casting x to float, I’m still getting the error:

TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type int32 of argument 'x'.

when the tf loss is called. Can someone explain what’s happening?


If x represents your predictions. It probably already has the type float32. I think you need to cast y, which is presumably your labels. So:

loss = tf.nn.weighted_cross_entropy_with_logits(tf.cast(y, dtype=tf.float32),x,pos_weight=10)

Answered By – AloneTogether

Answer Checked By – Marilyn (Easybugfix Volunteer)

Leave a Reply

(*) Required, Your email will not be published