strict `case_when`

Problem: case_when() isn’t strict enough

case_when() is fantastic, but in some cases I’d like it to be a little bit stricter on what it allows.

I want to eliminate ways in which errors or oversights can creep in, so I’d like special handling for the following cases:

  1. the fall-through/catch-all value of TRUE should be disallowed by default.
  2. all input values should match one rule
  3. no input value should match more than one rule

Before starting, let me state clearly that

  • case_when() is awesome
  • case_when() is 1000x better than nested ifelse() statements.

Use case

My main use case for this strict version is ensuring that continuous values are correctly turned into categories, when using complicated rules involving multiple variables.

In these cases:

  • every input value should match a category
  • every input value should match only one category

Given these assumptions, there are quite a few problems with how I’ve written the following code to map size to a categorical.

size <- seq(0, 20, 0.5)

case_when(
  between(size,  1, 5 ) ~ 'small',
  between(size,  6, 12) ~ 'medium',
  between(size, 11, 20) ~ 'big',
  TRUE                  ~ 'huge'
)

Issues:

  1. I’m missing a rule to cope with the input value of zero. This will fall through to the default response of ‘huge’
  2. There’s no rule to handle the value of 5.5.
  3. There are multiple rules to handle the value of 11.5

case_when() issue #1 - use of TRUE fall-through

The use of a final TRUE value to catch all unmatched values can lead to unexpected results.

In the following example, I have made a typo in defining the rule for cat. Since cat is not matched by the first rule (or any other rules), it falls through to the default classification of bird.

animal <- c('cat', 'dog', 'eagle')

case_when(
  animal == 'catt' ~ 'mammal',
  animal == 'dog'  ~ 'mammal',
  TRUE             ~ 'bird'
)
[1] "bird"   "mammal" "bird"  

case_when() issue #2 - Every input value should match a rule

By default, case_when() returns an NA for values that are not captured by a rule.

In the following, I’ve again made a typo. This time an input of dog matches none of the rules and is returned as NA

animal <- c('cat', 'dog', 'eagle')

case_when(
  animal == 'cat'   ~ 'mammal',
  animal == 'dogg'  ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
[1] "mammal" NA       "bird"  

A proposed solution is offered in this github issue. But it appears to not work. I’ve posted a follow-up on the github issue, and I’ll wait to see if anyone can clarify.

case_when() issue #3 - Every input value should match only 1 rule

By default, case_when() returns the result for the first rule which matches an input value.

In the following, I’ve again made an error with the ruleset, and there are two rules which match an input of eagle. The first one is obviously wrong, but this is what will be returned and any later matching rule is ignored.

animal <- c('cat', 'dog', 'eagle')

case_when(
  animal == 'eagle' ~ 'insect',
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
[1] "mammal" "mammal" "insect"

This is not something case_when() does, and is succinctly discussed in this github issue

Solution

  • strict_case_when() is a thin wrapper around case_when()
  • It’s going to call case_when() first and let that function verify/test all the inputs are correct and calculate a response.
  • After case_when() succeeds, then we are assured that all the arguments are well-formed - this will make my checking code easier to write.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#' Stricter version of case_when()
#'
#' @param ... arguments to case_when
#' @param .strict_mode Turn on strict mode. Default: TRUE
#'   This mode 
#'     - disallows a fall-through 'TRUE' value on the LHS. 
#'     - disallows input values which do not match any rules. 
#'     - disallows input values which match more than one rule
#'    
#'
#' @return A vector of length 1 or n, matching the length of the logical input 
#' or output vectors, with the type (and attributes) of the first RHS. 
#' Inconsistent lengths or types will generate an error.
#'
#' @import dplyr 
#' @import rlang
#' @import purrr
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when <- function(..., .strict_mode=TRUE) {
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # Allow case_when to do its thing!
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  res <- dplyr::case_when(...)
  
  if (.strict_mode) {
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # If `case_when` runs ok, then it means I can make the assumption that all
    # its input arguments are well-formed formulas, with a proper LHS
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    formulas <- rlang::list2(...)
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Count how many times each input is matched by a rule.
    # Have to evaluate this in the right environment.
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    match_counts <- formulas %>%
      purrr::map(~rlang::eval_bare(rlang::f_lhs(.x), env = environment(.x))) %>%
      purrr::reduce(`+`)
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Check for fall-through 'TRUE' value
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    lhs <- formulas %>% purrr::map(rlang::f_lhs)
    if (any(purrr::map_lgl(lhs, isTRUE))) {
      stop("strict_case_when(): fall-through 'TRUE' is not allowed", call. = FALSE)
    }
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Check if an input values unmatched
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if (any(match_counts == 0L)) {
      stop("strict_case_when(): no matches found at the following input indices: ", 
           deparse(which(match_counts == 0L)), call. = FALSE) 
    }
    
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Check if any input values matched multiple times
    #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if (any(match_counts > 1L)) {
      stop("strict_case_when(): multiple matches found at the following input indices: ", 
           deparse(which(match_counts > 1L)), call. = FALSE) 
    }
  }
  
  
  res
}

Using strict_case_when()

Here again are the 3 test cases from earlier. This shows how strict_case_when() appropriately throws errors for each of the issues.

animal <- c('cat', 'dog', 'eagle')


#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# A good example. Shouldn't throw error.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
[1] "mammal" "mammal" "bird"  
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Issue #1 handled - no 'TRUE' LHS allowed
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'catt' ~ 'mammal',
  animal == 'dog'  ~ 'mammal',
  TRUE             ~ 'bird'
)
Error: strict_case_when(): fall-through 'TRUE' is not allowed
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Issue #2 handled - catch unmatched values
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'cat'   ~ 'mammal',
  animal == 'dogg'  ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
Error: strict_case_when(): no matches found at the following input indices: 2L
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Issue #3 handled - multiple matches raise an error
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'eagle' ~ 'insect',
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
Error: strict_case_when(): multiple matches found at the following input indices: 3L

Future

  • Awaiting someone to point out that this is equivalent to a base R function I’ve never heard of before… ;)
  • Waiting to see if anyone can clarify this github issue