import jax import numpy as np import jax.numpy as jnp from jax import random
# We will create two arrays, one with numpy and other with jax # to check the common things and the differences array_numpy = np.arange(10, dtype=np.int32) array_jax = jnp.arange(10, dtype=jnp.int32)
print("Array created using numpy: ", array_numpy) print("Array created using JAX: ", array_jax)
# What types of array are these? print(f"array_numpy is of type : {type(array_numpy)}") print(f"array_jax is of type : {type(array_jax)}")
结果输出为:
1 2 3 4 5
Array created using numpy: [0 1 2 3 4 5 6 7 8 9] Array created using JAX: [0 1 2 3 4 5 6 7 8 9]
array_numpy is of type : <class 'numpy.ndarray'> array_jax is of type : <class 'jaxlib.xla_extension.DeviceArray'>
# Find the max element. Similarly you can find `min` as well print(f"Maximum element in ndarray: {array_numpy.max()}") print(f"Maximum element in DeviceArray: {array_jax.max()}")
# Reshaping print("Original shape of ndarray: ", array_numpy.shape) print("Original shape of DeviceArray: ", array_jax.shape)
print("\nAbsoulte pairwise difference in DeviceArray") print(jnp.abs(array_jax - array_jax.T))
# Are they equal? print("\nAre all the values same?", end=" ") print(jnp.alltrue(np.abs(array_numpy - array_numpy.T) == jnp.abs(array_jax - array_jax.T)))
# Matrix multiplication print("Matrix multiplication of ndarray") print(np.dot(array_numpy, array_numpy.T))
print("\nMatrix multiplication of DeviceArray") print(jnp.dot(array_jax, array_jax.T))
Trying to modify DeviceArray-> TypeError '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
# Of course, updates come in many forms! print(array2.at[4].add(6)) print(array2.at[4].max(20)) print(array2.at[4].min(-1))
# Equivalent but depecated. Just to showcase the similarity to tf scatter_nd_update print("\nEquivalent but deprecatd") print(jax.ops.index_add(array2, 4, 6)) print(jax.ops.index_max(array2, 4, 20)) print(jax.ops.index_min(array2, 4, -1))
# Create two random arrays sampled from a uniform distribution array1 = np.random.uniform(size=(8000, 8000)).astype(np.float32) array2 = jax.random.uniform(jax.random.PRNGKey(0), (8000, 8000), dtype=jnp.float32) # More on PRNGKey later! print("Shape of ndarray: ", array1.shape) print("Shape of DeviceArray: ", array2.shape)
# Shape of ndarray: (8000, 8000) # Shape of DeviceArray: (8000, 8000)
现在,让我们对每个数组进行一些计算,看看会发生什么以及每个计算需要多少时间。
1 2 3 4 5 6 7 8 9 10 11 12
# Dot product on ndarray start_time = time.time() res = np.dot(array1, array1) print(f"Time taken by dot product op on ndarrays: {time.time()-start_time:.2f} seconds")
# Dot product on DeviceArray start_time = time.time() res = jnp.dot(array2, array2) print(f"Time taken by dot product op on DeviceArrays: {time.time()-start_time:.2f} seconds")
# Time taken by dot product op on ndarrays: 7.95 seconds # Time taken by dot product op on DeviceArrays: 0.02 seconds
# Check the difference in semantics of the above function in JAX array2 = jax.random.randint(jax.random.PRNGKey(0), minval=0, maxval=5, shape=[2], dtype=jnp.int32 ) print("Implicit JAX casting gives: ", (array2 + 5.0).dtype)