A Brief Introduction to PyTraverse

The goal of this notebook is to showcase the key aspects of PyTraverse and how they can be used to rewrite datastructures.

1. Hello World Traverser

Let’s start with the simplest possible traverser: On that does not actually traverse anything:

import pytraverse as t


@t.traverser
def my_traverser(s: str) -> str:
    return s.upper()


s = "hello world"
print(t.traverse(s, my_traverser))
HELLO WORLD

The core function offered by the pytraverse module is the traverse function. It takes an object to traverse and a traverser that should be applied to it. Here, the traverser just calls the .upper() method of the passed-in string.

Note the @t.traverser decorator. It converts our simple object mapper to a “proper” traverser - we will discuss later what this means. For now, just remember to annotate your traverser functions with this decorator.

2. Down to Business

The previous example was maybe a bit boring. After all, we did not traverse anything. Let’s change that now…

Consider the problem of computing the sum of all the integers in a (nested) datastructure:

data = [[1, 2, [3, [4, 5], 6], 7, 8], 9]
# We want to compute 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 = 45 here...

The main idea behind the traverser library is to abstract away the recursion, so that you can focus on the semantics of the task at hand.

Let’s start by solving this task using a standard recursive implementation:

def recursive_sum(x: object) -> int:
    if isinstance(x, list):
        s = sum([recursive_sum(item) for item in x])
        return s
    return x


print(recursive_sum(data))
45

That was easy enough. Now, let’s rewrite this to use the pytraverse library:

from collections.abc import Callable


@t.traverser
def sum_traverser(x: object, traverse: Callable[[object], int]) -> int:
    if isinstance(x, list):
        s = sum([traverse(item) for item in x])
        return s
    return x


print(t.traverse(data, sum_traverser))
45

Not so different… indeed, the code is mostly identical to the previous one. The main difference is, that the function no longer explicitly calls itself. Instead it receives a traverse function now, which facilitates the recursive calls.

For now, this second implementation is not any simpler than the one above. The main advantage of this decoupling of the recursive caller and the callee only becomes apparent once we introduce the main superpower of the traverser module: Composition.

3. Let’s compose…

In the spirit of separation of concerns, it might not always be desirable (or possible) to write a single large recursive function which processes your data to your heart’s content.

Instead, one might want to decouple the general logic of traversing (what are the children of a datastructure) from the actual processing of the nodes.

For example, let’s assume we want to multiply all numbers in a nested list datastructure by 2. Alternatively, we might want to add 1 to each element, or do something else… Using naive recusive Python, this might look something like this:

def mashed_multiply(x: object) -> object:
    if isinstance(x, list):
        return [mashed_multiply(item) for item in x]
    return x * 2


def mashed_add(x: object) -> object:
    if isinstance(x, list):
        return [mashed_add(item) for item in x]
    return x + 1


def mashed_exp(x: object) -> object:
    if isinstance(x, list):
        return [mashed_exp(item) for item in x]
    return 2**x


print("Data:", data)
print("Mul:", mashed_multiply(data))
print("Add:", mashed_add(data))
print("Pow:", mashed_exp(data))
Data: [[1, 2, [3, [4, 5], 6], 7, 8], 9]
Mul: [[2, 4, [6, [8, 10], 12], 14, 16], 18]
Add: [[2, 3, [4, [5, 6], 7], 8, 9], 10]
Pow: [[2, 4, [8, [16, 32], 64], 128, 256], 512]

Note the redundant code in mashed_multiply, mashed_add and mashed_exp. For complicated traversal logic, monolithic recursive processors leads to duplicated, non-extensible and brittle code. Let’s fix this using traversers…

@t.singledispatch_traverser
def data_traverser(x: list, traverse: Callable[[object], object]) -> list:
    return [traverse(item) for item in x]


@t.singledispatch_traverser
def mul_traverser(x: int) -> int:
    return x * 2


@t.singledispatch_traverser
def add_traverser(x: int) -> int:
    return x + 1


@t.singledispatch_traverser
def pow_traverser(x: int) -> int:
    return 2**x


print("Data:", data)
print("Mul:", t.traverse(data, t.sequential(data_traverser, mul_traverser)))
print("Add:", t.traverse(data, t.sequential(data_traverser, add_traverser)))
print("Pow:", t.traverse(data, t.sequential(data_traverser, pow_traverser)))
Data: [[1, 2, [3, [4, 5], 6], 7, 8], 9]
Mul: [[2, 4, [6, [8, 10], 12], 14, 16], 18]
Add: [[2, 3, [4, [5, 6], 7], 8, 9], 10]
Pow: [[2, 4, [8, [16, 32], 64], 128, 256], 512]

