Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Pyspark dataframe column value dependent on value from another row

I have a dataframe like this:

columns = ['manufacturer', 'product_id']
data = [("Factory", "AE222"), ("Sub-Factory-1", "0"), ("Sub-Factory-2", "0"),("Factory", "AE333"), ("Sub-Factory-1", "0"), ("Sub-Factory-2", "0")]
rdd = spark.sparkContext.parallelize(data)
df = rdd.toDF(columns)
 
+-------------+----------+
| manufacturer|product_id|
+-------------+----------+
|      Factory|     AE222|
|Sub-Factory-1|         0|
|Sub-Factory-2|         0|
|      Factory|     AE333|
|Sub-Factory-1|         0|
|Sub-Factory-2|         0|
+-------------+----------+

Which I want to turn into this:

+-------------+----------+
| manufacturer|product_id|
+-------------+----------+
|      Factory|     AE222|
|Sub-Factory-1|     AE222|
|Sub-Factory-2|     AE222|
|      Factory|     AE333|
|Sub-Factory-1|     AE333|
|Sub-Factory-2|     AE333|
+-------------+----------+

So that each Sub-Factory gets the value from the closest Factory Value above the current Sub-Factory row. I can solve it with a nested for loop but it is not very efficient since there could be millions of rows. I have looked into Pyspark Window function but cannot really wrap my head around it. Any ideas?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

You can use first function with ignorenulls=True over a Window. But you need to identify groups of manufacturer in order to partition by that group.

As you didn’t give any ID column I’m using monotonically_increasing_id and a cumulative conditional sum to create a group column:

from pyspark.sql import functions as F

df1 = df.withColumn(
    "row_id",
    F.monotonically_increasing_id()
).withColumn(
    "group",
    F.sum(F.when(F.col("manufacturer") == "Factory", 1)).over(Window.orderBy("row_id"))
).withColumn(
    "product_id",
    F.when(
        F.col("product_id") == 0,
        F.first("product_id", ignorenulls=True).over(Window.partitionBy("group").orderBy("row_id"))
    ).otherwise(F.col("product_id"))
).drop("row_id", "group")

df1.show()
#+-------------+----------+
#| manufacturer|product_id|
#+-------------+----------+
#|      Factory|     AE222|
#|Sub-Factory-1|     AE222|
#|Sub-Factory-2|     AE222|
#|      Factory|     AE333|
#|Sub-Factory-1|     AE333|
#|Sub-Factory-2|     AE333|
#+-------------+----------+
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading