I can do the following with a single int to retrieve a bool tensor:
import torch
a = torch.tensor([1,2,3])
a != 2
#tensor([ True, False, True])
Can I do the same with a list in plain pytorch? I.e.:
import torch
a = torch.tensor([1,2,3])
a not in [2,3]
#tensor([ True, False, False])
Thanks a lot for your time!
>Solution :
I think you want torch.isin
out = ~torch.isin(a, torch.tensor([2, 3]))
# or
out = torch.isin(a, torch.tensor([2, 3]), invert=True)
print(out)
tensor([ True, False, False])