@@ -28,20 +28,17 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) {
2828 return REDISMODULE_OK ;
2929}
3030
31- // Managing context for the DLManagedTensor, will manage the lifetime of
32- // DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
33- // original framework of destruction, and this context will be deleted also.
34- struct TfDlManagedTensorCtx {
31+ struct TFDLManagedTensorCtx {
3532 TFE_TensorHandle * reference ;
3633 int64_t ndim ;
3734 int64_t * shape ;
3835 int64_t * strides ;
3936 DLManagedTensor tensor ;
4037};
41- typedef struct TfDlManagedTensorCtx TfDlManagedTensorCtx ;
38+ typedef struct TFDLManagedTensorCtx TFDLManagedTensorCtx ;
4239
43- TfDlManagedTensorCtx * TfDlManagedTensorCtx_Create (TFE_TensorHandle * h , TF_Status * status ) {
44- TfDlManagedTensorCtx * ctx = RedisModule_Alloc (sizeof (TfDlManagedTensorCtx ));
40+ TFDLManagedTensorCtx * TFDLManagedTensorCtx_Create (TFE_TensorHandle * h , TF_Status * status ) {
41+ TFDLManagedTensorCtx * ctx = RedisModule_Alloc (sizeof (TFDLManagedTensorCtx ));
4542 ctx -> ndim = TFE_TensorHandleNumDims (h , status );
4643 ctx -> shape = RedisModule_Calloc (ctx -> ndim , sizeof (int64_t ));
4744 ctx -> strides = RedisModule_Calloc (ctx -> ndim , sizeof (int64_t ));
@@ -55,23 +52,19 @@ TfDlManagedTensorCtx *TfDlManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status
5552 return ctx ;
5653}
5754
58- void TfDlManagedTensorCtx_Free ( TfDlManagedTensorCtx * ctx ) {
55+ void TFDLManagedTensorCtx_Free ( TFDLManagedTensorCtx * ctx ) {
5956 RedisModule_Free (ctx -> shape );
6057 RedisModule_Free (ctx -> strides );
6158 RedisModule_Free (ctx );
6259}
6360
64- // Deleter for DLManagedTensor
6561void DLManagedTensorDeleter (DLManagedTensor * arg ) {
66- TfDlManagedTensorCtx * owner = (TfDlManagedTensorCtx * )(arg -> manager_ctx );
67-
68- // TODO: check if we need to deleted the actual tensor as well
62+ TFDLManagedTensorCtx * owner = (TFDLManagedTensorCtx * )(arg -> manager_ctx );
6963 TFE_DeleteTensorHandle (owner -> reference );
70- TfDlManagedTensorCtx_Free (owner );
64+ TFDLManagedTensorCtx_Free (owner );
7165}
7266
73- // Converts TF_DATAType to DLPack data type.
74- DLDataType GetDlDataType (TF_DataType data_type , TF_Status * status ) {
67+ DLDataType GetDLDataType (TF_DataType data_type , TF_Status * status ) {
7568 DLDataType dtype ;
7669 dtype .lanes = 1 ;
7770 dtype .bits = TF_DataTypeSize (data_type ) * 8 ;
@@ -104,8 +97,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status *status) {
10497 return dtype ;
10598}
10699
107- // Gets DLPack's DLDevice from eager tensor handle.
108- DLDevice GetDlDevice (TFE_TensorHandle * h , TF_Status * status ) {
100+ DLDevice GetDLDevice (TFE_TensorHandle * h , TF_Status * status ) {
109101 DLDevice device ;
110102 const char * device_name = TFE_TensorHandleBackingDeviceName (h , status );
111103
@@ -135,8 +127,7 @@ DLDevice GetDlDevice(TFE_TensorHandle *h, TF_Status *status) {
135127 return device ;
136128}
137129
138- // Converts DLContext to TF device name.
139- int DeviceNameFromDlContext (const DLDevice * device , char device_name [64 ]) {
130+ int DeviceNameFromDLContext (const DLDevice * device , char device_name [64 ]) {
140131 switch (device -> device_type ) {
141132 case kDLCPU :
142133 strcpy (device_name , "CPU:0" );
@@ -148,8 +139,7 @@ int DeviceNameFromDlContext(const DLDevice *device, char device_name[64]) {
148139 return 1 ;
149140}
150141
151- // Converts DLPack data type to TF_DATATYPE.
152- int TfDataTypeFromDlDataType (const DLDataType * dtype , TF_DataType * tf_dtype ) {
142+ int TFDataTypeFromDLDataType (const DLDataType * dtype , TF_DataType * tf_dtype ) {
153143 switch (dtype -> code ) {
154144 case kDLUInt :
155145 switch (dtype -> bits ) {
@@ -216,14 +206,10 @@ int TfDataTypeFromDlDataType(const DLDataType *dtype, TF_DataType *tf_dtype) {
216206 }
217207}
218208
219- // Wraps the deleter function of DLManagedTensor to match the function signature
220- // TFE_NewTensorHandleFromDeviceMemory.
221209void DeallocatorWrapperFunc (void * data , size_t len , void * dlmt_vptr ) {
222210 TFE_CallDLManagedTensorDeleter (dlmt_vptr );
223211}
224212
225- // Checks whether the stride array matches the layout of compact, row-majored
226- // data.
227213bool IsValidStrideCompactRowMajorData (int64_t * shape_arr , int64_t * stride_arr , int ndim ) {
228214 if (ndim >= 1 && stride_arr [ndim - 1 ] != 1 ) {
229215 return false;
@@ -244,7 +230,7 @@ void TFE_CallDLManagedTensorDeleter(void *dlm_ptr) {
244230}
245231
246232void * TFE_HandleToDLPack (TFE_TensorHandle * h , TF_Status * status ) {
247- DLDevice tf_dlm_device = GetDlDevice (h , status );
233+ DLDevice tf_dlm_device = GetDLDevice (h , status );
248234 if (TF_GetCode (status ) != TF_OK ) {
249235 return NULL ;
250236 }
@@ -256,12 +242,12 @@ void *TFE_HandleToDLPack(TFE_TensorHandle *h, TF_Status *status) {
256242
257243 TF_DataType data_type = TFE_TensorHandleDataType (h );
258244
259- DLDataType tf_dlm_type = GetDlDataType (data_type , status );
245+ DLDataType tf_dlm_type = GetDLDataType (data_type , status );
260246 if (TF_GetCode (status ) != TF_OK ) {
261247 return NULL ;
262248 }
263249
264- TfDlManagedTensorCtx * tf_dlm_tensor_ctx = TfDlManagedTensorCtx_Create (h , status );
250+ TFDLManagedTensorCtx * tf_dlm_tensor_ctx = TFDLManagedTensorCtx_Create (h , status );
265251
266252 DLManagedTensor * dlm_tensor = & tf_dlm_tensor_ctx -> tensor ;
267253 dlm_tensor -> manager_ctx = tf_dlm_tensor_ctx ;
@@ -287,15 +273,15 @@ TFE_TensorHandle *TFE_HandleFromDLPack(void *dlm, TF_Status *status, TFE_Context
287273 DLManagedTensor * dlmt = (DLManagedTensor * )dlm ;
288274 DLTensor * dl_tensor = & dlmt -> dl_tensor ;
289275 char device_name [64 ];
290- int ret = DeviceNameFromDlContext (& dl_tensor -> device , device_name );
276+ int ret = DeviceNameFromDLContext (& dl_tensor -> device , device_name );
291277 if (ret != 0 ) {
292- // tensorflow::errors::InvalidArgument(" Unsupported Device Type");
278+ // TODO Unsupported device type
293279 return NULL ;
294280 }
295281 TF_DataType dtype ;
296- ret = TfDataTypeFromDlDataType (& dl_tensor -> dtype , & dtype );
282+ ret = TFDataTypeFromDLDataType (& dl_tensor -> dtype , & dtype );
297283 if (ret != 0 ) {
298- // status->status = std::move(s);
284+ // TODO Unsupported data type
299285 return NULL ;
300286 }
301287 int num_dims = dl_tensor -> ndim ;
@@ -421,8 +407,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
421407 uint8_t config [4 ] = {0x32 , 0x02 , 0x20 , 0x01 };
422408 TFE_ContextOptionsSetConfig (context_opts , (void * )config , 4 , status );
423409
424- // TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
425- // TFE_ContextOptionsSetAsync(context_opts, 0);
410+ TFE_ContextOptionsSetAsync (context_opts , 0 );
426411 TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
427412
428413 TFE_Context * context = TFE_NewContext (context_opts , status );
@@ -605,6 +590,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
605590 const size_t noutputs = array_len (mctxs [0 ]-> outputs );
606591 TFE_TensorHandle * inputTensorsHandles [ninputs ];
607592 TFE_TensorHandle * outputTensorsHandles [noutputs ];
593+ TFE_TensorHandle * deviceInputTensorsHandles [ninputs ];
594+ TFE_TensorHandle * deviceOutputTensorsHandles [noutputs ];
608595
609596 size_t batch_sizes [nbatches ];
610597 size_t batch_offsets [nbatches ];
@@ -655,7 +642,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
655642 return 1 ;
656643 }
657644
658- inputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (
645+ deviceInputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (
659646 inputTensorsHandles [i ], mctxs [0 ]-> model -> session , tf_devicestr , status );
660647
661648 if (TF_GetCode (status ) != TF_OK ) {
@@ -676,7 +663,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
676663 return 1 ;
677664 }
678665
679- TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
666+ TFE_OpAddInputList (fn_op , deviceInputTensorsHandles , ninputs , status );
680667 if (TF_GetCode (status ) != TF_OK ) {
681668 char * errorMessage = RedisModule_Strdup (TF_Message (status ));
682669 RAI_SetError (error , RAI_EMODELRUN , errorMessage );
@@ -686,7 +673,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
686673 }
687674
688675 int noutputs_ = noutputs ;
689- TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
676+ TFE_Execute (fn_op , deviceOutputTensorsHandles , & noutputs_ , status );
690677 if (TF_GetCode (status ) != TF_OK ) {
691678 char * errorMessage = RedisModule_Strdup (TF_Message (status ));
692679 RAI_SetError (error , RAI_EMODELRUN , errorMessage );
@@ -697,6 +684,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
697684
698685 for (size_t i = 0 ; i < ninputs ; ++ i ) {
699686 TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
687+ TFE_DeleteTensorHandle (deviceInputTensorsHandles [i ]);
700688 }
701689
702690 if (TF_GetCode (status ) != TF_OK ) {
@@ -709,9 +697,8 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
709697
710698 for (size_t i = 0 ; i < noutputs ; ++ i ) {
711699 outputTensorsHandles [i ] = TFE_TensorHandleCopyToDevice (
712- outputTensorsHandles [i ], mctxs [0 ]-> model -> session , "/device:CPU:0" , status );
700+ deviceOutputTensorsHandles [i ], mctxs [0 ]-> model -> session , "/device:CPU:0" , status );
713701
714- // TF_Tensor* outputTensor = TFE_TensorHandleResolve(outputTensorsHandles[i], status);
715702 RAI_Tensor * outputTensor =
716703 RAI_TensorCreateFromDLTensor (TFE_HandleToDLPack (outputTensorsHandles [i ], status ));
717704
@@ -728,7 +715,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
728715 continue ;
729716 }
730717 if (RAI_TensorDim (outputTensor , 0 ) != total_batch_size ) {
731- // TF_DeleteTensor (outputTensor);
718+ RAI_TensorFree (outputTensor );
732719 TF_DeleteStatus (status );
733720 RAI_SetError (error , RAI_EMODELRUN ,
734721 "ERR Model did not generate the expected batch size" );
@@ -743,7 +730,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
743730 mctxs [0 ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (outputTensor );
744731 }
745732 RAI_TensorFree (outputTensor );
746- TFE_DeleteTensorHandle (outputTensorsHandles [i ]);
733+ TFE_DeleteTensorHandle (deviceOutputTensorsHandles [i ]);
747734 }
748735
749736 TF_DeleteStatus (status );
0 commit comments