import os import cv2 import glob import time import urllib import requests import numpy as np import matplotlib.pyplot as plt import jax import jax.numpy as jnp from jax import random from jax import make_jaxpr from jax.config import config from jax import grad, vmap, pmap, jit
# What if we want to do this for a batch of vectors? array1 = jnp.stack([jnp.array([1, 2, 3, 4]) for i inrange(5)]) array2 = jnp.stack([jnp.array([5, 6, 7, 8]) for i inrange(5)])
# First way to do batch vector product using loops res1 = [] for i inrange(5): res1.append(dot_product(array1[i], array2[i])) res1 = jnp.stack(res1)
# In numpy, we can use `einsum` for the same res2 = np.einsum('ij,ij-> i', array1, array2)
# We can even simplify einsum and chain two oprations to # achieve the same res3 = np.sum(array1*array2, axis=1)
# Let's check the results print_results(array1, array2, res1, title="1. Dot product on a batch of vectors using loop") print("="*70, "\n") print_results(array1, array2, res2, title="2. Dot product on a batch of vectors in numpy using einsum") print("="*70, "\n") print_results(array1, array2, res3, title="3. Dot product on a batch of vectors using elementwise multiplication and sum")
# Using vmap transformed function res4 = batch_dot_product(array1, array2) print_results(array1, array2, res4, title="Dot product of a batch of vectors using vmap")
# We have a batch of vectors as well already `array2` which looks like this # [[5 6 7 8] # [5 6 7 8] # [5 6 7 8] # [5 6 7 8] # [5 6 7 8]]
# We will now perform the dot product of array1 (a single vetor) with a batch # of vectors (array2 in this case). We will pass `None` in the `in_axes(..)` argument # to say that the first input doesn't have a batch dimension
res5 = vmap(dot_product, in_axes=(None, 0))(array1, array2) print_results(array1, array2, res5, title="Only one of the inputs in batched")
结果输出为:
1 2 3 4 5 6 7 8 9 10 11 12 13
Only one of the inputs in batched First array => Shape: (4,) [1 2 3 4]
defrandom_rotate(img, rotate): """Randomly rotate an image by 90 degrees. Args: img: Array representing the image rotate: Boolean for rotating or not Returns: Either Rotated or an identity image """ return jax.lax.cond(rotate, rotate_img, identity, img)
# Run the pipeline on a single image # Get an image img = images[0] img_copy = img.copy()
# Pass the image copy to augmentation pipeline augmented = random_rotate(img_copy, 1)
# Plot the original image and the augmented image _, ax = plt.subplots(1, 2, figsize=(12, 8))
# Using the same original image img_copy = img.copy()
# Batch size of the output as well as for the boolean array # used to tell whether to rotate an input image or not batch_size = 8
# We use seed for anything that involves `random` key = random.PRNGKey(1234)
# Although splitting is not necessary as the key is only used once, # I will just leave the original key as it is key, subkey = random.split(key) rotate = random.randint(key, shape=[batch_size], minval=0, maxval=2)
# Return identical or flipped image via augmentation pipeline # We will transform the original `random_rotate(..)` function # using vmap augmented = vmap(random_rotate, in_axes=(None, 0))(img_copy, rotate)
print("Number of images to generate: ", batch_size) print("Rotate-or-not array: ", rotate) plot_images(augmented, batch_size=8, title="Multiple augmenetd images from a single input image" )
# Number of images to generate: 8 # Rotate-or-not array: [1 1 0 1 0 1 0 0]
# Original images plot_images(images, batch_size=8, title="Original images")
我们现在增加这批输入图片。仔细查看in_axes()参数的输入。
1 2 3
# Augment a batch of input images using the same augmentation pipeline augmented = vmap(random_rotate, in_axes=(0, 0))(images, rotate) plot_images(augmented, batch_size=8, title="Augmented Images")
defrotate_90(img): """Rotates an image by 90 degress k times.""" return jnp.rot90(img, k=1, axes=(0, 1))
defidentity(img): """Returns an image as it is.""" return img
defflip_left_right(img): """Flips an image left/right direction.""" return jnp.fliplr(img)
defflip_up_down(img): """Flips an image in up/down direction.""" return jnp.flipud(img)
defrandom_rotate(img, rotate): """Randomly rotate an image by 90 degrees. Args: img: Array representing the image rotate: Boolean for rotating or not Returns: Rotated or an identity image """
defrandom_horizontal_flip(img, flip): """Randomly flip an image vertically. Args: img: Array representing the image flip: Boolean for flipping or not Returns: Flipped or an identity image """ return jax.lax.cond(flip, flip_left_right, identity, img) defrandom_vertical_flip(img, flip): """Randomly flip an image vertically. Args: img: Array representing the image flip: Boolean for flipping or not Returns: Flipped or an identity image """ return jax.lax.cond(flip, flip_up_down, identity, img)
# Get the jitted version of our augmentation functions random_rotate_jitted = jit(vmap(random_rotate, in_axes=(0, 0))) random_horizontal_flip_jitted = jit(vmap(random_horizontal_flip, in_axes=(0, 0))) random_vertical_flip_jitted = jit(vmap(random_vertical_flip, in_axes=(0, 0)))
defaugment_images(images, key): """Augment a batch of input images. Args: images: Batch of input images as a jax array key: Seed/Key for random functions for generating booleans Returns: Augmented images with the same shape as the input images """ batch_size = len(images) # 1. Rotation key, subkey = random.split(key) rotate = random.randint(key, shape=[batch_size], minval=0, maxval=2) augmented = random_rotate_jitted(images, rotate) # 2. Flip horizontally key, subkey = random.split(key) flip = random.randint(key, shape=[batch_size], minval=0, maxval=2) augmented = random_horizontal_flip_jitted(augmented, flip) # 3. Flip vertically key, subkey = random.split(key) flip = random.randint(key, shape=[batch_size], minval=0, maxval=2) augmented = random_vertical_flip_jitted(augmented, flip) return augmented.block_until_ready()
# Because we are jitting the transformations, we will record the # time taken for augmentation on subsequent calls for i inrange(3): print("Call: ", i + 1, end=" => ") key=random.PRNGKey(0) start_time = time.time() augmented = augment_images(images, key) print(f"Time taken to generate augmentations: {time.time()-start_time:.2f}")
# Plot the augmented images plot_images(augmented, batch_size=8, title="Augmenetd Images")
结果输出为:
1 2 3
Call: 1 => Time taken to generate augmentations: 1.58 Call: 2 => Time taken to generate augmentations: 0.02 Call: 3 => Time taken to generate augmentations: 0.02
# Augment images function without `jit` # as jitting is not required while using pmap # Get the vmapped version of our augmentation functions random_rotate_vmapped = vmap(random_rotate, in_axes=(0, 0)) random_horizontal_flip_vmapped = vmap(random_horizontal_flip, in_axes=(0, 0)) random_vertical_flip_vmapped = vmap(random_vertical_flip, in_axes=(0, 0))
defaugment_images(images, key): """Augment a batch of input images. Args: images: Batch of input images as a jax array key: Seed/Key for random functions for generating booleans Returns: Augmented images with the same shape as the input images """ batch_size = len(images) # 1. Rotation key, subkey = random.split(key) rotate = random.randint(key, shape=[batch_size], minval=0, maxval=2) augmented = random_rotate_vmapped(images, rotate) # 2. Flip horizontally key, subkey = random.split(key) flip = random.randint(key, shape=[batch_size], minval=0, maxval=2) augmented = random_horizontal_flip_vmapped(augmented, flip) # 3. Flip vertically key, subkey = random.split(key) flip = random.randint(key, shape=[batch_size], minval=0, maxval=2) augmented = random_vertical_flip_vmapped(augmented, flip) return augmented
# Generate a big batch of 64 big_batch_images = jnp.stack([images for i inrange(8)]) print("Number of images in batch: ", big_batch_images.shape[0])
# Generate a batch of keys as well as the augment_images # function accepts a key as well key = random.PRNGKey(123) big_batch_keys = [key]
for i inrange(7): key, subkey = random.split(key) big_batch_keys.append(key) big_batch_keys = jnp.stack(big_batch_keys)
# We will run it more than once for i inrange(3): print("Call: ", i + 1, end=" => ") start_time = time.time() augmented_parallel = pmapped_augment_images(big_batch_images, big_batch_keys) print(f"Time taken to generate augmentations: {time.time()-start_time:.2f}") # Plot the augmenetd images augmented_parallel = augmented_parallel.reshape(64, 800, 800, 3) plot_images(augmented_parallel, batch_size=64, title="Augmentation on parallel devices", figsize=(20, 25))
结果输出为:
1 2 3
Call: 1 => Time taken to generate augmentations: 4.02 Call: 2 => Time taken to generate augmentations: 0.06 Call: 3 => Time taken to generate augmentations: 0.05