Here, two new functions are used:

  1. @t.singledispatch_traverser is used instead of @t.traverser.

  2. t.sequential is used to combine traversers.

First, let’s consider the changed decorator. For now, let’s ignore the details of what this decorator does. Similar to @t.traverser, @t.singledispatch_traverser ensures that the given traverser is a “proper” traverser (whatever that means)… The reason we use it, instead of the regular traverser decorator is, that the traversal logic in each function is now type-dependent.

Note that the implementation in data_traverser only works for lists, while the code in the other traversers only works for numbers. @t.singledispatch_traverser ensures that the traversal logic is only called for compatible types, i.e., it effectively does the isinstance(x, list) checks for us. To do this type matching, the decorator looks at the type annotation of the first argument of the traverser.

Remember: Type annotations are mandatory when using @t.singledispatch_traverser!

Second, consider the t.sequential calls. This function combines the given traverers into a single traverser which applies each of the given traversers in order (left to right).

# This:
t.sequential(data_traverser, mul_traverser)


# is equivalent to this:
@t.traverser
def sequential_traverser(x: object, traverse: Callable[[object], object]) -> object:
    # data_traverser:
    if isinstance(x, list):
        return [traverse(item) for item in x]
    # mul_traverser:
    if isinstance(x, int):
        return x * 2
    # singledispatch_traverser acts like the identity function
    # for other types:
    return x

4. …multiple dispatchers

Singledispatch and sequential composition are already fairly powerful, but there is more to the story.

A fundamental problem of programming is the so-called expession problem.

In the case of data traversal, this problem comes up whenever the type hierarchy of the datatypes we want to work with is not fully known in advance or if it might even change later on. For example, assume we decide that our composed traversers should not only be able to multiply/add/exponentiate integers in nested lists, but also in mixtures of nested lists and dictionaries:

mixed_data = {"a": 1, "b": [2, 3], "c": {"d": 4, "e": 5}}

print("Mixed Data:", mixed_data)
print("Mul:", t.traverse(mixed_data, t.sequential(data_traverser, mul_traverser)))
print("Add:", t.traverse(mixed_data, t.sequential(data_traverser, add_traverser)))
print("Pow:", t.traverse(mixed_data, t.sequential(data_traverser, pow_traverser)))
Mixed Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Mul: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Add: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Pow: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}

Since data_traverser only knows how to deal with lists, we cannot process mixed_data.

Naively, we could just rewrite the data_traverser to also deal with dict inputs. But what if we do not want to do this for some reason? 🤨

[Why?] Before we continue, let’s justify why we would not just rewrite data_traverser. Assume, we decide that we want to extend the data_traverser to also work on the elements of tables contained in PDF files, i.e., what if we want to multiply all numbers contained in a table inside a PDF by 2. A bit odd, but why not… To support this very niche use-case, every user of data_traverser would not only have to load the basic list and dict traverser code but also a huge blob of PDF parsing code they might not even care about. Wouldn’t it be nice if the user could decide which data structures should be traversable by data_traverser and then only load the necessary code?

Fortunately, our singledispatch_traverser-based data_traverser already supports the extensibility we need via a technique called multiple dispatch. This means that we can extend the behavior of our traverser like this:

@data_traverser.register
def _(x: dict, traverse: Callable[[object], object]) -> dict:
    return {k: traverse(v) for k, v in x.items()}


print("Mixed Data:", mixed_data)
print("Mul:", t.traverse(mixed_data, t.sequential(data_traverser, mul_traverser)))
print("Add:", t.traverse(mixed_data, t.sequential(data_traverser, add_traverser)))
print("Pow:", t.traverse(mixed_data, t.sequential(data_traverser, pow_traverser)))
Mixed Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Mul: {'a': 2, 'b': [4, 6], 'c': {'d': 8, 'e': 10}}
Add: {'a': 2, 'b': [3, 4], 'c': {'d': 5, 'e': 6}}
Pow: {'a': 2, 'b': [4, 8], 'c': {'d': 16, 'e': 32}}

All we had to do was to register a new dispatch handler with the data_traverser (note the x: dict type hint which enables this!) and all existing users of that traverser benefit from the extended functionality. Cool!

5. Learning to Count

At this point you could just compose and register dispatchers on traversers to your heart’s content. But… there’s more.

So far we just played around in functional wonderland where all traverers independently did their little thing without much care for the rest of the world. Unfortunately, this is not sufficent.

Let’s start with a simple example: What if we want to apply the add/multiply/pow operations from the previous sections only to every second leaf node of a nested datastructure. To do this, our traverser would have to know, where it is in the traversal tree. How do we do this?

Variables!

The traverse module comes with three types of variables:

  1. GlobalVariable

  2. StackVariable

  3. ComputedVariable

