JAX(Autodiff)

今天,我们将研究另一个重要概念自动微分。我们已经在TensorFlow中看到了自动微分自动微分的想法在所有框架中都非常相似,但IMO JAX比所有框架都做得更好。

梯度

JAX中的grad函数用于计算梯度。 我们知道JAX背后的基本思想是使用函数组合grad也将可调用对象作为输入并返回可调用对象。因此,每当我们想要计算梯度时,我们需要首先将可调用对象传递给grad。让我们举个例子来更清楚地说明:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax import vmap, pmap, jit
from jax import grad, value_and_grad
from jax.test_util import check_grads

def product(x, y):
z = x * y
return z

x = 3.0
y = 4.0

z = product(x, y)

print(f"Input Variable x: {x}")
print(f"Input Variable y: {y}")
print(f"Product z: {z}\n")

# dz / dx
dx = grad(product, argnums=0)(x, y)
print(f"Gradient of z wrt x: {dx}")

# dz / dy
dy = grad(product, argnums=1)(x, y)
print(f"Gradient of z wrt y: {dy}")

# Input Variable x: 3.0
# Input Variable y: 4.0
# Product z: 12.0
# Gradient of z wrt x: 4.0
# Gradient of z wrt y: 3.0

让我们分解上面的例子并尝试逐步理解梯度计算。

  • 我们有一个名为Product(...)的函数,它接受两个位置参数作为输入并返回这些参数的乘积。
  • 我们将Product(...)函数传递给grad来计算梯度。grad中的argnums参数告诉grad区分函数与位置参数。因此,我们通过01来相应地计算xy的梯度。

您还可以一次性计算函数值和梯度。为此,我们将使用value_and_grad(...)函数。

1
2
3
4
5
6
z, dx = value_and_grad(product, argnums=0)(x, y)
print("Product z:", z)
print(f"Gradient of z wrt x: {dx}")

# Product z: 12.0
# Gradient of z wrt x: 4.0

Jaxprs 和 grad

由于我们可以在JAX中组合函数转换,因此我们可以从grad函数生成jaxprs来了解幕后发生的情况。举个例子:

1
2
3
4
5
6
7
8
# Differentiating wrt first positional argument `x`
print("Differentiating wrt x")
print(make_jaxpr(grad(product, argnums=0))(x, y))


# Differentiating wrt second positional argument `y`
print("\nDifferentiating wrt y")
print(make_jaxpr(grad(product, argnums=1))(x, y))

结果输出为:

1
2
3
4
5
Differentiating wrt x
{ lambda ; a:f32[] b:f32[]. let _:f32[] = mul a b; c:f32[] = mul 1.0 b in (c,) }

Differentiating wrt y
{ lambda ; a:f32[] b:f32[]. let _:f32[] = mul a b; c:f32[] = mul a 1.0 in (c,) }

请注意,除我们要微分的1之外的参数是值为1的常数。

停止梯度计算

有时我们不希望梯度流过特定计算中涉及的某些变量。在这种情况下,我们需要明确告诉JAX我们不希望梯度流经指定的变量集。稍后我们将研究这方面的复杂示例,但现在,我将修改我们的Product(...)函数,其中我们不希望梯度流经y

1
2
3
4
5
6
7
8
9
10
# Modified product function. Explicity stopping the
# flow of the gradients through `y`
def product_stop_grad(x, y):
z = x * jax.lax.stop_gradient(y)
return z

# Differentiating wrt y. This should return 0
grad(product_stop_grad, argnums=1)(x, y)

# DeviceArray(0., dtype=float32)

每个样本的梯度

反向模式下,仅为输出标量的函数定义梯度,例如反向传播损失值以更新机器学习模型的参数。损失始终是标量值。如果您的函数返回一个批次并且您想要计算该批次的每个样本的梯度该怎么办?
这些在JAX中非常简单。

  • 编写一个接受输入应用tanh的函数。
  • 我们将检查是否可以计算单个示例的梯度。
  • 我们将传递一批输入并计算整批的梯度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def activate(x):
"""Applies tanh activation."""
return jnp.tanh(x)

# Check if we can compute the gradients for a single example
grads_single_example = grad(activate)(0.5)
print("Gradient for a single input x=0.5: ", grads_single_example)

# Now we will generate a batch of random inputs, and will pass
# those inputs to our activate function. And we will also try to
# calculate the grads on the same batch in the same way as above
# Always use the PRNG
key = random.PRNGKey(1234)
x = random.normal(key=key, shape=(5,))
activations = activate(x)

print("\nTrying to compute gradients on a batch")
print("Input shape: ", x.shape)
print("Output shape: ", activations.shape)

try:
grads_batch = grad(activate)(x)
print("Gradients for the batch: ", grads_batch)
except Exception as ex:
print(type(ex).__name__, ex)

结果输出为:

1
2
3
4
5
Gradient for a single input x=0.5:  0.7864477
Trying to compute gradients on a batch
Input shape: (5,)
Output shape: (5,)
TypeError Gradient only defined for scalar-output functions. Output had shape: (5,).

那么解决办法是什么呢?vmappmap是几乎所有问题的解决方案,让我们看看它的实际效果:

1
2
3
4
grads_batch = vmap(grad(activate))(x)
print("Gradients for the batch: ", grads_batch)

# Gradients for the batch: [0.48228705 0.45585024 0.99329686 0.0953269 0.8153717 ]

让我们分解一下我们上面为达到预期结果所做的所有修改。

  • grad(activate)(...)适用于单个示例。
  • 添加vmap组合为我们的输入和输出添加批量维度(默认为0)。

从单个示例到批量示例就这么简单,反之亦然。您所需要的只是专注于使用vmap。让我们看看这个转换的jaxpr是什么样子的。

