import numpy as np import jax import jax.numpy as jnp from jax import grad from jax import jit from jax import lax from jax import random
# A global variable counter = 5
defadd_global_value(x): """ A function that relies on the global variable `counter` for doing some computation. """ return x + counter
x = 2
# We will `JIT` the function so that it runs as a JAX transformed # function and not like a normal python function y = jit(add_global_value)(x) print("Global variable value: ", counter) print(f"First call to the function with input {x} with global variable value {counter} returned {y}")
# Someone updated the global variable value later in the code counter = 10
# Call the function again y = jit(add_global_value)(x) print("\nGlobal variable changed value: ", counter) print(f"Second call to the function with input {x} with global variable value {counter} returned {y}")
结果输出为:
1 2 3 4 5
Global variable value: 5 First call to the function with input 2 with global variable value 5 returned 7
Global variable changed value: 10 Second call to the function with input 2 with global variable value 10 returned 7
# Change the type of the argument passed to the function # In this case we will change int to float (2 -> 2.0) x = 2.0 y = jit(add_global_value)(x) print(f"Third call to the function with input {x} with global variable value {counter} returned {y}")
# Change the shape of the argument x = jnp.array([2])
# Changing global variable value again counter = 15
# Call the function again y = jit(add_global_value)(x) print(f"Third call to the function with input {x} with global variable value {counter} returned {y}")
结果输出为:
1 2
Third call to the function with input 2.0 with global variable value 10 returned 12.0 Third call to the function with input [2] with global variable value 15 returned [17]
y = apply_sin_to_global() print("Global variable value: ", counter) print(f"First call to the function with global variable value {counter} returned {y}")
# Change the global value again counter = 90 y = apply_sin_to_global() print("\nGlobal variable value: ", counter) print(f"Second call to the function with global variable value {counter} returned {y}")
结果输出为:
1 2 3 4 5
Global variable value: 15 First call to the function with global variable value 15 returned 0.20791170001029968
Global variable value: 90 Second call to the function with global variable value 90 returned 0.03489949554204941
# A function that takes an actual array object # and add all the elements present in it defadd_elements(array, start, end, initial_value=0): res = 0 defloop_fn(i, val): return val + array[i] return lax.fori_loop(start, end, loop_fn, initial_value)
# Define an array object array = jnp.arange(5) print("Array: ", array) print("Adding all the array elements gives: ", add_elements(array, 0, len(array), 0))
# Redefining the same function but this time it takes an # iterator object as an input defadd_elements(iterator, start, end, initial_value=0): res = 0 defloop_fn(i, val): return val + next(iterator) return lax.fori_loop(start, end, loop_fn, initial_value) # Create an iterator object iterator = iter(np.arange(5)) print("\n\nIterator: ", iterator) print("Adding all the elements gives: ", add_elements(iterator, 0, 5, 0))
结果输出为:
1 2 3 4 5
Array: [0 1 2 3 4] Adding all the array elements gives: 10
Iterator: <iterator object at 0x7ff9e82205d0> Adding all the elements gives: 0
为什么第二种情况的结果为0?这是因为迭代器引入了外部状态来检索下一个值。
案例3:IO
让我们再举一个例子,一个非常不寻常的例子,它可能会使你的函数变得不纯粹(impure)。
1 2 3 4 5 6 7 8 9 10 11 12
defreturn_as_it_is(x): """Returns the same element doing nothing. A function that isn't using `globals` or any `iterator` """ print(f"I have received the value") return x
# First call to the function print(f"Value returned on first call: {jit(return_as_it_is)(2)}\n")
# Second call to the fucntion with different value print(f"Value returned on second call: {jit(return_as_it_is)(4)}")
结果输出为:
1 2 3 4
I have received the value Value returned on first call: 2
# Function that uses stateful objects but internally and is still pure defpure_function_with_stateful_obejcts(array): array_dict = {} for i inrange(len(array)): array_dict[i] = array[i] + 10 return array_dict
array = jnp.arange(5)
# First call to the function print(f"Value returned on first call: {jit(pure_function_with_stateful_obejcts)(array)}")
# Second call to the fucntion with different value print(f"\nValue returned on second call: {jit(pure_function_with_stateful_obejcts)(array)}")
结果输出为:
1 2 3
Value returned on first call: {0: DeviceArray(10, dtype=int32), 1: DeviceArray(11, dtype=int32), 2: DeviceArray(12, dtype=int32), 3: DeviceArray(13, dtype=int32), 4: DeviceArray(14, dtype=int32)}
Value returned on second call: {0: DeviceArray(10, dtype=int32), 1: DeviceArray(11, dtype=int32), 2: DeviceArray(12, dtype=int32), 3: DeviceArray(13, dtype=int32), 4: DeviceArray(14, dtype=int32)}