TensorFlow函数:tf.where
tf.where函数
tf.where(
condition,
x=None,
y=None,
name=None
)
定义在:tensorflow/python/ops/array_ops.py.
请参阅指南:控制流程>比较运算符,数学函数>序列比较和索引
根据condition返回x或y中的元素.
如果x和y都为None,则该操作将返回condition中true元素的坐标.坐标以二维张量返回,其中第一维(行)表示真实元素的数量,第二维(列)表示真实元素的坐标.请记住,输出张量的形状可以根据输入中的真实值的多少而变化.索引以行优先顺序输出.
如果两者都不是None,则x和y必须具有相同的形状.如果x和y是标量,则condition张量必须是标量.如果x和y是更高级别的矢量,则condition必须是大小与x的第一维度相匹配的矢量,或者必须具有与x相同的形状.
condition张量作为一个可以选择的掩码(mask),它根据每个元素的值来判断输出中的相应元素/行是否应从 x (如果为 true) 或 y (如果为 false)中选择.
如果condition是向量,则x和y是更高级别的矩阵,那么它选择从x和y复制哪个行(外部维度).如果condition与x和y具有相同的形状,那么它将选择从x和y复制哪个元素.
函数参数:
- condition:一个bool类型的张量(Tensor).
- x:可能与condition具有相同形状的张量;如果condition的秩是1,则x可能有更高的排名,但其第一维度必须匹配condition的大小.
- y:与x具有相同的形状和类型的张量.
- name:操作的名称(可选).
返回值:
如果它们不是None,则返回与x,y具有相同类型与形状的张量;张量具有形状(num_true, dim_size(condition)).
可能引发的异常:
- ValueError:当一个x或y正好不是None.
更多建议: