Skip to content

Commit 6727f2a

Browse files
rongoualliepiper
authored andcommitted
reduce cudaDeviceSynchronize calls
1 parent ed6b727 commit 6727f2a

File tree

2 files changed

+20
-32
lines changed

2 files changed

+20
-32
lines changed

thrust/system/cuda/detail/par.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,31 +69,6 @@ struct execute_on_stream_base : execution_policy<Derived>
6969
{
7070
return exec.stream;
7171
}
72-
73-
friend __host__ __device__
74-
cudaError_t
75-
synchronize_stream(execute_on_stream_base &exec)
76-
{
77-
cudaError_t result;
78-
if (THRUST_IS_HOST_CODE) {
79-
#if THRUST_INCLUDE_HOST_CODE
80-
cudaStreamSynchronize(exec.stream);
81-
result = cudaGetLastError();
82-
#endif
83-
} else {
84-
#if THRUST_INCLUDE_DEVICE_CODE
85-
#if __THRUST_HAS_CUDART__
86-
THRUST_UNUSED_VAR(exec);
87-
cudaDeviceSynchronize();
88-
result = cudaGetLastError();
89-
#else
90-
THRUST_UNUSED_VAR(exec);
91-
result = cudaSuccess;
92-
#endif
93-
#endif
94-
}
95-
return result;
96-
}
9772
};
9873

9974
struct execute_on_stream : execute_on_stream_base<execute_on_stream>

thrust/system/cuda/detail/util.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,27 @@ __thrust_exec_check_disable__
7272
template <class Derived>
7373
__host__ __device__
7474
cudaError_t
75-
synchronize_stream(execution_policy<Derived> &)
75+
synchronize_stream(execution_policy<Derived> &policy)
7676
{
77-
#if __THRUST_HAS_CUDART__
78-
cudaDeviceSynchronize();
79-
return cudaGetLastError();
80-
#else
81-
return cudaSuccess;
82-
#endif
77+
cudaError_t result;
78+
if (THRUST_IS_HOST_CODE) {
79+
#if THRUST_INCLUDE_HOST_CODE
80+
cudaStreamSynchronize(stream(policy));
81+
result = cudaGetLastError();
82+
#endif
83+
} else {
84+
#if THRUST_INCLUDE_DEVICE_CODE
85+
#if __THRUST_HAS_CUDART__
86+
THRUST_UNUSED_VAR(policy);
87+
cudaDeviceSynchronize();
88+
result = cudaGetLastError();
89+
#else
90+
THRUST_UNUSED_VAR(policy);
91+
result = cudaSuccess;
92+
#endif
93+
#endif
94+
}
95+
return result;
8396
}
8497

8598
// Entry point/interface.

0 commit comments

Comments
 (0)