Finding elements in an array using numpy.where()

In order to find some specific elements or indexes of those elements from an array we can use np.where. It returns all those indexes that are true or associated with true condition.

Get 12 random numbers between [10,50)

import numpy as np
>>> x = np.random.randint(10,50,12)
>>> x
array([14, 38, 16, 24, 21, 14, 19, 41, 10, 40, 49, 11])

Get index where element is more than 20

>>> idx_21 = np.where(x>20)
>>> idx_21
(array([ 1,  3,  4,  7,  9, 10]),) 

Get elements which are more than 20

>>> x[idx_21]
array([38, 24, 21, 41, 40, 49])

Another famous use case of np.where is to find images from a batch for which the predictions are either right or wrong. Example:

pred = [0,0,1,0,1,1]
actual = [0,0,0,0,1,0]

def rand_mask(mask): return np.random.choice(np.where(mask)[0], min(len(pred), 2), replace=False)
def rand_correct(is_correct): return rand_mask((pred == actual) == is_correct)

Get random indexes where prediction is correct

idxs_correct = rand_correct(True)

Get random indexes where prediction is incorrect

idxs_incorrect = rand_correct(False)

Get random images based on the correct or incorrect predicted indexes

imgs_correct = [images[id] for id in idxs_correct]
imgs_incorrect = [images[id] for id in idxs_incorrect]