Writing a function which respects `dplyr::group_by()`

I need to create a special mutate function that uses the dplyr groups if they’re defined, otherwise it should calculate over the entire data.

I couldn’t find anything online on how to do this, so I’ve come with a solution.

Does anyone know the correct/sanctioned way of handling this?

Thanks to:

  • romainfrancois for pointing out that the ‘indices’ attribute is undocumented, and I should use ‘group_indices()’ instead.

The naive function (totally unaware of grouping)

For this example, I’m writing a clone of dplyr::add_tally() called add_n()
(my real problem is much more complicated).

add_n <- function(df) {
  df$n <- nrow(df)
  df
}

This function works great on an ungrouped dataset…

mtcars %>% 
  head() %>%
  add_n()
# A tibble: 6 x 12
    mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb     n
  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1  21       6   160   110  3.9   2.62  16.5     0     1     4     4     6
2  21       6   160   110  3.9   2.88  17.0     0     1     4     4     6
3  22.8     4   108    93  3.85  2.32  18.6     1     1     4     1     6
4  21.4     6   258   110  3.08  3.22  19.4     1     0     3     1     6
5  18.7     8   360   175  3.15  3.44  17.0     0     0     3     2     6
6  18.1     6   225   105  2.76  3.46  20.2     1     0     3     1     6

… but on a grouped data set I don’t get the counting done per group, just the overall count.

mtcars %>% 
  head() %>%
  group_by(cyl) %>%
  add_n()
# A tibble: 6 x 12
# Groups:   cyl [3]
    mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb     n
  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1  21       6   160   110  3.9   2.62  16.5     0     1     4     4     6
2  21       6   160   110  3.9   2.88  17.0     0     1     4     4     6
3  22.8     4   108    93  3.85  2.32  18.6     1     1     4     1     6
4  21.4     6   258   110  3.08  3.22  19.4     1     0     3     1     6
5  18.7     8   360   175  3.15  3.44  17.0     0     0     3     2     6
6  18.1     6   225   105  2.76  3.46  20.2     1     0     3     1     6

The group-aware function

add_n <- function(df) {
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # If df is grouped, then use 'do()' to apply the function to all groups.
  # Note: As the data is re-organised by the grouping variable(s) during the 
  #       'do()', use the group_indices() to put the data back in the original order 
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  if (dplyr::is_grouped_df(df)) {
    indices <- group_indices(df) 
    df      <- do(df, add_n(.)) 
    df      <- df[order(order(indices)), , drop = FALSE]
    return(df)
  } 
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # This is just the body of the naive version of the function
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  df$n <- nrow(df)
  df
}

This function works great on an ungrouped dataset…

mtcars %>% 
  head() %>%
  add_n()
# A tibble: 6 x 12
    mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb     n
  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1  21       6   160   110  3.9   2.62  16.5     0     1     4     4     6
2  21       6   160   110  3.9   2.88  17.0     0     1     4     4     6
3  22.8     4   108    93  3.85  2.32  18.6     1     1     4     1     6
4  21.4     6   258   110  3.08  3.22  19.4     1     0     3     1     6
5  18.7     8   360   175  3.15  3.44  17.0     0     0     3     2     6
6  18.1     6   225   105  2.76  3.46  20.2     1     0     3     1     6

… and also works properly to create per-group counts, with the data kept in its original order.

mtcars %>% 
  head() %>%
  group_by(cyl) %>%
  add_n()
# A tibble: 6 x 12
# Groups:   cyl [3]
    mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb     n
  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
1  21       6   160   110  3.9   2.62  16.5     0     1     4     4     4
2  21       6   160   110  3.9   2.88  17.0     0     1     4     4     4
3  22.8     4   108    93  3.85  2.32  18.6     1     1     4     1     1
4  21.4     6   258   110  3.08  3.22  19.4     1     0     3     1     4
5  18.7     8   360   175  3.15  3.44  17.0     0     0     3     2     1
6  18.1     6   225   105  2.76  3.46  20.2     1     0     3     1     4