|
17 | 17 | import tempfile |
18 | 18 | import numpy as np |
19 | 19 |
|
| 20 | +from absl.testing import parameterized |
20 | 21 | from tensorflow.python.client import session |
21 | 22 | from tensorflow.python.data.ops import dataset_ops |
22 | 23 | from tensorflow.python.framework import dtypes |
|
26 | 27 | from tensorflow.python.ipu import loops |
27 | 28 | from tensorflow.python.ipu.config import DeviceConnectionType |
28 | 29 | from tensorflow.python.ipu.config import IPUConfig |
| 30 | +from tensorflow.python.ipu.ops import pipelining_ops |
29 | 31 | from tensorflow.python.ipu.ops.application_compile_op import experimental_application_compile_op as application_compile_op |
30 | 32 | from tensorflow.python.keras import layers |
31 | 33 | from tensorflow.python.ops import array_ops |
|
36 | 38 | from tensorflow.python.training import gradient_descent |
37 | 39 |
|
38 | 40 |
|
39 | | -class TestApplicationCompile(test_util.TensorFlowTestCase): |
| 41 | +class TestApplicationCompile(test_util.TensorFlowTestCase, |
| 42 | + parameterized.TestCase): |
40 | 43 | def setUp(self): |
41 | 44 | super().setUp() |
42 | 45 |
|
@@ -209,6 +212,39 @@ def my_net(lr): |
209 | 212 |
|
210 | 213 | self.assertGreater(os.path.getsize(compiled_path.decode()), 0) |
211 | 214 |
|
| 215 | + @parameterized.named_parameters(("resources", False), ("constants", True)) |
| 216 | + @test_util.deprecated_graph_mode_only |
| 217 | + def test_compile_pipeline(self, freeze_variables): |
| 218 | + with session.Session() as sess: |
| 219 | + |
| 220 | + dataset = dataset_ops.Dataset.from_tensor_slices((np.ones( |
| 221 | + (10, 5), dtype=np.float32),)) |
| 222 | + dataset = dataset.batch(1, drop_remainder=True) |
| 223 | + infeed_queue = ipu_infeed_queue.IPUInfeedQueue(dataset) |
| 224 | + outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue() |
| 225 | + |
| 226 | + def stage1(offset, x): |
| 227 | + return layers.Dense(5, activation="relu")(x) + offset |
| 228 | + |
| 229 | + def stage2(x): |
| 230 | + return layers.Dense(10, activation="softmax")(x) |
| 231 | + |
| 232 | + def my_net(): |
| 233 | + return pipelining_ops.pipeline(computational_stages=[stage1, stage2], |
| 234 | + gradient_accumulation_count=4, |
| 235 | + infeed_queue=infeed_queue, |
| 236 | + inputs=[42.0], |
| 237 | + outfeed_queue=outfeed_queue, |
| 238 | + device_mapping=[0, 0]) |
| 239 | + |
| 240 | + result = application_compile_op(my_net, |
| 241 | + freeze_variables=freeze_variables) |
| 242 | + |
| 243 | + sess.run(variables.global_variables_initializer()) |
| 244 | + compiled_path = sess.run(result) |
| 245 | + |
| 246 | + self.assertGreater(os.path.getsize(compiled_path.decode()), 0) |
| 247 | + |
212 | 248 |
|
213 | 249 | if __name__ == "__main__": |
214 | 250 | test.main() |
0 commit comments