Skip to content
This repository was archived by the owner on Jan 10, 2025. It is now read-only.

Commit 64a9f1a

Browse files
committed
gpt2 generation script
1 parent f68ef89 commit 64a9f1a

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

models_generation/gpt2.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import tensorflow as tf
2+
from transformers import TFGPT2LMHeadModel
3+
4+
model = TFGPT2LMHeadModel.from_pretrained('gpt2') # or 'distilgpt2'
5+
6+
input_spec = tf.TensorSpec([1, 64], tf.int32)
7+
model._set_inputs(input_spec, training=False)
8+
9+
print(model.inputs)
10+
print(model.outputs)
11+
12+
converter = tf.lite.TFLiteConverter.from_keras_model(model)
13+
14+
# For FP16 quantization:
15+
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
16+
# converter.target_spec.supported_types = [tf.float16]
17+
18+
tflite_model = converter.convert()
19+
20+
open("gpt2-64.tflite", "wb").write(tflite_model)

0 commit comments

Comments
 (0)