Skip to content

Commit f211d34

Browse files
committed
Add troubleshooting page for TF sessions used in predict method (#1141)
Co-authored-by: David Eliahu <deliahu@users.noreply.github.com> (cherry picked from commit 54d3d26)
1 parent 2a2dc0e commit f211d34

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

docs/summary.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
* [API is stuck updating](troubleshooting/stuck-updating.md)
4949
* [NVIDIA runtime not found](troubleshooting/nvidia-container-runtime-not-found.md)
50+
* [TF session in predict()](troubleshooting/tf-session-in-predict.md)
5051

5152
## Guides
5253

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Using TensorFlow session in predict method
2+
3+
When doing inferences with TensorFlow using the [Python Predictor](../deployments/predictors.md#python-predictor), it should be noted that your Python Predictor's `__init__()` constructor is only called on one thread, whereas its `predict()` method can run on any of the available threads (which is configured via the `threads_per_process` field in the API's `predictor` configuration). If `threads_per_process` is set to `1` (the default value), then there is no concern, since `__init__()` and `predict()` will run on the same thread. However, if `threads_per_process` is greater than `1`, then only one of the inference threads will have executed the `__init__()` function. This can cause issues with TensorFlow because the default graph is a property of the current thread, so if `__init__()` initializes the TensorFlow graph, only the thread that executed `__init__()` will have the default graph set.
4+
5+
The error you may see if the default graph is not set (as a consequence of `__init__()` and `predict()` running in separate threads) is:
6+
7+
```text
8+
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, ?), dtype=int32) is not an element of this graph.
9+
```
10+
11+
To avoid this error, you can set the default graph before running the prediction in the `predict()` method:
12+
13+
```python
14+
15+
def predict(self, payload):
16+
with self.sess.graph.as_default():
17+
# perform your inference here
18+
```

0 commit comments

Comments
 (0)