18.5 Using the AST to solve more complicated problems

  • Here we focus on what we learned to perform recursion on the AST.
  • Two parts of a recursive function:
    • Recursive case: handles the nodes in the tree. Typically, you’ll do something to each child of a node, usually calling the recursive function again, and then combine the results back together again. For expressions, you’ll need to handle calls and pairlists (function arguments).
    • Base case: handles the leaves of the tree. The base cases ensure that the function eventually terminates, by solving the simplest cases directly. For expressions, you need to handle symbols and constants in the base case.

18.5.1 Two helper functions

  • First, we need an epxr_type() function to return the type of expression element as a string.
expr_type <- function(x) {
  if (rlang::is_syntactic_literal(x)) {
    "constant"
  } else if (is.symbol(x)) {
    "symbol"
  } else if (is.call(x)) {
    "call"
  } else if (is.pairlist(x)) {
    "pairlist"
  } else {
    typeof(x)
  }
}
expr_type(expr("a"))
#> [1] "constant"
expr_type(expr(x))
#> [1] "symbol"
expr_type(expr(f(1, 2)))
#> [1] "call"
  • Second, we need a wrapper function to handle exceptions.
switch_expr <- function(x, ...) {
  switch(expr_type(x),
    ...,
    stop("Don't know how to handle type ", typeof(x), call. = FALSE)
  )
}
  • Lastly, we can write a basic template that walks the AST using the switch() statement.
recurse_call <- function(x) {
  switch_expr(x,
    # Base cases
    symbol = ,
    constant = ,

    # Recursive cases
    call = ,
    pairlist =
  )
}

18.5.2 Specific use cases for recurse_call()

18.5.3 Example 1: Finding F and T

  • Using F and T in our code rather than FALSE and TRUE is bad practice.
  • Say we want to walk the AST to find times when we use F and T.
  • Start off by finding the type of T vs TRUE.
expr_type(expr(TRUE))
#> [1] "constant"

expr_type(expr(T))
#> [1] "symbol"
  • With this knowledge, we can now write the base cases of our recursive function.
  • The logic is as follows:
    • A constant is never a logical abbreviation and a symbol is an abbreviation if it is “F” or “T”:
logical_abbr_rec <- function(x) {
  switch_expr(x,
    constant = FALSE,
    symbol = as_string(x) %in% c("F", "T")
  )
}
logical_abbr_rec(expr(TRUE))
#> [1] FALSE
logical_abbr_rec(expr(T))
#> [1] TRUE
  • It’s best practice to write another wrapper, assuming every input you receive will be an expression.
logical_abbr <- function(x) {
  logical_abbr_rec(enexpr(x))
}

logical_abbr(T)
#> [1] TRUE
logical_abbr(FALSE)
#> [1] FALSE

18.5.3.1 Next step: code for the recursive cases

  • Here we want to do the same thing for calls and for pairlists.
  • Here’s the logic: recursively apply the function to each subcomponent, and return TRUE if any subcomponent contains a logical abbreviation.
  • This is simplified by using the purrr::some() function, which iterates over a list and returns TRUE if the predicate function is true for any element.
logical_abbr_rec <- function(x) {
  switch_expr(x,
  # Base cases
  constant = FALSE,
  symbol = as_string(x) %in% c("F", "T"),
  # Recursive cases
  call = ,
  # Are we sure this is the correct function to use?
  # Why not logical_abbr_rec?
  pairlist = purrr::some(x, logical_abbr_rec)
  )
}

logical_abbr(mean(x, na.rm = T))
#> [1] TRUE

logical_abbr(function(x, na.rm = T) FALSE)
#> [1] TRUE

18.5.4 Example 2: Finding all variables created by assignment

  • Listing all the variables is a little more complicated.
  • Figure out what assignment looks like based on the AST.
ast(x <- 10)
#> █─`<-` 
#> ├─x 
#> └─10
  • Now we need to decide what data structure we’re going to use for the results.
    • Easiest thing will be to return a character vector.
    • We would need to use a list if we wanted to return symbols.

18.5.5 Dealing with the base cases

find_assign_rec <- function(x) {
  switch_expr(x,
    constant = ,
    symbol = character()
  )
}
find_assign <- function(x) find_assign_rec(enexpr(x))

find_assign("x")
#> character(0)

find_assign(x)
#> character(0)

18.5.6 Dealing with the recursive cases

  • Here is the function to flatten pairlists.
flat_map_chr <- function(.x, .f, ...) {
  purrr::flatten_chr(purrr::map(.x, .f, ...))
}

flat_map_chr(letters[1:3], ~ rep(., sample(3, 1)))
#> [1] "a" "a" "b" "b" "c"
  • Here is the code needed to identify calls.
find_assign_rec <- function(x) {
  switch_expr(x,
    # Base cases
    constant = ,
    symbol = character(),

    # Recursive cases
    pairlist = flat_map_chr(as.list(x), find_assign_rec),
    call = {
      if (is_call(x, "<-")) {
        as_string(x[[2]])
      } else {
        flat_map_chr(as.list(x), find_assign_rec)
      }
    }
  )
}

find_assign(a <- 1)
#> [1] "a"

find_assign({
  a <- 1
  {
    b <- 2
  }
})
#> [1] "a" "b"

18.5.7 Make the function more robust

  • Throw cases at it that we think might break the function.
  • Write a function to handle these cases.
find_assign_call <- function(x) {
  if (is_call(x, "<-") && is_symbol(x[[2]])) {
    lhs <- as_string(x[[2]])
    children <- as.list(x)[-1]
  } else {
    lhs <- character()
    children <- as.list(x)
  }

  c(lhs, flat_map_chr(children, find_assign_rec))
}

find_assign_rec <- function(x) {
  switch_expr(x,
    # Base cases
    constant = ,
    symbol = character(),

    # Recursive cases
    pairlist = flat_map_chr(x, find_assign_rec),
    call = find_assign_call(x)
  )
}

find_assign(a <- b <- c <- 1)
#> [1] "a" "b" "c"

find_assign(system.time(x <- print(y <- 5)))
#> [1] "x" "y"
  • This approach certainly is more complicated, but it’s important to start simple and move up.