@@ -65,16 +65,80 @@ IMPL_NAIVE_DOUBLE_INPUT_INTERNAL(interelem) {
6565 func = interelem_div<TA, TB, TC>;
6666 break ;
6767 case Param::Mod:
68- func = interelem_mod<TA, TB, TC>;
68+ if (la.dtype .enumv () == DTypeEnum::Float32 ||
69+ la.dtype .enumv () == DTypeEnum::Float64 ||
70+ lb.dtype .enumv () == DTypeEnum::Float32 ||
71+ lb.dtype .enumv () == DTypeEnum::Float64) {
72+ return Status (StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
73+ " The mod operator only accepts int or long as input" );
74+ }
75+ if (loup.dtype .enumv () == DTypeEnum::Int32) {
76+ func = interelem_mod<nn_int32, nn_int32, nn_int32>;
77+ } else if (loup.dtype .enumv () == DTypeEnum::Int64) {
78+ func = interelem_mod<nn_int64, nn_int64, nn_int64>;
79+ } else if (loup.dtype .enumv () == DTypeEnum::Bool) {
80+ func = interelem_mod<bool , bool , bool >;
81+ } else {
82+ return Status (StatusCategory::SYSTEM, StatusCode::FAIL,
83+ " Unexpected exception in interelem.cpp." );
84+ }
6985 break ;
7086 case Param::And:
71- func = interelem_and<TA, TB, TC>;
87+ if (la.dtype .enumv () == DTypeEnum::Float32 ||
88+ la.dtype .enumv () == DTypeEnum::Float64 ||
89+ lb.dtype .enumv () == DTypeEnum::Float32 ||
90+ lb.dtype .enumv () == DTypeEnum::Float64) {
91+ return Status (StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
92+ " The and operator only accepts int or long as input" );
93+ }
94+ if (loup.dtype .enumv () == DTypeEnum::Int32) {
95+ func = interelem_and<nn_int32, nn_int32, nn_int32>;
96+ } else if (loup.dtype .enumv () == DTypeEnum::Int64) {
97+ func = interelem_and<nn_int64, nn_int64, nn_int64>;
98+ } else if (loup.dtype .enumv () == DTypeEnum::Bool) {
99+ func = interelem_and<bool , bool , bool >;
100+ } else {
101+ return Status (StatusCategory::SYSTEM, StatusCode::FAIL,
102+ " Unexpected exception in interelem.cpp." );
103+ }
72104 break ;
73105 case Param::Or:
74- func = interelem_or<TA, TB, TC>;
106+ if (la.dtype .enumv () == DTypeEnum::Float32 ||
107+ la.dtype .enumv () == DTypeEnum::Float64 ||
108+ lb.dtype .enumv () == DTypeEnum::Float32 ||
109+ lb.dtype .enumv () == DTypeEnum::Float64) {
110+ return Status (StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
111+ " The or operator only accepts int or long as input" );
112+ }
113+ if (loup.dtype .enumv () == DTypeEnum::Int32) {
114+ func = interelem_or<nn_int32, nn_int32, nn_int32>;
115+ } else if (loup.dtype .enumv () == DTypeEnum::Int64) {
116+ func = interelem_or<nn_int64, nn_int64, nn_int64>;
117+ } else if (loup.dtype .enumv () == DTypeEnum::Bool) {
118+ func = interelem_or<bool , bool , bool >;
119+ } else {
120+ return Status (StatusCategory::SYSTEM, StatusCode::FAIL,
121+ " Unexpected exception in interelem.cpp." );
122+ }
75123 break ;
76124 case Param::Xor:
77- func = interelem_xor<TA, TB, TC>;
125+ if (la.dtype .enumv () == DTypeEnum::Float32 ||
126+ la.dtype .enumv () == DTypeEnum::Float64 ||
127+ lb.dtype .enumv () == DTypeEnum::Float32 ||
128+ lb.dtype .enumv () == DTypeEnum::Float64) {
129+ return Status (StatusCategory::NUMNET, StatusCode::MISMATCHED_DTYPE,
130+ " The xor operator only accepts int or long as input" );
131+ }
132+ if (loup.dtype .enumv () == DTypeEnum::Int32) {
133+ func = interelem_xor<nn_int32, nn_int32, nn_int32>;
134+ } else if (loup.dtype .enumv () == DTypeEnum::Int64) {
135+ func = interelem_xor<nn_int64, nn_int64, nn_int64>;
136+ } else if (loup.dtype .enumv () == DTypeEnum::Bool) {
137+ func = interelem_xor<bool , bool , bool >;
138+ } else {
139+ return Status (StatusCategory::SYSTEM, StatusCode::FAIL,
140+ " Unexpected exception in interelem.cpp." );
141+ }
78142 break ;
79143
80144 default :
0 commit comments