# Always use a seed key = random.PRNGKey(1234) W = random.normal(key=key, shape=[1000, 10000], dtype=jnp.float32)
# Never reuse the key key, subkey = random.split(key) X = random.normal(key=subkey, shape=[10000, 20000], dtype=jnp.float32)
# JIT the functions we have dot_product_jit = jit(get_dot_product) activation_jit = jit(apply_activation)
for i inrange(3): start = time.time() # Don't forget to use `block_until_ready(..)` # else you will be recording dispatch time only Z = dot_product_jit(W, X).block_until_ready() end = time.time() print(f"Iteration: {i+1}") print(f"Time taken to execute dot product: {end - start:.2f} seconds", end="") start = time.time() A = activation_jit(Z).block_until_ready() print(f", activation function: {time.time()-start:.2f} seconds")
结果输出为:
1 2 3 4 5 6
Iteration: 1 Time taken to execute dot product: 6.48 seconds, activation function: 0.05 seconds Iteration: 2 Time taken to execute dot product: 3.17 seconds, activation function: 0.03 seconds Iteration: 3 Time taken to execute dot product: 3.19 seconds, activation function: 0.03 seconds
# Make jaxpr for the activation function print(jax.make_jaxpr(activation_jit)(Z))
结果输出为:
1 2 3 4 5 6
{ lambda ; a.let b = xla_call[ backend=None call_jaxpr={ lambda ; a.let b = max 0.0 a in (b,) } device=None donated_invars=(False,) name=apply_activation ] a in (b,) }
如何解释jaxpr?
第一行告诉您该函数接收一个参数a。
第二行告诉您,这将在XLA上执行,即(0, a)的最大值。
最后一行告诉您返回的输出。
让我们看一下应用点积的函数的jaxpr。
1 2
# Make jaxpr for the activation function print(jax.make_jaxpr(dot_product_jit)(W, X))
结果输出为:
1 2 3 4 5 6 7 8 9 10
{ lambda ; a b. let c = xla_call[ backend=None call_jaxpr={ lambda ; a b. let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None ] a b in (c,) } device=None donated_invars=(False, False) name=get_dot_product ] a b in (c,) }
与上面类似:
第一行告诉函数接收两个输入变量a和b,对应于我们的W和X
第二行是XLA调用,我们在其中执行点操作。(检查点积使用的尺寸)
最后一行是要返回的结果,用c表示
我们再举一个有趣的例子:
1 2 3 4 5 6 7 8 9 10 11 12 13
# We know that `print` introduces but impurity but it is # also very useful to print values while debugging. How does # jaxprs interpret that?
defnumber_squared(num): print("Received: ", num) return num ** 2
# Compiled version number_squared_jit = jit(number_squared)
# Make jaxprs print(jax.make_jaxpr(number_squared_jit)(2))
结果输出为:
1 2 3 4 5 6 7 8 9
Received: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)> { lambda ; a. let b = xla_call[ backend=None call_jaxpr={ lambda ; a. let b = integer_pow[ y=2 ] a in (b,) } device=None donated_invars=(False,) name=number_squared ] a in (b,) }
# Subsequent calls to the jitted function for i, num inenumerate([2, 4, 8]): print("Iteration: ", i+1) print("Result: ", number_squared_jit(num)) print("="*50)
# An impure function (using a global state) defnumber_squared(num): global squared_numbers squared = num ** 2 squared_numbers.append(squared) return squared
# Subsequent calls to the jitted function for i, num inenumerate([4, 8, 16]): print("Iteration: ", i+1) print("Result: ", number_squared_jit(num)) print("="*50) # What's in the list? print("\n Results in the global list") squared_numbers
Results in the global list [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/1)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>]
# Calling the two functions into a single function # so that we can jit this function instead of jitting them defforward_pass(W, X): Z = get_dot_product(W, X) A = apply_activation(Z) return Z, A
# Always use a seed key = random.PRNGKey(1234)
# We will use much bigger array this time W = random.normal(key=key, shape=[2000, 10000], dtype=jnp.float32)
# Never reuse the key key, subkey = random.split(key) X = random.normal(key=subkey, shape=[10000, 20000], dtype=jnp.float32)
# JIT the functions we have individually dot_product_jit = jit(get_dot_product) activation_jit = jit(apply_activation)
# JIT the function that wraps both the functions forward_pass_jit = jit(forward_pass)
for i inrange(3): start = time.time() # Don't forget to use `block_until_ready(..)` # else you will be recording dispatch time only Z = dot_product_jit(W, X).block_until_ready() end = time.time() print(f"Iteration: {i+1}") print(f"Time taken to execute dot product: {end - start:.2f} seconds", end="") start = time.time() A = activation_jit(Z).block_until_ready() print(f", activation function: {time.time()- start:.2f} seconds") # Now measure the time with a single jitted function that calls # the other two functions Z, A = forward_pass_jit(W, X) Z, A = Z.block_until_ready(), A.block_until_ready() print(f"Time taken by the forward pass function: {time.time()- start:.2f} seconds") print("") print("="*50)
结果输出为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
Iteration: 1 Time taken to execute dot product: 8.83 seconds, activation function: 0.08 seconds Time taken by the forward pass function: 6.30 seconds
================================================== Iteration: 2 Time taken to execute dot product: 6.16 seconds, activation function: 0.06 seconds Time taken by the forward pass function: 6.54 seconds
================================================== Iteration: 3 Time taken to execute dot product: 6.12 seconds, activation function: 0.06 seconds Time taken by the forward pass function: 6.17 seconds