strict `case_when` - bugfix

Bugfix - actually handling NAs in input

I previously talked about why I wanted a strict_case_when() a few week’s ago.

I’ve since found a major bug for inputs which contain NA and posting an update to fix this.

This code is available as a github gist

strict_case_when()

  • strict_case_when() is a thin wrapper around dplyr::case_when() which:
    • disallows a fall-through ‘TRUE’ value on the LHS.
    • disallows input values which do not match any rules.
    • disallows input values which match more than one rule
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#' Stricter version of case_when()
#'     - disallows a fall-through 'TRUE' value on the LHS. 
#'     - disallows input values which do not match any rules. 
#'     - disallows input values which match more than one rule
#'
#' @param ... arguments to case_when
#'    
#' @return A vector of length 1 or n, matching the length of the logical input 
#' or output vectors, with the type (and attributes) of the first RHS. 
#' Inconsistent lengths or types will generate an error.
#'
#' @import dplyr 
#' @import rlang
#' @import purrr
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when <- function(...) {
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # Allow case_when to do its thing!
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  res <- dplyr::case_when(...)
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # If `case_when` runs ok, then it means I can make the assumption that all
  # its input arguments are well-formed formulas, with a proper LHS
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  formulas <- rlang::list2(...)
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # Check for fall-through 'TRUE' value
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  lhs <- formulas %>% purrr::map(rlang::f_lhs)
  if (any(purrr::map_lgl(lhs, isTRUE))) {
    stop("strict_case_when(): fall-through 'TRUE' is not allowed", call. = FALSE)
  }
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # Count how many times each input is matched by a rule.
  # Have to evaluate this in the right environment.
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  match_counts <- formulas %>%
    purrr::map(~rlang::eval_bare(rlang::f_lhs(.x), env = environment(.x))) %>%
    purrr::transpose() %>%
    purrr::map(flatten_lgl) %>%
    purrr::map_int(sum, na.rm=TRUE)
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # Check if an input values unmatched
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  if (any(match_counts == 0L)) {
    stop("strict_case_when(): no matches found at the following input indices: ", 
         deparse(which(match_counts == 0L)), call. = FALSE) 
  }
  
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  # Check if any input values matched multiple times
  #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  if (any(match_counts > 1L)) {
    stop("strict_case_when(): multiple matches found at the following input indices: ", 
         deparse(which(match_counts > 1L)), call. = FALSE) 
  }
  
  
  res
}

Using strict_case_when()

The following test cases show how strict_case_when() appropriately throws errors for each of the issues I care about.

A good example. Shouldn’t throw error.

animal <- c('cat', 'dog', 'eagle')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# A good example. Shouldn't throw error.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
[1] "mammal" "mammal" "bird"  

no bare ‘TRUE’ LHS allowed

animal <- c('cat', 'dog', 'eagle')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Issue #1 handled - no 'TRUE' LHS allowed
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'catt' ~ 'mammal',
  animal == 'dog'  ~ 'mammal',
  TRUE             ~ 'bird'
)
Error: strict_case_when(): fall-through 'TRUE' is not allowed

All input values must match one rule

animal <- c('cat', 'dog', 'eagle')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Issue #2 handled - catch unmatched values
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'cat'   ~ 'mammal',
  animal == 'dogg'  ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
Error: strict_case_when(): no matches found at the following input indices: 2L

Input values cannot match more than one rule

animal <- c('cat', 'dog', 'eagle')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Issue #3 handled - multiple matches raise an error
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'eagle' ~ 'insect',
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal',
  animal == 'eagle' ~ 'bird'
)
Error: strict_case_when(): multiple matches found at the following input indices: 3L

Bugfix 20180920 - now handling NAs correctly

animal <- c('cat', 'dog', 'eagle')

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# It's an error if NAs are present but not handled - Bugfix 20180920
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
animal <- c('cat', 'dog', 'eagle', NA)

strict_case_when(
  animal == 'eagle' ~ 'insect',
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal'
)
Error: strict_case_when(): no matches found at the following input indices: 4L
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Now works if NAs handled - Bugfix 20180920
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
strict_case_when(
  animal == 'eagle' ~ 'insect',
  animal == 'cat'   ~ 'mammal',
  animal == 'dog'   ~ 'mammal',
  is.na(animal)     ~ 'unknown'
)
[1] "mammal"  "mammal"  "insect"  "unknown"

Future

  • More bugfixes probably :)