I am passing lists of column names of varying lengths to the PySpark’s groupby().agg function? The code I have written checks the length of the list and for example, if it is length 1, it will do a .agg(count) on the one element. If the list is of length 2, it will do two separate .agg(counts) producing two new .agg columns.
Is there a more succinct way to write this than through an if statement because as the lists of column names become longer I’ll have to add more elif statements.
For example:
agg_fields: list of column names
if len(agg_fields) == 1:
df = df.groupBy(col1, col2).agg(count(agg_fields[0]))
elif len(agg_fields) == 2:
df = df.groupBy(col1, col2).agg(count(agg_fields[0]), \
count(agg_fields[1]))
>Solution :
Yes, you can simply loop to create your aggregate statement:
agg_df = df.groupBy("col1","col2").agg(*[count(i).alias(i) for i in agg_fields])