To solve the conditional transformation problem described above, we can employ a GlobalVariable. Since showing is better than telling, let’s just see how this works:

LEAF_COUNT = t.GlobalVariable[int]("LEAF_COUNT", default=0)


@data_traverser.register(object)
def leaf_count_traverser(state: t.State) -> t.State:
    state[LEAF_COUNT] += 1
    return state


traversed_data, state = t.traverse_with_state(
    mixed_data,
    t.sequential(data_traverser, mul_traverser),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)
print("Leaf Count:", state[LEAF_COUNT])
Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 2, 'b': [4, 6], 'c': {'d': 8, 'e': 10}}
Leaf Count: 5

First, we define the gloabl int variable LEAF_COUNT. Second, we register a new default dispatcher with the data_traverser. If no other dispatcher matches, i.e., if neigher the previously defined list and dict handlers traverse the current object, the new traverser is executed.

This newly registered leaf_count_traverser is a bit different from the ones we saw before. Instead of an object, it takes a t.State as a parameter. During the traversal such state objects can be used to pass-along and update arbitrary data.

The previously defined LEAF_COUNT variable can be accessed and changed via the state[LEAF_COUNT] syntax. Since LEAF_COUNT is a global variable, any changes to this variable will be visible to all parent and child elements in the traversed structure.

Last, to access the final state after traveral, we call t.travsere_with_state instead of t.traverse. After traversing/processing all five leaf nodes of mixed_data, the final value of LEAF_COUNT is 5 - as we would expect.

Now, let’s use this counter to apply the mul_traverser only to every second node:

def is_even_leaf(state: t.State) -> bool:
    return state[LEAF_COUNT] % 2 == 0


conditional_mul_traverser = t.traverser(mul_traverser, traverse_if=is_even_leaf)

traversed_data = t.traverse(
    mixed_data,
    t.sequential(data_traverser, conditional_mul_traverser),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)
Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 1, 'b': [4, 3], 'c': {'d': 8, 'e': 5}}

To realize the conditional traversal, we make use of another neat feature of the traverser decorator: traverse_if. This optional parameter can be used to disable the resulting traverser given some predicate. Complimentary to traverse_if there is also an analogous skip_if parameter.

6. Too deep

Even/odd counting is nice, but what about depth counting? Let’s now only apply the multiply traverser to nodes that are not at the root level.

DEPTH_COUNT = t.StackVariable[int]("DEPTH_COUNT", default=-1)


@t.traverser
def depth_counter_traverser(state: t.State) -> t.State:
    state[DEPTH_COUNT] += 1
    print(f"DEPTH_COUNT = {state[DEPTH_COUNT]} at object {state.object}")
    return state


depth_conditional_mul_traverser = t.traverser(
    mul_traverser,
    traverse_if=lambda state: state[DEPTH_COUNT] > 1,
)

traversed_data = t.traverse(
    mixed_data,
    t.sequential(
        depth_counter_traverser,
        data_traverser,
        depth_conditional_mul_traverser,
    ),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)
DEPTH_COUNT = 0 at object {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
DEPTH_COUNT = 1 at object 1
DEPTH_COUNT = 1 at object [2, 3]
DEPTH_COUNT = 2 at object 2
DEPTH_COUNT = 2 at object 3
DEPTH_COUNT = 1 at object {'d': 4, 'e': 5}
DEPTH_COUNT = 2 at object 4
DEPTH_COUNT = 2 at object 5
Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 1, 'b': [4, 6], 'c': {'d': 8, 'e': 10}}

Overall, this is pretty similar to what we did before but now we used A StackVariable instead of a GlobalVariable. As the name suggests, updates to stack variables are only visible in downstream traverser calls, not upstream.

… there be dragons

Note that the above example depends on the order of the traversers in t.sequential. Let’s try the following:

traversed_data = t.traverse(
    mixed_data,
    t.sequential(
        data_traverser,
        depth_counter_traverser,  # Depth count increased after data_traverser
        depth_conditional_mul_traverser,
    ),
)

print("Data:", mixed_data)
print("Traversed Data:", traversed_data)
DEPTH_COUNT = 0 at object 1
DEPTH_COUNT = 0 at object 2
DEPTH_COUNT = 0 at object 3
DEPTH_COUNT = 0 at object [2, 3]
DEPTH_COUNT = 0 at object 4
DEPTH_COUNT = 0 at object 5
DEPTH_COUNT = 0 at object {'d': 4, 'e': 5}
DEPTH_COUNT = 0 at object {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}
Traversed Data: {'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': 5}}

If the depth_counter_traverser is executed after the data_traverser, the recursive data traversal will reach the leaf nodes before the DEPTH_COUNT variable is increased. This highlights the importance of ordering composed traversers carefully.