In pandas I have the option of doing the following to get the name of the column containing the maximum value for each row:
import pandas as pd
df = pd.DataFrame({'a': [1, 2, 3, 4, 5], 'b': [5, 4, 3, 2, 1]})
df['Largest'] = df.idxmax(axis=1)
Which gets me:
| a | b | Largest | |
|---|---|---|---|
| 0 | 1 | 5 | b |
| 1 | 2 | 4 | b |
| 2 | 3 | 3 | a |
| 3 | 4 | 2 | a |
| 4 | 5 | 1 | a |
How can I do this kind of operation in polars? There doesn’t seem to be an idxmax method, and max_horizontal seems to only return the value rather than any indexing information.
>Solution :
You could do
def arg_max_horizontal(*columns: pl.Expr) -> pl.Expr:
return (
pl.concat_list(columns)
.list.arg_max()
.map_dict({i: col_name for i, col_name in enumerate(columns)})
)
print(dfpl.with_columns(Largest=arg_max_horizontal("a", "b")))
Then you’ll get
Out[11]:
shape: (5, 3)
┌─────┬─────┬─────────┐
│ a ┆ b ┆ Largest │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ str │
╞═════╪═════╪═════════╡
│ 1 ┆ 5 ┆ b │
│ 2 ┆ 4 ┆ b │
│ 3 ┆ 3 ┆ a │
│ 4 ┆ 2 ┆ a │
│ 5 ┆ 1 ┆ a │
└─────┴─────┴─────────┘