Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Groupby and return the row label of the maximum value in PySpark Dataframe

I have dataframe:

data = [('I ran home', 3, 1, 10), 
       ('I went home', 3, 1, 11),
       ('I looked at the cat', 4, 2, 19),
       ('The cat looked at me', 5, 3, 20),
       ('I ran home', 3, 4, 10),
       ('I went homes', 3, 4, 12)]

schema = StructType([ \
    StructField("text",StringType(),True), \
    StructField("word_count", IntegerType(), True), \
    StructField("group", IntegerType(), True), \
    StructField("len_text", IntegerType(), True)])

 
df = spark.createDataFrame(data=data, schema=schema)
df.show(truncate=False)
+--------------------+----------+-----+--------+
|text                |word_count|group|len_text|
+--------------------+----------+-----+--------+
|I ran home          |3         |1    |10      |
|I went home         |3         |1    |11      |
|I looked at the cat |4         |2    |19      |
|The cat looked at me|5         |3    |20      |
|I ran home          |3         |4    |10      |
|I went homes        |3         |4    |12      |
+--------------------+----------+-----+--------+

I want to filter rows with two conditions: if the values in the word_count column are the same and if the value in the len_text column is greater than the next row, then leave the greater value. In pandas i can do this with idmax():

df1 = df.loc[df.groupby('group')['len_text'].idxmax()]

Is there any analogue for pyspark? I want this result:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

+--------------------+----------+-----+--------+
|text                |word_count|group|len_text|
+--------------------+----------+-----+--------+
|I went home         |3         |1    |11      |
|I looked at the cat |4         |2    |19      |
|The cat looked at me|5         |3    |20      |
|I went homes        |3         |4    |12      |
+--------------------+----------+-----+--------+

>Solution :

You can use window functions, i.e. row_number

from pyspark.sql import functions as F, Window as W

w = W.partitionBy('group').orderBy(F.desc('len_text'))
df = df.withColumn('_rn', F.row_number().over(w))
df = df.filter('_rn=1').drop('_rn')

df.show()
# +--------------------+----------+-----+--------+
# |                text|word_count|group|len_text|
# +--------------------+----------+-----+--------+
# |         I went home|         3|    1|      11|
# | I looked at the cat|         4|    2|      19|
# |The cat looked at me|         5|    3|      20|
# |        I went homes|         3|    4|      12|
# +--------------------+----------+-----+--------+
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading