@@ -645,72 +645,18 @@ breadth of relevant operator coverage isn't clear; however, the normalizations
645645that it does provide are worth using, so we'll make use of them throughout.
646646:END:
647647
648- [[grappler-normalize-function]] provides a simple means of
648+ src_python[:eval never]{symbolic_pymc.tensorflow.graph.normalize_tf_graph} provides a simple means of
649649applying src_python[:eval never]{grappler}.
650650
651- #+NAME: grappler-normalize-function
652- #+BEGIN_SRC python :exports code :results silent
653- from tensorflow.core.protobuf import config_pb2
654-
655- from tensorflow.python.framework import ops
656- from tensorflow.python.framework import importer
657- from tensorflow.python.framework import meta_graph
658-
659- from tensorflow.python.grappler import cluster
660- from tensorflow.python.grappler import tf_optimizer
661-
662-
663- try:
664- gcluster = cluster.Cluster()
665- except tf.errors.UnavailableError:
666- pass
667-
668- config = config_pb2.ConfigProto()
669-
670-
671- def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
672- """Use grappler to normalize a graph.
673-
674- Arguments
675- =========
676- graph_output: Tensor
677- A tensor we want to consider as "output" of a FuncGraph.
678-
679- Returns
680- =======
681- The simplified graph.
682- """
683- train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
684- train_op.clear()
685- train_op.extend([graph_output])
686-
687- metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)
688-
689- optimized_graphdef = tf_optimizer.OptimizeGraph(
690- config, metagraph, verbose=verbose, cluster=gcluster)
691-
692- output_name = graph_output.name
693-
694- if new_graph:
695- optimized_graph = ops.Graph()
696- else:
697- optimized_graph = ops.get_default_graph()
698- del graph_output
699-
700- with optimized_graph.as_default():
701- importer.import_graph_def(optimized_graphdef, name="")
702-
703- opt_graph_output = optimized_graph.get_tensor_by_name(output_name)
704-
705- return opt_graph_output
706- #+END_SRC
707-
708- In [[grappler-normalize-function]] we
651+ In [[grappler-normalize-test-graph]] we
709652run src_python[:eval never]{grappler} on the log-likelihood graph for a normal
710653random variable from [[tfp-normal-log-lik-graph]].
711654
712655#+NAME: grappler-normalize-test-graph
713656#+BEGIN_SRC python :exports code :results silent :wrap
657+ from symbolic_pymc.tensorflow.graph import normalize_tf_graph
658+
659+
714660normal_log_lik_opt = normalize_tf_graph(normal_log_lik)
715661#+END_SRC
716662
0 commit comments