Check if tensor is in list

I can do the following with a single int to retrieve a bool tensor:

import torch
a = torch.tensor([1,2,3])
a != 2
#tensor([ True, False,  True])

Can I do the same with a list in plain pytorch? I.e.:

import torch
a = torch.tensor([1,2,3])
a not in [2,3]
#tensor([ True, False,  False])

Thanks a lot for your time!

>Solution :

I think you want torch.isin

out = ~torch.isin(a, torch.tensor([2, 3]))
# or
out = torch.isin(a, torch.tensor([2, 3]), invert=True)
print(out)

tensor([ True, False, False])

Leave a Reply