有时,需要对某些 label 做 mask

#!/usr/bin/env python
# coding=utf-8

"""
tf version: 1.15.0

"""
import tensorflow as tf

# 维度 [batch_size, 1]
label1 = tf.constant([[0.0],
                      [1.0],
                      [1.0]])

label2 = tf.constant([[1.0],
                      [0.0],
                      [1.0]])

loss1 = tf.constant([[0.1],
                      [0.2],
                      [0.3]])

loss2 = tf.constant([[0.4],
                      [0.5],
                      [0.6]])

stack_loss = tf.stack([loss1, loss2], axis=1)

"""
# 如果使用 logical_or,即 [label1 or label2] 的关系

label_or = tf.cast(tf.logical_or(tf.cast(label1, dtype=tf.bool), tf.cast(label2, dtype=tf.bool)), dtype=tf.float32)
stack_label = tf.stack([label_or, label_or], axis=1)

最终结果
('loss_1: ', 0.6)  # 1.0 * 0.1 + 1.0 * 0.2 + 1.0 * 0.3 = 0.6 
('loss_2: ', 1.5)  # 1.0 * 0.4 + 1.0 * 0.5 + 1.0 * 0.6 = 1.5
"""

stack_label = tf.stack([label1, label2], axis=1)

losses_masked = tf.multiply(stack_loss, stack_label)
loss_1 = tf.reduce_sum(losses_masked[:, 0])
loss_2 = tf.reduce_sum(losses_masked[:, 1])

sess = tf.Session()

print("loss_1: ", sess.run(loss_1))  # 1.0 * 0.2 + 1.0 * 0.3 = 0.5
print("loss_2: ", sess.run(loss_2))  # 1.0 * 0.4 + 1.0 * 0.6 = 1.0

输出:

('loss_1: ', 0.5)
('loss_2: ', 1.0)
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