I’m having an apparent erratic behaviour when using the PySpark SQL ‘sampleBy’ function.
Just to understand how it works, I’m trying to apply the stratified sampling over a sample of 100 numbers of values (0, 1 and 2) which are distributed nearly 1/3rd each.
I’m applying a fractioning of 10% for the value zero and 20% for the value 1 (and 0% for the value 2, as it’s not stated).
I would be expecting to obtain a sample output with 3 zeros (10% of 33.3, rounded) and 7 ones (20% of 33.3 rounded), but this is not the case, and the distribution of values in the output changes if I change the seed….Am I missing something? Is this normal?
from pyspark.sql.functions import col
from pyspark.sql import SQLContext
sqlC = SQLContext(sc)
dataset = sqlC.range(0, 100).select((col("id") % 3).alias("key"))
dataset.groupBy("key").count().orderBy("key").show()
#+---+-----+
#|key|count|
#+---+-----+
#| 0| 34|
#| 1| 33|
#| 2| 33|
#+---+-----+
# With seed = 0
sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
sampled.groupBy("key").count().orderBy("key").show()
#+---+-----+
#|key|count|
#+---+-----+
#| 0| 6|
#| 1| 11|
#+---+-----+
# With seed = 123
sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=123)
sampled.groupBy("key").count().orderBy("key").show()
#+---+-----+
#|key|count|
#+---+-----+
#| 0| 1|
#| 1| 11|
#+---+-----+
I’m applying a fractioning of 10% for the value zero and 20% for the value 1 (and 0% for the value 2, as it’s not stated).
I would be expecting to obtain a sample output with 3 zeros (10% of 33.3, rounded) and 7 ones (20% of 33.3 rounded), but this is not the case, and the distribution of values in the output changes if I change the seed….Am I missing something? Is this normal?
>Solution :
This partly down to how spark samples, and partly down to it being a small size in your example. Spark samples based on a likelihood of each row being selected, rather than on the fraction of the overall data. This isn’t explained in the sampleBy documentation but is in the sample docs.
I believe this is so that a full count is not necessary, and each row can be treated independently, which is far more efficient for large, distributed datasets.
If you make your sample dataset much bigger, you’ll see the differences become trivial:
from pyspark.sql.functions import col
from pyspark.sql import SQLContext
sqlC = SQLContext(sc)
dataset = sqlC.range(0, 1e8).select((col("id") % 3).alias("key"))
dataset.groupBy("key").count().orderBy("key").show()
#+---+--------+
#|key| count|
#+---+--------+
#| 0|33333334|
#| 1|33333333|
#| 2|33333333|
#+---+--------+
# With seed = 0
sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
sampled.groupBy("key").count().orderBy("key").show()
#+---+-------+
#|key| count|
#+---+-------+
#| 0|3332776|
#| 1|6662457|
#+---+-------+
# With seed = 123
sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=123)
sampled.groupBy("key").count().orderBy("key").show()
#+---+-------+
#|key| count|
#+---+-------+
#| 0|3331701|
#| 1|6667689|
#+---+-------+