keras - What‘s wrong with my loss function? The loss is nan as soon as I start training -


i using keras, , loss funtion below. shape of y_true , y_pred suppose (,15,15,8,5). last 5 x,y,w,h,conf of box. hope can me. thanks!

the output while training like:

8/1666 [..............................] - eta: 1662s - loss: nan - acc: 0.8994 

the loss function is:

def custom_loss(y_true, y_pred): ### adjust predictions # adjust x , y pred_box_xy = tf.sigmoid(y_pred[:, :, :, :, :2])  # adjust w , h pred_box_wh = tf.exp(y_pred[:, :, :, :, 2:4]) * np.reshape(anchors, [1, 1, 1, box, 2]) pred_box_wh = tf.sqrt(pred_box_wh / np.reshape([float(s_grid), float(s_grid)], [1, 1, 1, 1, 2]))  # adjust confidence pred_box_conf = tf.expand_dims(tf.sigmoid(y_pred[:, :, :, :, 4]), -1)  y_pred = tf.concat([pred_box_xy, pred_box_wh, pred_box_conf], 4)  ### adjust ground truth # adjust x , y center_xy = 0.5 * (y_true[:, :, :, :, 0:2] + y_true[:, :, :, :, 2:4]) center_xy = center_xy / np.reshape([(float(w_image) / s_grid), (float(w_image) / s_grid)], [1, 1, 1, 1, 2]) true_box_xy = center_xy - tf.floor(center_xy)  # adjust w , h true_box_wh = (y_true[:, :, :, :, 2:4] - y_true[:, :, :, :, 0:2]) true_box_wh = tf.sqrt(true_box_wh / np.reshape([float(w_image), float(h_image)], [1, 1, 1, 1, 2]))  true_box_conf = tf.expand_dims(y_true[:, :, :, :, 4], -1)  y_true = tf.concat([true_box_xy, true_box_wh, true_box_conf], 4)  ### compute weights weight_coor = tf.concat(4 * [true_box_conf], 4) weight_coor = lamda_cord * weight_coor  weight_conf = lamda_noobj * (1. - true_box_conf) + lamda_conf * true_box_conf  weight = tf.concat([weight_coor, weight_conf], 4)  ### finalize loss loss = tf.pow(y_pred - y_true, 2) loss = loss * weight loss = tf.reshape(loss, [-1, s_grid * s_grid * box * 5]) loss = tf.reduce_sum(loss, 1) loss = 0.5 * tf.reduce_mean(loss)  return loss 


Comments

Popular posts from this blog

php - Vagrant up error - Uncaught Reflection Exception: Class DOMDocument does not exist -

vue.js - Create hooks for automated testing -

Add new key value to json node in java -