import os import keras_nlp import keras import tensorflow as tf import time import tensorflow_datasets as tfds
os.environ["KERAS_BACKEND"] = "jax"# or "tensorflow" or "torch" keras.mixed_precision.set_global_policy("mixed_float16")
reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True) for document, title in reddit_ds: print(document.numpy()) print(title.numpy()) break
# Use a string identifier. gpt2_lm.compile(sampler="top_k") output = gpt2_lm.generate("I like basketball", max_length=200) print("\nGPT-2 output:") print(output)
# 自定义采样器实例 # Use a `Sampler` instance. `GreedySampler` tends to repeat itself, greedy_sampler = keras_nlp.samplers.GreedySampler() gpt2_lm.compile(sampler=greedy_sampler)
output = gpt2_lm.generate("I like basketball", max_length=200) print("\nGPT-2 output:") print(output)
# GPT-2 output: # I like basketball, and this is a pretty good one. # so i was playing basketball at my local high school, and i was playing with my friends.