从三元表达式(ternary expression)理解 numpy.where

本文介绍了 NumPy 中的 where 函数,展示了如何利用该函数高效地处理数组数据,包括一维和多维数组的条件筛选及数值替换,并对比了传统三元表达式的局限性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

三元表达式的一般形式:

x if condition else y

例如,我们有如下的两个数组以及条件:

xarr = np.array([1.1, 1.2, 1.3, 1.4, 1.5])

yarr = np.array([2.1, 2.2, 2.3, 2.4, 2.5])

cond = np.array([True, False, True, True, False])

当条件为 True 时,我们想从 xarr 中取值,反之从 yarr 中取值。用三元表达式需要这样来实现:

res = [(x if c else y)
       for x, y, c in zip(xarr, yarr, cond)]
res
"""
[1.1, 2.2, 1.3, 1.4, 2.5]
"""

但这个实现方法有许多问题。首先,对于大的数组执行速度不会太快。其次,不能对多维数组进行操作。


numpy.where 函数可看作是对三元表达式的向量化扩展。

numpy.where(condition, [x, y, ]/)

这里的 condition 为布尔类型的数组,x, y 也都为数组 (也可以为标量)。如果所有这些 arrays 都是一维的,那就等价于我们上面所写的三元表达式:

[(x if c else y) for c, x, y in zip(condition, x, y)]

上面的例子用 numpy.where 只需要这样写:

res = np.where(cond, xarr, yarr)
res
"""
[1.1, 2.2, 1.3, 1.4, 2.5]
"""

对于多维数组,例如,我们有如下数组:
arr = np.random.randn(5, 5)
arr
"""
array([[-0.97771187, -0.98135695, -0.13112475, -0.07527619,  1.20978508],
       [-1.12931145, -0.84098807,  2.04738178,  1.38584849, -0.51919951],
       [-0.92975612,  1.03771019, -0.08548654,  1.13116971, -0.89777143],
       [ 0.82876313, -0.73411161, -2.83065065, -1.14866989,  0.78968089],
       [ 0.03728637, -0.69337259, -1.40003486,  0.52986178,  1.34800647]])
"""

我们想把大于 0 的值替换为 2,而小于 0 的值替换为 -2:

arr > 0
"""
array([[False, False, False, False,  True],
       [False, False,  True,  True, False],
       [False,  True, False,  True, False],
       [ True, False, False, False,  True],
       [ True, False, False,  True,  True]])
"""
np.where(arr > 0, 2, -2)
"""
array([[-2, -2, -2, -2,  2],
       [-2, -2,  2,  2, -2],
       [-2,  2, -2,  2, -2],
       [ 2, -2, -2, -2,  2],
       [ 2, -2, -2,  2,  2]])
"""

只将大于 0 的值替换为 2:

np.where(arr > 0, 2, arr)
"""
array([[-0.97771187, -0.98135695, -0.13112475, -0.07527619,  2.        ],
       [-1.12931145, -0.84098807,  2.        ,  2.        , -0.51919951],
       [-0.92975612,  2.        , -0.08548654,  2.        , -0.89777143],
       [ 2.        , -0.73411161, -2.83065065, -1.14866989,  2.        ],
       [ 2.        , -0.69337259, -1.40003486,  2.        ,  2.        ]])
"""

注意

  • 使用 np.where 创建了一个新的 array。原有的 array 并未改变。
  • np.where(arr > 0, 2, arr),2 其实执行了广播操作
  • np.where(arr > 0, 2, -2),2 和 -2 都执行了广播操作

关于 NumPy 中的广播机制,可以看这篇文章:《NumPy 中的广播》


References

Python for Data Analysis, 2 n d ^{\rm nd} nd edition. Wes McKinney.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如松茂矣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值