Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

pyspark dataframe limiting on multiple columns

I wonder if anyone point me in the right direction with the following problem. In a rather large pyspark dataframe with about 50 odd columns, two of them represent say ‘make’ and ‘model’. Something like

21234234322(unique id) .. .. .. Nissan  Navara .. .. ..
73647364736            .. .. .. BMW     X5     .. .. ..

What I would like to know is what the top 2 models per brand are. I can groupby both columns and add a count no problem, but how do I then limit (or filter) that result? I.e. how do I keep the (up to) 2 most popular models per brand and remove the rest?

Whatever I try, I end up iterating over the brands that exist in the original dataframe manually. Is there another way?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

You can use a rank() with Window and filter():

from pyspark.sql import functions as func
from pyspark.sql.window import Window

df = spark.createDataFrame(
    [
        ('a', 1, 1),
        ('a', 1, 2),
        ('a', 1, 3),
        ('a', 2, 1),
        ('a', 2, 2),
        ('a', 3, 1)
    ],
    schema=['col1', 'col2', 'col3']
)

df.printSchema()
df.show(10, False)
+----+----+----+
|col1|col2|col3|
+----+----+----+
|a   |1   |1   |
|a   |1   |2   |
|a   |1   |3   |
|a   |2   |1   |
|a   |2   |2   |
|a   |3   |1   |
+----+----+----+

where col1 and col2 are grouping columns and col3 is your unique id:

df.groupBy(
    'col1', 'col2'
).agg(
    func.countDistinct('col3').alias('dcount')
).withColumn(
    'rank', func.rank().over(Window.partitionBy('col1').orderBy(func.desc('dcount')))
).filter(
    func.col('rank')<=2
).show(
    10, False
)
+----+----+------+----+
|col1|col2|dcount|rank|
+----+----+------+----+
|a   |1   |3     |1   |
|a   |2   |2     |2   |
+----+----+------+----+

You can use rank() after the grouping and aggregation to filter out the top 2 value in your each group (col1).

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading