Automatically Flatten & Unflatten Nested Containers
This post is about a small functionality that is found useful in TensorFlow / JAX / PyTorch.
Low-level components of these systems often use a plain list of values/tensors
as inputs & outputs.
However, end-users that develop models often want to work with more
complicated data structures:
Dict[str, Any], List[Any], custom classes, and their nested combinations.
Therefore, we need bidirectional conversion between nested structures and a plain list of tensors.
I found that different libraries invent similar approaches to solve this problem, and it's interesting to list them here.