diff --git a/docs/tutorials/text_generation.ipynb b/docs/tutorials/text_generation.ipynb
index a427038df..d00576e24 100644
--- a/docs/tutorials/text_generation.ipynb
+++ b/docs/tutorials/text_generation.ipynb
@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
"cellView": "form",
"id": "GCCk8_dHpuNf"
@@ -46,20 +46,20 @@
"id": "hcD2nPQvPOFM"
},
"source": [
- "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/text_generation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- "\u003c/table\u003e"
+ "
"
]
},
{
@@ -70,7 +70,7 @@
"source": [
"This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). Given a sequence of characters from this data (\"Shakespear\"), train a model to predict the next character in the sequence (\"e\"). Longer sequences of text can be generated by calling the model repeatedly.\n",
"\n",
- "Note: Enable GPU acceleration to execute this notebook faster. In Colab: *Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e GPU*.\n",
+ "Note: Enable GPU acceleration to execute this notebook faster. In Colab: *Runtime > Change runtime type > Hardware accelerator > GPU*.\n",
"\n",
"This tutorial includes runnable code implemented using [tf.keras](https://www.tensorflow.org/guide/keras/sequential_model) and [eager execution](https://www.tensorflow.org/guide/eager). The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt \"Q\":"
]
@@ -81,7 +81,7 @@
"id": "HcygKkEVZBaa"
},
"source": [
- "\u003cpre\u003e\n",
+ "\n",
"QUEENE:\n",
"I had thought thou hadst a Roman; for the oracle,\n",
"Thus by All bids the man against the word,\n",
@@ -112,7 +112,7 @@
"His lordship pluck'd from this sentence then for prey,\n",
"And then let us twain, being the moon,\n",
"were she such a case as fills m\n",
- "\u003c/pre\u003e"
+ ""
]
},
{
@@ -150,7 +150,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {
"id": "yG_n40gFzf9s"
},
@@ -176,7 +176,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {
"id": "pD_55cOxLkAb"
},
@@ -198,11 +198,19 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {
"id": "aavnuByVymwK"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Length of text: 1115394 characters\n"
+ ]
+ }
+ ],
"source": [
"# Read, then decode for py2 compat.\n",
"text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
@@ -212,11 +220,33 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"metadata": {
"id": "Duhg9NrUymwO"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First Citizen:\n",
+ "Before we proceed any further, hear me speak.\n",
+ "\n",
+ "All:\n",
+ "Speak, speak.\n",
+ "\n",
+ "First Citizen:\n",
+ "You are all resolved rather to die than to famish?\n",
+ "\n",
+ "All:\n",
+ "Resolved. resolved.\n",
+ "\n",
+ "First Citizen:\n",
+ "First, you know Caius Marcius is chief enemy to the people.\n",
+ "\n"
+ ]
+ }
+ ],
"source": [
"# Take a look at the first 250 characters in text\n",
"print(text[:250])"
@@ -224,11 +254,19 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {
"id": "IlCgQBRVymwR"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "65 unique characters\n"
+ ]
+ }
+ ],
"source": [
"# The unique characters in the file\n",
"vocab = sorted(set(text))\n",
@@ -259,11 +297,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {
"id": "a86OoYtO01go"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"example_texts = ['abcdefg', 'xyz']\n",
"\n",
@@ -282,7 +331,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"metadata": {
"id": "6GMlCe3qzaL9"
},
@@ -303,11 +352,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {
"id": "WLv5Q_2TC2pc"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"ids = ids_from_chars(chars)\n",
"ids"
@@ -333,7 +393,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {
"id": "Wd2m3mqkDjRj"
},
@@ -354,11 +414,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {
"id": "c2GCh0ySD44s"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"chars = chars_from_ids(ids)\n",
"chars"
@@ -375,18 +446,29 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {
"id": "zxYI-PeltqKP"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([b'abcdefg', b'xyz'], dtype=object)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"tf.strings.reduce_join(chars, axis=-1).numpy()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {
"id": "w5apvBDn9Ind"
},
@@ -435,11 +517,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"metadata": {
"id": "UopbsKi88tm5"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))\n",
"all_ids"
@@ -447,7 +540,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"metadata": {
"id": "qmxrYDCTy-eL"
},
@@ -458,11 +551,28 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"metadata": {
"id": "cjH5v45-yqqH"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "F\n",
+ "i\n",
+ "r\n",
+ "s\n",
+ "t\n",
+ " \n",
+ "C\n",
+ "i\n",
+ "t\n",
+ "i\n"
+ ]
+ }
+ ],
"source": [
"for ids in ids_dataset.take(10):\n",
" print(chars_from_ids(ids).numpy().decode('utf-8'))"
@@ -470,7 +580,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"metadata": {
"id": "C-G2oaTxy6km"
},
@@ -490,11 +600,27 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 18,
"metadata": {
"id": "BpdjRO2CzOfZ"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tf.Tensor(\n",
+ "[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'\n",
+ " b'\\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'\n",
+ " b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'\n",
+ " b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'\n",
+ " b'e' b'a' b'k' b'.' b'\\n' b'\\n' b'A' b'l' b'l' b':' b'\\n' b'S' b'p' b'e'\n",
+ " b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\\n' b'\\n' b'F' b'i'\n",
+ " b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\\n' b'Y'\n",
+ " b'o' b'u' b' '], shape=(101,), dtype=string)\n"
+ ]
+ }
+ ],
"source": [
"sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)\n",
"\n",
@@ -513,11 +639,23 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
"metadata": {
"id": "QO32cMWu4a06"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "b'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n",
+ "b'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n",
+ "b\"now Caius Marcius is chief enemy to the people.\\n\\nAll:\\nWe know't, we know't.\\n\\nFirst Citizen:\\nLet us ki\"\n",
+ "b\"ll him, and we'll have corn at our own price.\\nIs't a verdict?\\n\\nAll:\\nNo more talking on't; let it be d\"\n",
+ "b'one: away, away!\\n\\nSecond Citizen:\\nOne word, good citizens.\\n\\nFirst Citizen:\\nWe are accounted poor citi'\n"
+ ]
+ }
+ ],
"source": [
"for seq in sequences.take(5):\n",
" print(text_from_ids(seq).numpy())"
@@ -537,7 +675,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 20,
"metadata": {
"id": "9NGu-FkO_kYU"
},
@@ -551,18 +689,30 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"metadata": {
"id": "WxbDTJTw5u_P"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'],\n",
+ " ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"split_input_target(list(\"Tensorflow\"))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
"metadata": {
"id": "B9iKPXkw5xwa"
},
@@ -573,11 +723,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 23,
"metadata": {
"id": "GNbw-iR0ymwj"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input : b'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou'\n",
+ "Target: b'irst Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n"
+ ]
+ }
+ ],
"source": [
"for input_example, target_example in dataset.take(1):\n",
" print(\"Input :\", text_from_ids(input_example).numpy())\n",
@@ -597,11 +756,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 24,
"metadata": {
"id": "p2pGotuNzf-S"
},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<_PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# Batch size\n",
"BATCH_SIZE = 64\n",
@@ -647,7 +817,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 25,
"metadata": {
"id": "zHT8cLh7EAsg"
},
@@ -673,7 +843,7 @@
"source": [
"class MyModel(tf.keras.Model):\n",
" def __init__(self, vocab_size, embedding_dim, rnn_units):\n",
- " super().__init__(self)\n",
+ " super().__init__()\n",
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
" self.gru = tf.keras.layers.GRU(rnn_units,\n",
" return_sequences=True,\n",
@@ -684,7 +854,7 @@
" x = inputs\n",
" x = self.embedding(x, training=training)\n",
" if states is None:\n",
- " states = self.gru.get_initial_state(x)\n",
+ " states = self.gru.get_initial_state(tf.shape(x)[0])\n",
" x, states = self.gru(x, initial_state=states, training=training)\n",
" x = self.dense(x, training=training)\n",
"\n",
@@ -696,7 +866,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 27,
"metadata": {
"id": "IX58Xj9z47Aw"
},
@@ -743,11 +913,25 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 28,
"metadata": {
"id": "C-_70kKAPrPU"
},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "InvalidArgumentError",
+ "evalue": "Exception encountered when calling MyModel.call().\n\n\u001b[1m{{function_node __wrapped__Pack_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = [] [Op:Pack] name: \u001b[0m\n\nArguments received by MyModel.call():\n • inputs=tf.Tensor(shape=(64, 100), dtype=int64)\n • states=None\n • return_state=False\n • training=False",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mInvalidArgumentError\u001b[39m Traceback (most recent call last)",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[28]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_example_batch, target_example_batch \u001b[38;5;129;01min\u001b[39;00m dataset.take(\u001b[32m1\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m example_batch_predictions = \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_example_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(example_batch_predictions.shape, \u001b[33m\"\u001b[39m\u001b[33m# (batch_size, sequence_length, vocab_size)\u001b[39m\u001b[33m\"\u001b[39m)\n",
+ "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\charl\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\keras\\src\\utils\\traceback_utils.py:122\u001b[39m, in \u001b[36mfilter_traceback..error_handler\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 119\u001b[39m filtered_tb = _process_traceback_frames(e.__traceback__)\n\u001b[32m 120\u001b[39m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[32m 121\u001b[39m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m122\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e.with_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 123\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 124\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[26]\u001b[39m\u001b[32m, line 14\u001b[39m, in \u001b[36mMyModel.call\u001b[39m\u001b[34m(self, inputs, states, return_state, training)\u001b[39m\n\u001b[32m 12\u001b[39m x = \u001b[38;5;28mself\u001b[39m.embedding(x, training=training)\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m states \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m states = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgru\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_initial_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 15\u001b[39m x, states = \u001b[38;5;28mself\u001b[39m.gru(x, initial_state=states, training=training)\n\u001b[32m 16\u001b[39m x = \u001b[38;5;28mself\u001b[39m.dense(x, training=training)\n",
+ "\u001b[31mInvalidArgumentError\u001b[39m: Exception encountered when calling MyModel.call().\n\n\u001b[1m{{function_node __wrapped__Pack_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = [] [Op:Pack] name: \u001b[0m\n\nArguments received by MyModel.call():\n • inputs=tf.Tensor(shape=(64, 100), dtype=int64)\n • states=None\n • return_state=False\n • training=False"
+ ]
+ }
+ ],
"source": [
"for input_example_batch, target_example_batch in dataset.take(1):\n",
" example_batch_predictions = model(input_example_batch)\n",
@@ -972,11 +1156,10 @@
"# Directory where the checkpoints will be saved\n",
"checkpoint_dir = './training_checkpoints'\n",
"# Name of the checkpoint files\n",
- "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
+ "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}.keras\")\n",
"\n",
"checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
- " filepath=checkpoint_prefix,\n",
- " save_weights_only=True)"
+ " filepath=checkpoint_prefix)"
]
},
{
@@ -1377,6 +1560,18 @@
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.0"
}
},
"nbformat": 4,