diff --git a/docs/tutorials/text_generation.ipynb b/docs/tutorials/text_generation.ipynb index a427038df..d00576e24 100644 --- a/docs/tutorials/text_generation.ipynb +++ b/docs/tutorials/text_generation.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "cellView": "form", "id": "GCCk8_dHpuNf" @@ -46,20 +46,20 @@ "id": "hcD2nPQvPOFM" }, "source": [ - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/text_generation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/text_generation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e" + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" ] }, { @@ -70,7 +70,7 @@ "source": [ "This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). Given a sequence of characters from this data (\"Shakespear\"), train a model to predict the next character in the sequence (\"e\"). Longer sequences of text can be generated by calling the model repeatedly.\n", "\n", - "Note: Enable GPU acceleration to execute this notebook faster. In Colab: *Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e GPU*.\n", + "Note: Enable GPU acceleration to execute this notebook faster. In Colab: *Runtime > Change runtime type > Hardware accelerator > GPU*.\n", "\n", "This tutorial includes runnable code implemented using [tf.keras](https://www.tensorflow.org/guide/keras/sequential_model) and [eager execution](https://www.tensorflow.org/guide/eager). The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt \"Q\":" ] @@ -81,7 +81,7 @@ "id": "HcygKkEVZBaa" }, "source": [ - "\u003cpre\u003e\n", + "
\n",
         "QUEENE:\n",
         "I had thought thou hadst a Roman; for the oracle,\n",
         "Thus by All bids the man against the word,\n",
@@ -112,7 +112,7 @@
         "His lordship pluck'd from this sentence then for prey,\n",
         "And then let us twain, being the moon,\n",
         "were she such a case as fills m\n",
-        "\u003c/pre\u003e"
+        "
" ] }, { @@ -150,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "id": "yG_n40gFzf9s" }, @@ -176,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "pD_55cOxLkAb" }, @@ -198,11 +198,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "aavnuByVymwK" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of text: 1115394 characters\n" + ] + } + ], "source": [ "# Read, then decode for py2 compat.\n", "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n", @@ -212,11 +220,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "id": "Duhg9NrUymwO" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First Citizen:\n", + "Before we proceed any further, hear me speak.\n", + "\n", + "All:\n", + "Speak, speak.\n", + "\n", + "First Citizen:\n", + "You are all resolved rather to die than to famish?\n", + "\n", + "All:\n", + "Resolved. resolved.\n", + "\n", + "First Citizen:\n", + "First, you know Caius Marcius is chief enemy to the people.\n", + "\n" + ] + } + ], "source": [ "# Take a look at the first 250 characters in text\n", "print(text[:250])" @@ -224,11 +254,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "id": "IlCgQBRVymwR" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "65 unique characters\n" + ] + } + ], "source": [ "# The unique characters in the file\n", "vocab = sorted(set(text))\n", @@ -259,11 +297,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "id": "a86OoYtO01go" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "example_texts = ['abcdefg', 'xyz']\n", "\n", @@ -282,7 +331,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "id": "6GMlCe3qzaL9" }, @@ -303,11 +352,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "id": "WLv5Q_2TC2pc" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "ids = ids_from_chars(chars)\n", "ids" @@ -333,7 +393,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "id": "Wd2m3mqkDjRj" }, @@ -354,11 +414,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "id": "c2GCh0ySD44s" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "chars = chars_from_ids(ids)\n", "chars" @@ -375,18 +446,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "id": "zxYI-PeltqKP" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "array([b'abcdefg', b'xyz'], dtype=object)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "tf.strings.reduce_join(chars, axis=-1).numpy()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { "id": "w5apvBDn9Ind" }, @@ -435,11 +517,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "id": "UopbsKi88tm5" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))\n", "all_ids" @@ -447,7 +540,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": { "id": "qmxrYDCTy-eL" }, @@ -458,11 +551,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": { "id": "cjH5v45-yqqH" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "F\n", + "i\n", + "r\n", + "s\n", + "t\n", + " \n", + "C\n", + "i\n", + "t\n", + "i\n" + ] + } + ], "source": [ "for ids in ids_dataset.take(10):\n", " print(chars_from_ids(ids).numpy().decode('utf-8'))" @@ -470,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": { "id": "C-G2oaTxy6km" }, @@ -490,11 +600,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": { "id": "BpdjRO2CzOfZ" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'\n", + " b'\\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'\n", + " b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'\n", + " b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'\n", + " b'e' b'a' b'k' b'.' b'\\n' b'\\n' b'A' b'l' b'l' b':' b'\\n' b'S' b'p' b'e'\n", + " b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\\n' b'\\n' b'F' b'i'\n", + " b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\\n' b'Y'\n", + " b'o' b'u' b' '], shape=(101,), dtype=string)\n" + ] + } + ], "source": [ "sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)\n", "\n", @@ -513,11 +639,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": { "id": "QO32cMWu4a06" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n", + "b'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n", + "b\"now Caius Marcius is chief enemy to the people.\\n\\nAll:\\nWe know't, we know't.\\n\\nFirst Citizen:\\nLet us ki\"\n", + "b\"ll him, and we'll have corn at our own price.\\nIs't a verdict?\\n\\nAll:\\nNo more talking on't; let it be d\"\n", + "b'one: away, away!\\n\\nSecond Citizen:\\nOne word, good citizens.\\n\\nFirst Citizen:\\nWe are accounted poor citi'\n" + ] + } + ], "source": [ "for seq in sequences.take(5):\n", " print(text_from_ids(seq).numpy())" @@ -537,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": { "id": "9NGu-FkO_kYU" }, @@ -551,18 +689,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": { "id": "WxbDTJTw5u_P" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'],\n", + " ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "split_input_target(list(\"Tensorflow\"))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": { "id": "B9iKPXkw5xwa" }, @@ -573,11 +723,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": { "id": "GNbw-iR0ymwj" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input : b'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou'\n", + "Target: b'irst Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n" + ] + } + ], "source": [ "for input_example, target_example in dataset.take(1):\n", " print(\"Input :\", text_from_ids(input_example).numpy())\n", @@ -597,11 +756,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": { "id": "p2pGotuNzf-S" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "<_PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Batch size\n", "BATCH_SIZE = 64\n", @@ -647,7 +817,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": { "id": "zHT8cLh7EAsg" }, @@ -673,7 +843,7 @@ "source": [ "class MyModel(tf.keras.Model):\n", " def __init__(self, vocab_size, embedding_dim, rnn_units):\n", - " super().__init__(self)\n", + " super().__init__()\n", " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", " self.gru = tf.keras.layers.GRU(rnn_units,\n", " return_sequences=True,\n", @@ -684,7 +854,7 @@ " x = inputs\n", " x = self.embedding(x, training=training)\n", " if states is None:\n", - " states = self.gru.get_initial_state(x)\n", + " states = self.gru.get_initial_state(tf.shape(x)[0])\n", " x, states = self.gru(x, initial_state=states, training=training)\n", " x = self.dense(x, training=training)\n", "\n", @@ -696,7 +866,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": { "id": "IX58Xj9z47Aw" }, @@ -743,11 +913,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": { "id": "C-_70kKAPrPU" }, - "outputs": [], + "outputs": [ + { + "ename": "InvalidArgumentError", + "evalue": "Exception encountered when calling MyModel.call().\n\n\u001b[1m{{function_node __wrapped__Pack_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = [] [Op:Pack] name: \u001b[0m\n\nArguments received by MyModel.call():\n • inputs=tf.Tensor(shape=(64, 100), dtype=int64)\n • states=None\n • return_state=False\n • training=False", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mInvalidArgumentError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[28]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m input_example_batch, target_example_batch \u001b[38;5;129;01min\u001b[39;00m dataset.take(\u001b[32m1\u001b[39m):\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m example_batch_predictions = \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_example_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[38;5;28mprint\u001b[39m(example_batch_predictions.shape, \u001b[33m\"\u001b[39m\u001b[33m# (batch_size, sequence_length, vocab_size)\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32mc:\\Users\\charl\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\keras\\src\\utils\\traceback_utils.py:122\u001b[39m, in \u001b[36mfilter_traceback..error_handler\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 119\u001b[39m filtered_tb = _process_traceback_frames(e.__traceback__)\n\u001b[32m 120\u001b[39m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[32m 121\u001b[39m \u001b[38;5;66;03m# `keras.config.disable_traceback_filtering()`\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m122\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e.with_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 123\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 124\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[26]\u001b[39m\u001b[32m, line 14\u001b[39m, in \u001b[36mMyModel.call\u001b[39m\u001b[34m(self, inputs, states, return_state, training)\u001b[39m\n\u001b[32m 12\u001b[39m x = \u001b[38;5;28mself\u001b[39m.embedding(x, training=training)\n\u001b[32m 13\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m states \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m14\u001b[39m states = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mgru\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget_initial_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 15\u001b[39m x, states = \u001b[38;5;28mself\u001b[39m.gru(x, initial_state=states, training=training)\n\u001b[32m 16\u001b[39m x = \u001b[38;5;28mself\u001b[39m.dense(x, training=training)\n", + "\u001b[31mInvalidArgumentError\u001b[39m: Exception encountered when calling MyModel.call().\n\n\u001b[1m{{function_node __wrapped__Pack_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [64,100,256] != values[1].shape = [] [Op:Pack] name: \u001b[0m\n\nArguments received by MyModel.call():\n • inputs=tf.Tensor(shape=(64, 100), dtype=int64)\n • states=None\n • return_state=False\n • training=False" + ] + } + ], "source": [ "for input_example_batch, target_example_batch in dataset.take(1):\n", " example_batch_predictions = model(input_example_batch)\n", @@ -972,11 +1156,10 @@ "# Directory where the checkpoints will be saved\n", "checkpoint_dir = './training_checkpoints'\n", "# Name of the checkpoint files\n", - "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n", + "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}.keras\")\n", "\n", "checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n", - " filepath=checkpoint_prefix,\n", - " save_weights_only=True)" + " filepath=checkpoint_prefix)" ] }, { @@ -1377,6 +1560,18 @@ "kernelspec": { "display_name": "Python 3", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" } }, "nbformat": 4,