Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Window function with PySpark

I have a PySpark Dataframe and my goal is to create a Flag column whose value depends on the value of the Amount column.
Basically, for each Group, I want to know if in any of the first three months, there is an amount greater than 0 and if that is the case, the value of the Flag column will be 1 for all the group, otherwise the value will be 0.

I will include an example to clarify a bit better.

Initial PySpark Dataframe:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

Group Month Amount
A 1 0
A 2 0
A 3 35
A 4 0
A 5 0
B 1 0
B 2 0
C 1 0
C 2 0
C 3 0
C 4 13
D 1 0
D 2 24
D 3 0

Final PySpark Dataframe:

Group Month Amount Flag
A 1 0 1
A 2 0 1
A 3 35 1
A 4 0 1
A 5 0 1
B 1 0 0
B 2 0 0
C 1 0 0
C 2 0 0
C 3 0 0
C 4 13 0
D 1 0 1
D 2 24 1
D 3 0 1

Basically, what I want is for each group, to sum the amount of the first 3 months. If that sum is greater than 0, the flag is 1 for all the elements of the group, and otherwise is 0.

>Solution :

You can create the flag column by applying a Window function. Create a psuedo-column which becomes 1 if the criteria is met and then finally sum over the psuedo-column and if it’s greater than 0, then there was atleast once row that met the criteria and set the flag to 1.

from pyspark.sql import functions as F
from pyspark.sql import Window as W

data = [("A", 1, 0, ), 
("A", 2, 0, ), 
("A", 3, 35, ), 
("A", 4, 0, ), 
("A", 5, 0, ), 
("B", 1, 0, ), 
("B", 2, 0, ), 
("C", 1, 0, ), 
("C", 2, 0, ), 
("C", 3, 0, ), 
("C", 4, 13, ), 
("D", 1, 0, ), 
("D", 2, 24, ), 
("D", 3, 0, ), ]

df = spark.createDataFrame(data, ("Group", "Month", "Amount", ))

ws = W.partitionBy("Group").orderBy("Month").rowsBetween(W.unboundedPreceding, W.unboundedFollowing)

criteria = F.when((F.col("Month") < 4) & (F.col("Amount") > 0), F.lit(1)).otherwise(F.lit(0))

(df.withColumn("flag", F.when(F.sum(criteria).over(ws) > 0, F.lit(1)).otherwise(F.lit(0)))
).show()

"""
+-----+-----+------+----+
|Group|Month|Amount|flag|
+-----+-----+------+----+
|    A|    1|     0|   1|
|    A|    2|     0|   1|
|    A|    3|    35|   1|
|    A|    4|     0|   1|
|    A|    5|     0|   1|
|    B|    1|     0|   0|
|    B|    2|     0|   0|
|    C|    1|     0|   0|
|    C|    2|     0|   0|
|    C|    3|     0|   0|
|    C|    4|    13|   0|
|    D|    1|     0|   1|
|    D|    2|    24|   1|
|    D|    3|     0|   1|
+-----+-----+------+----+
"""
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading