JAX(DeviceArray)

JAX是一个高性能机器学习库JAX在加速器(例如GPUTPU)上编译并运行NumPy代码。您可以使用JAX(以及为JAX构建的神经网络库FLAX)来构建和训练深度学习模型。

JAX是什么?

JAX是一个特别适合机器学习研究的框架。关于JAX的几点:

  • 它就像numpy一样,但使用编译器(XLA)编译本机Numpy代码,并在加速器(GPU/TPU)上运行。
  • 对于自动微分JAX使用Autograd。它自动区分原生PythonNumpy代码。
  • JAX用于将数值程序表示为组合,但具有某些约束,例如JAX转换和编译设计为仅适用于纯函数的Python函数。如果一个函数在使用相同的参数调用时始终返回相同的值,则该函数是纯函数,并且该函数没有副作用,例如,改变非局部变量的状态。
  • 就语法而言,JAXnumpy非常相似,但您应该注意一些细微的差异。

让我们举几个例子来看看JAX的实际应用!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import time

import jax
import numpy as np
import jax.numpy as jnp
from jax import random

# We will create two arrays, one with numpy and other with jax
# to check the common things and the differences
array_numpy = np.arange(10, dtype=np.int32)
array_jax = jnp.arange(10, dtype=jnp.int32)

print("Array created using numpy: ", array_numpy)
print("Array created using JAX: ", array_jax)

# What types of array are these?
print(f"array_numpy is of type : {type(array_numpy)}")
print(f"array_jax is of type : {type(array_jax)}")

结果输出为:

1
2
3
4
5
Array created using numpy:  [0 1 2 3 4 5 6 7 8 9]
Array created using JAX: [0 1 2 3 4 5 6 7 8 9]

array_numpy is of type : <class 'numpy.ndarray'>
array_jax is of type : <class 'jaxlib.xla_extension.DeviceArray'>

array_numpyndarray对象,而array_jaxDeviceArray对象。

DeviceArray

有关DeviceArray的几点:

  • 它是JAX数组对象的核心,与ndarray类似,但有细微的差别。
  • ndarray不同,DeviceArray由单个设备(CPU/GPU/TPU)上的内存缓冲区支持。
  • 它与设备无关,即JAX不需要跟踪阵列所在的设备,并且可以避免数据传输。
  • 由于它与设备无关,因此可以轻松在CPU、GPUTPU上运行相同的JAX代码,而无需更改代码。
  • DeviceArray是惰性的,即JAX DeviceArray的值不会立即可用,并且仅在请求时才拉取。
  • 尽管DeviceArray是惰性的,您仍然可以执行诸如检查DeviceArray的形状或类型之类的操作,而无需等待生成它的计算完成。我们甚至可以将其传递给另一个JAX计算。

延迟计算和与设备无关的两个属性为DeviceArray提供了巨大的优势。

Numpy vs JAX-numpy

jax numpyAPI方面与numpy非常相似。您在numpy中执行的大多数操作也可以在jax numpy中使用,具有类似的语义。我只是列出了一些操作来展示这一点,但还有更多操作。注意:并非所有Numpy函数都在JAX numpy中实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Find the max element. Similarly you can find `min` as well
print(f"Maximum element in ndarray: {array_numpy.max()}")
print(f"Maximum element in DeviceArray: {array_jax.max()}")

# Reshaping
print("Original shape of ndarray: ", array_numpy.shape)
print("Original shape of DeviceArray: ", array_jax.shape)

array_numpy = array_numpy.reshape(-1, 1)
array_jax = array_jax.reshape(-1, 1)

print("\nNew shape of ndarray: ", array_numpy.shape)
print("New shape of DeviceArray: ", array_jax.shape)

# Absoulte pairwise difference
print("Absoulte pairwise difference in ndarray")
print(np.abs(array_numpy - array_numpy.T))

print("\nAbsoulte pairwise difference in DeviceArray")
print(jnp.abs(array_jax - array_jax.T))

# Are they equal?
print("\nAre all the values same?", end=" ")
print(jnp.alltrue(np.abs(array_numpy - array_numpy.T) == jnp.abs(array_jax - array_jax.T)))

# Matrix multiplication
print("Matrix multiplication of ndarray")
print(np.dot(array_numpy, array_numpy.T))

print("\nMatrix multiplication of DeviceArray")
print(jnp.dot(array_jax, array_jax.T))

结果输出为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
Maximum element in ndarray: 9
Maximum element in DeviceArray: 9

Original shape of ndarray: (10,)
Original shape of DeviceArray: (10,)
New shape of ndarray: (10, 1)
New shape of DeviceArray: (10, 1)

Absoulte pairwise difference in ndarray
[[0 1 2 3 4 5 6 7 8 9]
[1 0 1 2 3 4 5 6 7 8]
[2 1 0 1 2 3 4 5 6 7]
[3 2 1 0 1 2 3 4 5 6]
[4 3 2 1 0 1 2 3 4 5]
[5 4 3 2 1 0 1 2 3 4]
[6 5 4 3 2 1 0 1 2 3]
[7 6 5 4 3 2 1 0 1 2]
[8 7 6 5 4 3 2 1 0 1]
[9 8 7 6 5 4 3 2 1 0]]
Absoulte pairwise difference in DeviceArray
[[0 1 2 3 4 5 6 7 8 9]
[1 0 1 2 3 4 5 6 7 8]
[2 1 0 1 2 3 4 5 6 7]
[3 2 1 0 1 2 3 4 5 6]
[4 3 2 1 0 1 2 3 4 5]
[5 4 3 2 1 0 1 2 3 4]
[6 5 4 3 2 1 0 1 2 3]
[7 6 5 4 3 2 1 0 1 2]
[8 7 6 5 4 3 2 1 0 1]
[9 8 7 6 5 4 3 2 1 0]]
Are all the values same? True

Matrix multiplication of ndarray
[[ 0 0 0 0 0 0 0 0 0 0]
[ 0 1 2 3 4 5 6 7 8 9]
[ 0 2 4 6 8 10 12 14 16 18]
[ 0 3 6 9 12 15 18 21 24 27]
[ 0 4 8 12 16 20 24 28 32 36]
[ 0 5 10 15 20 25 30 35 40 45]
[ 0 6 12 18 24 30 36 42 48 54]
[ 0 7 14 21 28 35 42 49 56 63]
[ 0 8 16 24 32 40 48 56 64 72]
[ 0 9 18 27 36 45 54 63 72 81]]

Matrix multiplication of DeviceArray
[[ 0 0 0 0 0 0 0 0 0 0]
[ 0 1 2 3 4 5 6 7 8 9]
[ 0 2 4 6 8 10 12 14 16 18]
[ 0 3 6 9 12 15 18 21 24 27]
[ 0 4 8 12 16 20 24 28 32 36]
[ 0 5 10 15 20 25 30 35 40 45]
[ 0 6 12 18 24 30 36 42 48 54]
[ 0 7 14 21 28 35 42 49 56 63]
[ 0 8 16 24 32 40 48 56 64 72]
[ 0 9 18 27 36 45 54 63 72 81]]

现在,让我们看一下在Numpy中可以执行但在Jax-numpy中不能执行的一些操作,反之亦然。

不变性(Immutability)

JAX数组是不可变的,就像TensorFlow张量一样。这意味着,JAX数组不支持像ndarray中那样的项目分配。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
array1 = np.arange(5, dtype=np.int32)
array2 = jnp.arange(5, dtype=jnp.int32)

print("Original ndarray: ", array1)
print("Original DeviceArray: ", array2)

# Item assignment
array1[4] = 10
print("\nModified ndarray: ", array1)
print("\nTrying to modify DeviceArray-> ", end=" ")

try:
array2[4] = 10
print("Modified DeviceArray: ", array2)
except Exception as ex:
print(type(ex).__name__, ex)

结果输出为:

1
2
3
4
5
Original ndarray:  [0 1 2 3 4]
Original DeviceArray: [0 1 2 3 4]
Modified ndarray: [ 0 1 2 3 10]

Trying to modify DeviceArray-> TypeError '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?

这种情况与我们使用TensorFlow Tensors时的情况完全相同。与TensorFlow中的tf.tensor_scatter_nd_update类似,我们有索引更新运算符(之前曾经有jax.ops.index_update(..)但现在已弃用)。语法非常简单,例如DeviceArray.at[idx].op(val)。但这不会修改原始数组,而是返回一个新数组,其中元素已按指定更新一个自然而然地浮现在脑海中的问题?为什么是不变性?问题是JAX依赖于纯函数。允许项目分配或就地更新与该理念相反。但是为什么TF张量是不可变的,因为它不需要纯函数?如果您要对DAG进行任何优化,强烈建议避免更改计算中使用的操作的状态,以避免任何副作用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Modifying DeviceArray elements at specific index/indices
array2_modified = array2.at[4].set(10)

# Equivalent => array2_modified = jax.ops.index_update(array2, 4, 10)
print("Original DeviceArray: ", array2)
print("Modified DeviceArray: ", array2_modified)

# Of course, updates come in many forms!
print(array2.at[4].add(6))
print(array2.at[4].max(20))
print(array2.at[4].min(-1))

# Equivalent but depecated. Just to showcase the similarity to tf scatter_nd_update
print("\nEquivalent but deprecatd")
print(jax.ops.index_add(array2, 4, 6))
print(jax.ops.index_max(array2, 4, 20))
print(jax.ops.index_min(array2, 4, -1))

异步调度

ndarraysDeviceArrays之间最大的区别之一在于它们的执行力和可用性。JAX使用异步调度来隐藏Python开销。

1
2
3
4
5
6
7
8
# Create two random arrays sampled from a uniform distribution
array1 = np.random.uniform(size=(8000, 8000)).astype(np.float32)
array2 = jax.random.uniform(jax.random.PRNGKey(0), (8000, 8000), dtype=jnp.float32) # More on PRNGKey later!
print("Shape of ndarray: ", array1.shape)
print("Shape of DeviceArray: ", array2.shape)

# Shape of ndarray: (8000, 8000)
# Shape of DeviceArray: (8000, 8000)

现在,让我们对每个数组进行一些计算,看看会发生什么以及每个计算需要多少时间。

1
2
3
4
5
6
7
8
9
10
11
12
# Dot product on ndarray
start_time = time.time()
res = np.dot(array1, array1)
print(f"Time taken by dot product op on ndarrays: {time.time()-start_time:.2f} seconds")

# Dot product on DeviceArray
start_time = time.time()
res = jnp.dot(array2, array2)
print(f"Time taken by dot product op on DeviceArrays: {time.time()-start_time:.2f} seconds")

# Time taken by dot product op on ndarrays: 7.95 seconds
# Time taken by dot product op on DeviceArrays: 0.02 seconds

看来DeviceArray计算很快就完成了。

  • ndarray的结果不同,DeviceArray上完成的计算结果尚不可用。这是加速器上可用的未来值。
  • 您可以通过打印或将其转换为普通的旧numpy ndarray来检索此计算的值。
  • DeviceArray的计时是调度工作所花费的时间,而不是实际计算所花费的时间。
  • 异步调度很有用,因为它允许Python代码“运行在加速器设备之前”,从而使Python代码远离关键路径。如果Python代码在设备上排队的速度比执行速度快,并且Python代码实际上不需要检查主机上计算的输出,则Python程序可以将任意数量的工作排队并避免使用加速器等待。
  • 要衡量任何此类操作的真实成本:将其转换为普通numpy ndarray(不推荐);使用block_until_ready()等待它的计算完成(基准测试的首选方法)。

Types promotion

这是需要牢记的另一个方面。与numpy相比,JAX中的dtype提升不那么激进:

  • 在提升Python标量时,JAX始终更喜欢JAX值的精度。
  • 在针对浮点或复杂类型提升整数或布尔类型时,JAX始终优先选择浮点或复杂类型。
  • JAX使用更适合GPU/TPU等现代加速器设备的浮点提升规则。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
print("Types promotion in numpy =>", end=" ")
print((np.int8(32) + 4).dtype)

print("Types promtoion in JAX =>", end=" ")
print((jnp.int8(32) + 4).dtype)

# Types promotion in numpy => int64
# Types promtoion in JAX => int8

array1 = np.random.randint(5, size=(2), dtype=np.int32)
print("Implicit numpy casting gives: ", (array1 + 5.0).dtype)

# Check the difference in semantics of the above function in JAX
array2 = jax.random.randint(jax.random.PRNGKey(0),
minval=0,
maxval=5,
shape=[2],
dtype=jnp.int32
)
print("Implicit JAX casting gives: ", (array2 + 5.0).dtype)

# Implicit numpy casting gives: float64
# Implicit JAX casting gives: float32

自动微分(Automatic Differentiation)

自动微分是我最喜欢讨论的主题之一。与我们在TensorFlow中深入介绍AD的方式类似,这里我们将通过一个简单的示例来了解它与JAX的集成有多紧密。

1
2
3
4
5
6
7
8
9
def squared(x):
return x**2

x = 4.0
y = squared(x)

dydx = jax.grad(squared)
print("First order gradients of y wrt x: ", dydx(x))
print("Second order gradients of y wrt x: ", jax.grad(dydx)(x))

结果输出为:

1
2
First order gradients of y wrt x:  8.0
Second order gradients of y wrt x: 2.0