I need to create a special mutate function that uses the dplyr groups if they’re defined, otherwise it should calculate over the entire data.
I couldn’t find anything online on how to do this, so I’ve come with a solution.
Does anyone know the correct/sanctioned way of handling this?
Thanks to:
- romainfrancois for pointing out that the ‘indices’ attribute is undocumented, and I should use ‘group_indices()’ instead.
The naive function (totally unaware of grouping)
For this example, I’m writing a clone of dplyr::add_tally()
called add_n()
(my real problem is much more complicated).
add_n <- function(df) {
df$n <- nrow(df)
df
}
This function works great on an ungrouped dataset…
mtcars %>%
head() %>%
add_n()
# A tibble: 6 x 12
mpg cyl disp hp drat wt qsec vs am gear carb n
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 6
2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 6
3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 6
4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 6
5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 6
6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 6
… but on a grouped data set I don’t get the counting done per group, just the overall count.
mtcars %>%
head() %>%
group_by(cyl) %>%
add_n()
# A tibble: 6 x 12
# Groups: cyl [3]
mpg cyl disp hp drat wt qsec vs am gear carb n
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 6
2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 6
3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 6
4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 6
5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 6
6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 6
The group-aware function
add_n <- function(df) {
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# If df is grouped, then use 'do()' to apply the function to all groups.
# Note: As the data is re-organised by the grouping variable(s) during the
# 'do()', use the group_indices() to put the data back in the original order
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if (dplyr::is_grouped_df(df)) {
indices <- group_indices(df)
df <- do(df, add_n(.))
df <- df[order(order(indices)), , drop = FALSE]
return(df)
}
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This is just the body of the naive version of the function
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
df$n <- nrow(df)
df
}
This function works great on an ungrouped dataset…
mtcars %>%
head() %>%
add_n()
# A tibble: 6 x 12
mpg cyl disp hp drat wt qsec vs am gear carb n
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 6
2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 6
3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 6
4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 6
5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 6
6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 6
… and also works properly to create per-group counts, with the data kept in its original order.
mtcars %>%
head() %>%
group_by(cyl) %>%
add_n()
# A tibble: 6 x 12
# Groups: cyl [3]
mpg cyl disp hp drat wt qsec vs am gear carb n
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 4
2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 4
3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 1
4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 4
5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 1
6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 4