1
2
3
4
5
6
7
8
9
make_jaxpr(vmap(grad(activate)))(x)

# { lambda ; a.
# let b = tanh a
# c = sub 1.0 b
# d = mul 1.0 c
# e = mul d b
# f = add_any d e
# in (f,) }

其他变换的组合

我们可以将任何其他转换与grad组合起来。我们已经看到vmapgrad一起使用。让我们将jit应用到上面的转换中,以使其更加高效。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
jitted_grads_batch = jit(vmap(grad(activate)))

for _ in range(3):
start_time = time.time()
print("Gradients for the batch: ", jitted_grads_batch(x))
print(f"Time taken: {time.time() - start_time:.2f} seconds")
print("="*50)
print()

# Gradients for the batch: [0.48228705 0.45585027 0.99329686 0.09532695 0.8153717 ]
# Time taken: 0.03 seconds
# ==================================================

# Gradients for the batch: [0.48228705 0.45585027 0.99329686 0.09532695 0.8153717 ]
# Time taken: 0.00 seconds
# ==================================================

# Gradients for the batch: [0.48228705 0.45585027 0.99329686 0.09532695 0.8153717 ]
# Time taken: 0.00 seconds
# ==================================================

验证有限差分

很多时候,我们想要用有限差分来验证梯度的计算,以再次检查我们所做的一切是否正确。因为这是处理导数时非常常见的健全性检查,所以JAX提供了一个方便的函数check_grads来检查任意阶梯度的有限差分。让我们来看看:

1
2
3
4
5
6
7
try:
check_grads(jitted_grads_batch, (x,), order=1)
print("Gradient match with gradient calculated using finite differences")
except Exception as ex:
print(type(ex).__name__, ex)

# Gradient match with gradient calculated using finite differences

高阶梯度

grad函数接受一个可调用函数作为输入并返回另一个函数。我们可以一次又一次地将变换返回的函数与grad组合起来,以计算任意阶的高阶导数。让我们举一个例子来看看它的实际效果。我们将使用activate(...)函数来演示这一点。

1
2
3
4
5
6
7
8
9
x = 0.5

print("First order derivative: ", grad(activate)(x))
print("Second order derivative: ", grad(grad(activate))(x))
print("Third order derivative: ", grad(grad(grad(activate)))(x))

# First order derivative: 0.7864477
# Second order derivative: -0.726862
# Third order derivative: -0.5652091

梯度和数值稳定性

下溢溢出是我们多次遇到的常见问题,尤其是在计算梯度时。我们将举一个例子(这个例子直接来自JAX文档,这是一个非常好的例子)来说明我们如何遇到数值不稳定以及JAX如何尝试帮助您克服它。当您计算某个值的梯度时会发生什么?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# An example of a mathematical operation in your workflow
def log1pexp(x):
"""Implements log(1 + exp(x))"""
return jnp.log(1. + jnp.exp(x))

# This works fine
print("Gradients for a small value of x: ", grad(log1pexp)(5.0))

# But what about for very large values of x for which the
# exponent operation will explode
print("Gradients for a large value of x: ", grad(log1pexp)(500.0))

# Gradients for a small value of x: 0.9933072

# Gradients for a large value of x: nan

刚刚发生了什么?让我们对其进行分解,以了解预期的输出以及返回nanJAX幕后gpoing的内容。我们知道上述函数的导数可以写成这样:

对于非常大的值,您会期望导数的值为1,但是当我们将grad与我们的函数实现结合起来时,它返回nan。为了获得更多信息,我们可以通过查看转换的jaxpr来分解梯度计算。

1
2
3
4
5
6
7
8
9
make_jaxpr(grad(log1pexp))(500.0)

# { lambda ; a.
# let b = exp a
# c = add b 1.0
# _ = log c
# d = div 1.0 c
# e = mul d b
# in (e,) }

如果您仔细观察,您会发现计算等效于:

对于较大的值,右侧的项将四舍五入为inf,并且梯度计算将返回nan,如我们在上面看到的。在这种情况下,我们知道如何正确计算梯度,但JAX不知道。它正在研究标准自动差异规则。那么,我们如何告诉JAX,我们的函数应该按照我们想要的方式进行区分呢?我们可以使用JAX中的custom_vjpcustom_vjp函数来实现这一点。让我们看看它的实际效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from jax import custom_jvp

@custom_jvp
def log1pexp(x):
"""Implements log(1 + exp(x))"""
return jnp.log(1. + jnp.exp(x))

@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
"""Tells JAX to differentiate the function in the way we want."""
x, = primals
x_dot, = tangents
ans = log1pexp(x)
# This is where we define the correct way to compute gradients
ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
return ans, ans_dot

# Let's now compute the gradients for large values
print("Gradients for a small value of x: ", grad(log1pexp)(500.0))

# What about the Jaxpr?
make_jaxpr(grad(log1pexp))(500.0)

# Gradients for a small value of x: 1.0
# { lambda ; a.
# let _ = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a.
# let b = exp a
# c = add b 1.0
# d = log c
# in (d,) }
# jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f79cc3f2dd0>
# num_consts=0 ] a
# b = exp a
# c = add b 1.0
# d = div 1.0 c
# e = sub 1.0 d
# f = mul e 1.0
# in (f,) }

让我们分解一下步骤。

  • 我们用计算雅可比向量积(前向模式)的custom_vjp装饰了log1pexp(...)
  • 然后我们定义了log1pexp_jvp(...)来计算梯度。重点关注该函数中的这行代码:ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot。简单来说,我们所做的就是以这种方式重新排列导数

我们用log1pexp.defjvp装饰logp1exp_jvp(...)函数,告诉JAX计算JVP,请使用我们定义的函数并返回预期的输出。