Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Finding the maximum elements whose sum is less than a given value in PyTorch

Given any PyTorch 2D tensor, what will be the most efficient way to compute the number of top-K values for each row whose sum is less than a given value?

Input:

tensor([[0.6607, 0.1165, 0.0278, 0.1950],
        [0.0529, 0.4607, 0.2729, 0.2135],
        [0.3267, 0.0902, 0.4578, 0.1253]])

Required Output for the given value 0.8:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

tensor([[1], #as 0.6607+0.1950 > 0.8
        [2], #as 0.4607+0.2729+0.2135 > 0.8
        [2]]) #as 0.4578+0.3267+0.1253 > 0.8

>Solution :

You can manage such operation by using a combination of sorting, cumulative sum, and max pooling.

First sort the values by row with torch.Tensor.sort

>>> v = x.sort(dim=1, descending=True).values
tensor([[0.6607, 0.1950, 0.1165, 0.0278],
        [0.4607, 0.2729, 0.2135, 0.0529],
        [0.4578, 0.3267, 0.1253, 0.0902]])

Then construct a mask on the cumulative sorted values that you get from applying torch.cumsum:

>>> torch.cumsum(v, dim=1) > .8
tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False,  True,  True]])

Applying a torch.Tensor.max on that mask will return the index of the first occurring True value, i.e. the location of the cumulative element which is above the threshold value 0.8:

>>> mask.max(1, True).indices
tensor([[1],
        [2],
        [2]])
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading