I am trying to speed-up a particular calculation on a data.table object. The table contains a value column and one or more grouping columns. For each group combination, if the sum of the values is greather than one, I would like to scale down the values proportionally so that their sum becomes one. We can assume that the values are non-negative.
Here is an example setup with half a million rows:
library(data.table)
n_rows <- 5e5
dt <- data.table(
# grouping variables
group1 = sample(letters, size = n_rows, replace = TRUE),
group2 = sample(letters, size = n_rows, replace = TRUE),
group3 = sample(letters, size = n_rows, replace = TRUE),
group4 = sample(letters, size = n_rows, replace = TRUE),
# non-negative values
value = runif(n_rows)
)
And the solution I have:
scale_values <- function(values) {
values / sum(values)
}
dt[, {value := if (sum(value) > 1) {
scale_values(value)
} else {
value
}},
by = list(group1, grpup2, group3, group4)]
This is the fastest option I have found so far (after having played a bit with data.table syntax and a few alternatives to scale_values), but I would like to make this faster. In actual usage, dt will have approximately 20 million rows and five grouping variables.
Thank you in advance for any ideas on how to improve this.
EDIT: it turns out a two-step solution is much faster:
dt[, sum_values := sum(values), by = list(group1, group2, group3, group4)][
sum_values > 1, value := scale_values(value), by = list(group1, group2, group3, group4)]
Though it isn’t very clear to me why.
>Solution :
You want to utilize Gforce optimization. Use verbose = TRUE to check if it is utilized.
library(data.table)
n_rows <- 5e5
set.seed(42)
dt <- data.table(
group1 = sample(letters, size = n_rows, replace = TRUE),
group2 = sample(letters, size = n_rows, replace = TRUE),
group3 = sample(letters, size = n_rows, replace = TRUE),
group4 = sample(letters, size = n_rows, replace = TRUE),
value = runif(n_rows)
)
dt1 <- copy(dt)
scale_values <- function(values) {
values / sum(values)
}
system.time(
dt[, value := if (sum(value) > 1) {
scale_values(value)
} else {
value
},
by = list(group1, group2, group3, group4), verbose = TRUE])
#user system elapsed
#0.39 0.09 0.39
system.time({
dt1[, sum := sum(value), by = list(group1, group2, group3, group4), verbose = TRUE]
dt1[sum > 1, value := value / sum, by = list(group1, group2, group3, group4), verbose = TRUE]
})
#user system elapsed
#0.13 0.02 0.12
all.equal(dt[["value"]], dt1[["value"]])
#[1] TRUE