Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How to get nearest neighbor for every element of array A from array B

I need to create function nearest_neighbor(src, dst), which accepts two arrays of 2D points, and for every point of array A calculates distance and index to closest neighbor from array B.

Example input:

src = np.array([[1,1], [2,2],[3,3],[4,4],[9,9]])
dst = np.array([[6,7],[10,10],[10,20]])

Example output:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

(array([7.81024968, 6.40312424, 5.        , 3.60555128, 1.41421356]),
 array([0, 0, 0, 0, 1]))

With sklearn you can do it like this:

def nearest_neighbor(src, dst):
    neigh = NearestNeighbors(n_neighbors=1)
    neigh.fit(dst)
    distances, indices = neigh.kneighbors(src, return_distance=True)
    return distances.ravel(), indices.ravel()

But i need to create it only with numpy. I made it like this:

def nearest_neighbor(src, dst):
    distances = []
    indices = []
    
    
    for dot in src:
        dists = np.linalg.norm(dst - dot,axis=1)
        dist = np.min(dists)
        idx = np.argmin(dists)
        
        distances.append(dist)
        indices.append(idx)

    return np.array(distances), np.array(indices)

But it works slow because of python cycle. How I can make it faster?

>Solution :

Using broadcast, src[:, None] - dst make each row of src subtract each row of dst:

>>> def nearest_neighbor(src, dst):
...     dist = np.linalg.norm(src[:, None] - dst, axis=-1)
...     indices = dist.argmin(-1)
...     return dist[np.arange(len(dist)), indices], indices
...
>>> src = np.array([[1,1], [2,2],[3,3],[4,4],[9,9]])
>>> dst = np.array([[6,7],[10,10],[10,20]])
>>> nearest_neighbor(src, dst)
(array([7.81024968, 6.40312424, 5.        , 3.60555128, 1.41421356]),
 array([0, 0, 0, 0, 1], dtype=int64))
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading