JAX(Pytrees)
Pytrees
什么是Pytree?如果我们按照文档中提供的Pytress定义,那么Pytree是指由类似容器的Python对象构建的树状结构。什么是类似容器的Python对象?看到名称,您可能已经猜到类似容器的Python对象包括列表、元组、字典、namedtuple、OrderedDict和None。这些数据结构是默认的类似容器的对象,被视为Pytree。我们还可以告诉JAX将视为类似容器的对象,但我们需要首先将它们包含在Pytree注册表中。让我们看一下Pytree的几个例子。
1 | import time |
简而言之,Pytree只是节点(类似容器的Python对象)和叶子(所有其他Python对象)的组合。JAX还允许您将自定义类型注册为Pytree。此时您应该问的一个问题是JAX中的核心数据结构DeviceArray是否可以用作Pytree。答案是否定的。任何ndarray都被视为Pytree中的叶子,让我们举个例子来说明这一点。
1 | # Check if we can make a pytree from a DeviceArray |
如您所见,ndarray仅被视为一片叶子。对于数值也是如此。另一个需要注意的重要事项是Pytree是树状数据结构,而不是DAG或图状数据结构。他们假设引用透明,并且不存在引用循环。因此,不要在Pytree的多个叶子中使用相同的对象。我们可以在每一层压平树,得到叶子,以及原始的树结构。让我们看看它的实际效果。
1 | # We will use the `example_2` pytree for this purpose. |
现在我们已经提取了叶子,我们可以修改它们并使用原始树结构(treedef)再次重建树。
注意:我们可以使用tree_map(...)和tree_multimap(...)对叶子进行操作,因为这是更好方法。上面的示例是为了展示您可以通过一种想要更多地控制应用于树的不同叶子的操作的方式来实现这一点。
1 | def change_even_positioned_leaf(x, pos): |
我们现在可以使用原始的树结构来重建具有变换后的叶子的树。我们看到Pytree是类似容器的Python对象,如列表、元组、字典等。但是如果您想扩展这组被视为Pytree节点的Python对象怎么办? 例如,如果您想将您的类视为Pytree节点怎么办?要将一个类视为Pytree节点,我们需要:
- 通过在内部注册表中注册它,告诉
JAX您希望将其视为节点而不是叶子。 - 因为这是一个自定义对象,
JAX不知道如何展开和取消展开它,我们也需要告诉JAX。 - 在某些情况下,我们需要比较两个
treedef结构是否相等。因此,我们需要确保添加自定义对象不会破坏相等性检查。
让我们看一个例子。
1 | from jax.tree_util import register_pytree_node |
结果输出为:
1 | Original Pytree: Counter value = 5 |
注意:定义函数来展开自定义对象时,请重新检查传递给该函数的参数顺序。第一个位置参数始终表示辅助数据,而第二个参数始终表示子数据。请检查上面的flatten_counter(...)作为示例。
我们尚未检查的一件事是我们的Pytree是否可以进行相等性检查。让我们检查一下。
1 | # Another instance |
Pytree和JAX转换
到目前为止,我们讨论了Pytree是什么以及如何创建它,包括使用自定义的类似容器的对象。在我们开始使用Pytree进行一些高级示例之前,让我们退后一步,了解如何将grad、vmap等JAX函数转换应用于pytree。
1 | def activate(x): |
有几点需要注意:
- 我们可以组合任何
JAX转换并将其应用到Pytree。 - 某些转换(例如
vmap和pmap)采用可选参数(例如in_axes和out_axes),这些参数指定应如何处理某些输入或输出值。这些参数也可以是pytree,它们的结构必须与相应参数的Pytree结构相对应。例如,检查如何在上面的示例中传递in_axes的值。
Jaxprs 和 Pytree
与任何其他JAX代码一样,您也可以将jaxprs与Pytree一起使用。例如,我们检查一下上面例子的jaxpr。
1 | make_jaxpr(vmap(grad(activate), in_axes=({"weights":0},)))(example_pytree) |
为什么是Pytree?
到目前为止,我们讨论了Pytree,但我们尚未回答的一个问题:为什么应该了解Pytree?以及使用Pytree的常见用例是什么? 尽管Pytree有很多用例,但最常见的用例是指定模型参数。例如,如果你想构建一个DNN,你可以将每层对应的权重和偏差存储为Pytree。您甚至可以将DNN特征的整个模块定义为Pytree。让我们看一个相同的例子。这取自JAX文档。
1 | key = random.PRNGKey(111) |
1 | def initialize_params(key, dims): |
结果输出为:
1 | Layer 1 => Params shape: {'biases': (128,), 'weights': (1, 128)} |