@@ -35,6 +35,9 @@ export interface Args {
3535 model ?: string ;
3636}
3737
38+ export type RequestArgs = Args &
39+ ( { data ?: Blob | ArrayBuffer } | { inputs : unknown } ) & { parameters ?: Record < string , unknown > } ;
40+
3841export type FillMaskArgs = Args & {
3942 inputs : string ;
4043} ;
@@ -909,10 +912,7 @@ export class HfInference {
909912 args : AutomaticSpeechRecognitionArgs ,
910913 options ?: Options
911914 ) : Promise < AutomaticSpeechRecognitionReturn > {
912- const res = await this . request < AutomaticSpeechRecognitionReturn > ( args , {
913- ...options ,
914- binary : true ,
915- } ) ;
915+ const res = await this . request < AutomaticSpeechRecognitionReturn > ( args , options ) ;
916916 const isValidOutput = typeof res . text === "string" ;
917917 if ( ! isValidOutput ) {
918918 throw new TypeError ( "Invalid inference output: output must be of type <text: string>" ) ;
@@ -928,10 +928,7 @@ export class HfInference {
928928 args : AudioClassificationArgs ,
929929 options ?: Options
930930 ) : Promise < AudioClassificationReturn > {
931- const res = await this . request < AudioClassificationReturn > ( args , {
932- ...options ,
933- binary : true ,
934- } ) ;
931+ const res = await this . request < AudioClassificationReturn > ( args , options ) ;
935932 const isValidOutput =
936933 Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
937934 if ( ! isValidOutput ) {
@@ -948,10 +945,7 @@ export class HfInference {
948945 args : ImageClassificationArgs ,
949946 options ?: Options
950947 ) : Promise < ImageClassificationReturn > {
951- const res = await this . request < ImageClassificationReturn > ( args , {
952- ...options ,
953- binary : true ,
954- } ) ;
948+ const res = await this . request < ImageClassificationReturn > ( args , options ) ;
955949 const isValidOutput =
956950 Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
957951 if ( ! isValidOutput ) {
@@ -965,10 +959,7 @@ export class HfInference {
965959 * Recommended model: facebook/detr-resnet-50
966960 */
967961 public async objectDetection ( args : ObjectDetectionArgs , options ?: Options ) : Promise < ObjectDetectionReturn > {
968- const res = await this . request < ObjectDetectionReturn > ( args , {
969- ...options ,
970- binary : true ,
971- } ) ;
962+ const res = await this . request < ObjectDetectionReturn > ( args , options ) ;
972963 const isValidOutput =
973964 Array . isArray ( res ) &&
974965 res . every (
@@ -993,10 +984,7 @@ export class HfInference {
993984 * Recommended model: facebook/detr-resnet-50-panoptic
994985 */
995986 public async imageSegmentation ( args : ImageSegmentationArgs , options ?: Options ) : Promise < ImageSegmentationReturn > {
996- const res = await this . request < ImageSegmentationReturn > ( args , {
997- ...options ,
998- binary : true ,
999- } ) ;
987+ const res = await this . request < ImageSegmentationReturn > ( args , options ) ;
1000988 const isValidOutput =
1001989 Array . isArray ( res ) &&
1002990 res . every ( ( x ) => typeof x . label === "string" && typeof x . mask === "string" && typeof x . score === "number" ) ;
@@ -1013,10 +1001,7 @@ export class HfInference {
10131001 * Recommended model: stabilityai/stable-diffusion-2
10141002 */
10151003 public async textToImage ( args : TextToImageArgs , options ?: Options ) : Promise < TextToImageReturn > {
1016- const res = await this . request < TextToImageReturn > ( args , {
1017- ...options ,
1018- blob : true ,
1019- } ) ;
1004+ const res = await this . request < TextToImageReturn > ( args , options ) ;
10201005 const isValidOutput = res && res instanceof Blob ;
10211006 if ( ! isValidOutput ) {
10221007 throw new TypeError ( "Invalid inference output: output must be of type object & of instance Blob" ) ;
@@ -1028,25 +1013,18 @@ export class HfInference {
10281013 * This task reads some image input and outputs the text caption.
10291014 */
10301015 public async imageToText ( args : ImageToTextArgs , options ?: Options ) : Promise < ImageToTextReturn > {
1031- return (
1032- await this . request < [ ImageToTextReturn ] > ( args , {
1033- ...options ,
1034- binary : true ,
1035- } )
1036- ) ?. [ 0 ] ;
1016+ return ( await this . request < [ ImageToTextReturn ] > ( args , options ) ) ?. [ 0 ] ;
10371017 }
10381018
10391019 /**
10401020 * Helper that prepares request arguments
10411021 */
10421022 private makeRequestOptions (
1043- args : Args & {
1023+ args : RequestArgs & {
10441024 data ?: Blob | ArrayBuffer ;
10451025 stream ?: boolean ;
10461026 } ,
10471027 options ?: Options & {
1048- binary ?: boolean ;
1049- blob ?: boolean ;
10501028 /** For internal HF use, which is why it's not exposed in {@link Options} */
10511029 includeCredentials ?: boolean ;
10521030 }
@@ -1059,11 +1037,11 @@ export class HfInference {
10591037 headers [ "Authorization" ] = `Bearer ${ this . apiKey } ` ;
10601038 }
10611039
1062- if ( ! options ?. binary ) {
1063- headers [ "Content-Type" ] = "application/json" ;
1064- }
1040+ const binary = "data" in args && ! ! args . data ;
10651041
1066- if ( options ?. binary ) {
1042+ if ( ! binary ) {
1043+ headers [ "Content-Type" ] = "application/json" ;
1044+ } else {
10671045 if ( mergedOptions . wait_for_model ) {
10681046 headers [ "X-Wait-For-Model" ] = "true" ;
10691047 }
@@ -1082,7 +1060,7 @@ export class HfInference {
10821060 const info : RequestInit = {
10831061 headers,
10841062 method : "POST" ,
1085- body : options ?. binary
1063+ body : binary
10861064 ? args . data
10871065 : JSON . stringify ( {
10881066 ...otherArgs ,
@@ -1094,11 +1072,12 @@ export class HfInference {
10941072 return { url, info, mergedOptions } ;
10951073 }
10961074
1075+ /**
1076+ * Primitive to make custom calls to the inference API
1077+ */
10971078 public async request < T > (
1098- args : Args & { data ?: Blob | ArrayBuffer } ,
1079+ args : RequestArgs ,
10991080 options ?: Options & {
1100- binary ?: boolean ;
1101- blob ?: boolean ;
11021081 /** For internal HF use, which is why it's not exposed in {@link Options} */
11031082 includeCredentials ?: boolean ;
11041083 }
@@ -1113,34 +1092,29 @@ export class HfInference {
11131092 } ) ;
11141093 }
11151094
1116- if ( options ?. blob ) {
1117- if ( ! response . ok ) {
1118- if ( response . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
1119- const output = await response . json ( ) ;
1120- if ( output . error ) {
1121- throw new Error ( output . error ) ;
1122- }
1095+ if ( ! response . ok ) {
1096+ if ( response . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
1097+ const output = await response . json ( ) ;
1098+ if ( output . error ) {
1099+ throw new Error ( output . error ) ;
11231100 }
1124- throw new Error ( "An error occurred while fetching the blob" ) ;
11251101 }
1126- return ( await response . blob ( ) ) as T ;
1102+ throw new Error ( "An error occurred while fetching the blob" ) ;
11271103 }
11281104
1129- const output = await response . json ( ) ;
1130- if ( output . error ) {
1131- throw new Error ( output . error ) ;
1105+ if ( response . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
1106+ return await response . json ( ) ;
11321107 }
1133- return output ;
1108+
1109+ return ( await response . blob ( ) ) as T ;
11341110 }
11351111
11361112 /**
1137- * Make request that uses server-sent events and returns response as a generator
1113+ * Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
11381114 */
11391115 public async * streamingRequest < T > (
1140- args : Args & { data ?: Blob | ArrayBuffer } ,
1116+ args : RequestArgs ,
11411117 options ?: Options & {
1142- binary ?: boolean ;
1143- blob ?: boolean ;
11441118 /** For internal HF use, which is why it's not exposed in {@link Options} */
11451119 includeCredentials ?: boolean ;
11461120 }
0 commit comments