Skip to content

Commit ec3710d

Browse files
GreenstanSolomonLakepre-commit-ci[bot]
authored
Added Confidence Argument to keypoint detection model (#354)
* Added "confidence" as argument for keypoint detection model - confidence added to KeypointDetectionModel constructor and api url string * Revert "Added "confidence" as argument for keypoint detection model" This reverts commit 583818f. * Reapply "Added "confidence" as argument for keypoint detection model" This reverts commit 00c1a11. * Added Tests for the Keypoint detection model * Fixed String concatenation and removed unnecessary dependabot file * Fixed Unit-tests for Keypoint Detection Model * Fixed confidence kwarg in Keypoint-detection model * Update tests/models/test_keypoint_detection.py Co-authored-by: Solomon Lake Giffen-Hunter <lakegh@gmail.com> * Update tests/models/test_keypoint_detection.py Co-authored-by: Solomon Lake Giffen-Hunter <lakegh@gmail.com> * Update tests/models/test_keypoint_detection.py Co-authored-by: Solomon Lake Giffen-Hunter <lakegh@gmail.com> * Update tests/models/test_keypoint_detection.py Co-authored-by: Solomon Lake Giffen-Hunter <lakegh@gmail.com> * fix(pre_commit): 🎨 auto format pre-commit hooks * Changed some file paths to relative paths --------- Co-authored-by: Solomon Lake Giffen-Hunter <lakegh@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 998d94b commit ec3710d

File tree

4 files changed

+510
-3
lines changed

4 files changed

+510
-3
lines changed

roboflow/models/keypoint_detection.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
id: str,
2727
name: Optional[str] = None,
2828
version: Optional[str] = None,
29+
confidence: Optional[int] = 40,
2930
local: Optional[str] = None,
3031
):
3132
"""
@@ -37,6 +38,7 @@ def __init__(
3738
name (str): is the name of the project
3839
version (str): version number
3940
local (str): localhost address and port if pointing towards local inference engine
41+
confidence (int): A threshold for the returned predictions on a scale of 0-100.
4042
colors (dict): colors to use for the image
4143
preprocessing (dict): preprocessing to use for the image
4244
@@ -48,6 +50,7 @@ def __init__(
4850
self.__api_key = api_key
4951
self.id = id
5052
self.name = name
53+
self.confidence = confidence
5154
self.version = version
5255
self.base_url = "https://detect.roboflow.com/"
5356

@@ -58,7 +61,7 @@ def __init__(
5861
print(f"initalizing local keypoint detection model hosted at : {local}")
5962
self.base_url = local
6063

61-
def predict(self, image_path, hosted=False): # type: ignore[override]
64+
def predict(self, image_path, hosted=False, confidence=None): # type: ignore[override]
6265
"""
6366
Run inference on an image.
6467
@@ -80,7 +83,10 @@ def predict(self, image_path, hosted=False): # type: ignore[override]
8083
8184
>>> prediction = model.predict("YOUR_IMAGE.jpg")
8285
"""
83-
self.__generate_url()
86+
if confidence is not None:
87+
self.confidence = confidence
88+
89+
self.__generate_url(confidence=confidence)
8490
self.__exception_check(image_path_check=image_path)
8591
# If image is local image
8692
if not hosted:
@@ -130,7 +136,7 @@ def load_model(self, name, version):
130136
self.version = version
131137
self.__generate_url()
132138

133-
def __generate_url(self):
139+
def __generate_url(self, confidence=None):
134140
"""
135141
Generate a Roboflow API URL on which to run inference.
136142
@@ -145,11 +151,15 @@ def __generate_url(self):
145151
if not version and len(splitted) > 2:
146152
version = splitted[2]
147153

154+
if confidence is not None:
155+
self.confidence = confidence
156+
148157
self.api_url = "".join(
149158
[
150159
self.base_url + without_workspace + "/" + str(version),
151160
"?api_key=" + self.__api_key,
152161
"&name=YOUR_IMAGE.jpg",
162+
"&confidence=" + str(self.confidence),
153163
]
154164
)
155165

@@ -175,6 +185,7 @@ def __str__(self):
175185
json_value = {
176186
"name": self.name,
177187
"version": self.version,
188+
"confidence": self.confidence,
178189
"base_url": self.base_url,
179190
}
180191

0 commit comments

Comments
 (0)