[Python]Numpyのwhere関数について

Numpyのwhere関数について、具体的な使用例も含めて解説します。

スポンサーリンク

概要と使用例

Numpyのwhere関数は指定した条件に合っているかを調べ、真の時と偽の時で出力を変える関数です。

他の言語をやった事ある人だと、三項演算子の?:を思い浮かべる人もいると思います。

where関数の書式は以下の通りです。

Numpy.where(条件, x, y)

第1引数で条件を指定します。

指定した条件が真ならばxを出力し、偽ならばyを出力します。

a = 5
x = "aは3より大きい"
y = "aは3より小さい"

out = np.where(a > 3, x, y)

print(out)

#出力
//aは3より大きい

具体的な使用例としては、機械学習で用いるReLu関数の実装があります。

ReLu関数とは数値が0以下の時は0を出力し、0よりも大きいときはその数値自体を出力する関数です。

a = 6
b = -9
c = 0

def ReLu(x):
    return np.where(x <= 0, 0, x)

print(ReLu(a))
print(ReLu(b))
print(ReLu(c))

#出力
//6
//0
//0
スポンサーリンク

まとめ

  • where関数は、条件が真か偽かで出力を変える関数
  • ReLu関数の実装に使用できる

コメント

タイトルとURLをコピーしました