I wonder if anyone point me in the right direction with the following problem. In a rather large pyspark dataframe with about 50 odd columns, two of them represent say ‘make’ and ‘model’. Something like
21234234322(unique id) .. .. .. Nissan Navara .. .. ..
73647364736 .. .. .. BMW X5 .. .. ..
What I would like to know is what the top 2 models per brand are. I can groupby both columns and add a count no problem, but how do I then limit (or filter) that result? I.e. how do I keep the (up to) 2 most popular models per brand and remove the rest?
Whatever I try, I end up iterating over the brands that exist in the original dataframe manually. Is there another way?
>Solution :
You can use a rank() with Window and filter():
from pyspark.sql import functions as func
from pyspark.sql.window import Window
df = spark.createDataFrame(
[
('a', 1, 1),
('a', 1, 2),
('a', 1, 3),
('a', 2, 1),
('a', 2, 2),
('a', 3, 1)
],
schema=['col1', 'col2', 'col3']
)
df.printSchema()
df.show(10, False)
+----+----+----+
|col1|col2|col3|
+----+----+----+
|a |1 |1 |
|a |1 |2 |
|a |1 |3 |
|a |2 |1 |
|a |2 |2 |
|a |3 |1 |
+----+----+----+
where col1 and col2 are grouping columns and col3 is your unique id:
df.groupBy(
'col1', 'col2'
).agg(
func.countDistinct('col3').alias('dcount')
).withColumn(
'rank', func.rank().over(Window.partitionBy('col1').orderBy(func.desc('dcount')))
).filter(
func.col('rank')<=2
).show(
10, False
)
+----+----+------+----+
|col1|col2|dcount|rank|
+----+----+------+----+
|a |1 |3 |1 |
|a |2 |2 |2 |
+----+----+------+----+
You can use rank() after the grouping and aggregation to filter out the top 2 value in your each group (col1).