18.5 Using the AST to solve more complicated problems
- Here we focus on what we learned to perform recursion on the AST.
- Two parts of a recursive function:
- Recursive case: handles the nodes in the tree. Typically, you’ll do something to each child of a node, usually calling the recursive function again, and then combine the results back together again. For expressions, you’ll need to handle calls and pairlists (function arguments).
- Base case: handles the leaves of the tree. The base cases ensure that the function eventually terminates, by solving the simplest cases directly. For expressions, you need to handle symbols and constants in the base case.
18.5.1 Two helper functions
- First, we need an
epxr_type()
function to return the type of expression element as a string.
expr_type <- function(x) {
if (rlang::is_syntactic_literal(x)) {
"constant"
} else if (is.symbol(x)) {
"symbol"
} else if (is.call(x)) {
"call"
} else if (is.pairlist(x)) {
"pairlist"
} else {
typeof(x)
}
}
expr_type(expr("a"))
#> [1] "constant"
expr_type(expr(x))
#> [1] "symbol"
expr_type(expr(f(1, 2)))
#> [1] "call"
- Second, we need a wrapper function to handle exceptions.
switch_expr <- function(x, ...) {
switch(expr_type(x),
...,
stop("Don't know how to handle type ", typeof(x), call. = FALSE)
)
}
- Lastly, we can write a basic template that walks the AST using the
switch()
statement.
18.5.3 Example 1: Finding F and T
- Using
F
andT
in our code rather thanFALSE
andTRUE
is bad practice. - Say we want to walk the AST to find times when we use
F
andT
. - Start off by finding the type of
T
vsTRUE
.
- With this knowledge, we can now write the base cases of our recursive function.
- The logic is as follows:
- A constant is never a logical abbreviation and a symbol is an abbreviation if it is “F” or “T”:
logical_abbr_rec <- function(x) {
switch_expr(x,
constant = FALSE,
symbol = as_string(x) %in% c("F", "T")
)
}
- It’s best practice to write another wrapper, assuming every input you receive will be an expression.
logical_abbr <- function(x) {
logical_abbr_rec(enexpr(x))
}
logical_abbr(T)
#> [1] TRUE
logical_abbr(FALSE)
#> [1] FALSE
18.5.3.1 Next step: code for the recursive cases
- Here we want to do the same thing for calls and for pairlists.
- Here’s the logic: recursively apply the function to each subcomponent, and return
TRUE
if any subcomponent contains a logical abbreviation. - This is simplified by using the
purrr::some()
function, which iterates over a list and returnsTRUE
if the predicate function is true for any element.
logical_abbr_rec <- function(x) {
switch_expr(x,
# Base cases
constant = FALSE,
symbol = as_string(x) %in% c("F", "T"),
# Recursive cases
call = ,
# Are we sure this is the correct function to use?
# Why not logical_abbr_rec?
pairlist = purrr::some(x, logical_abbr_rec)
)
}
logical_abbr(mean(x, na.rm = T))
#> [1] TRUE
logical_abbr(function(x, na.rm = T) FALSE)
#> [1] TRUE
18.5.4 Example 2: Finding all variables created by assignment
- Listing all the variables is a little more complicated.
- Figure out what assignment looks like based on the AST.
- Now we need to decide what data structure we’re going to use for the results.
- Easiest thing will be to return a character vector.
- We would need to use a list if we wanted to return symbols.
18.5.6 Dealing with the recursive cases
- Here is the function to flatten pairlists.
flat_map_chr <- function(.x, .f, ...) {
purrr::flatten_chr(purrr::map(.x, .f, ...))
}
flat_map_chr(letters[1:3], ~ rep(., sample(3, 1)))
#> [1] "a" "a" "b" "b" "c"
- Here is the code needed to identify calls.
find_assign_rec <- function(x) {
switch_expr(x,
# Base cases
constant = ,
symbol = character(),
# Recursive cases
pairlist = flat_map_chr(as.list(x), find_assign_rec),
call = {
if (is_call(x, "<-")) {
as_string(x[[2]])
} else {
flat_map_chr(as.list(x), find_assign_rec)
}
}
)
}
find_assign(a <- 1)
#> [1] "a"
find_assign({
a <- 1
{
b <- 2
}
})
#> [1] "a" "b"
18.5.7 Make the function more robust
- Throw cases at it that we think might break the function.
- Write a function to handle these cases.
find_assign_call <- function(x) {
if (is_call(x, "<-") && is_symbol(x[[2]])) {
lhs <- as_string(x[[2]])
children <- as.list(x)[-1]
} else {
lhs <- character()
children <- as.list(x)
}
c(lhs, flat_map_chr(children, find_assign_rec))
}
find_assign_rec <- function(x) {
switch_expr(x,
# Base cases
constant = ,
symbol = character(),
# Recursive cases
pairlist = flat_map_chr(x, find_assign_rec),
call = find_assign_call(x)
)
}
find_assign(a <- b <- c <- 1)
#> [1] "a" "b" "c"
find_assign(system.time(x <- print(y <- 5)))
#> [1] "x" "y"
- This approach certainly is more complicated, but it’s important to start simple and move up.