I know that torch.argmax(x, dim = 0) returns the index of the first maximum value in x along dimension 0. But is there an efficient way to return the indexes of the first n maximum values? If there are duplicate values I also want the index of those among the n indexes.
As a concrete example, say x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1]). I would like a function
generalized_argmax(xI torch.tensor, n: int)
such that
generalized_argmax(x, 4)
returns [0, 2, 4, 5] in this example.
>Solution :
To acquire all you need to go over the whole tensor anyway, the most efficient should therefore be to use argsort manually limited to n entries.
>>> x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
>>> x.argsort(dim=0, descending=True)[:n]
[2, 4, 0, 5]
Sort it again to get [0, 2, 4, 5] if you need the ascending order of indices.