@@ -77,6 +77,7 @@ function cannot_shuffle(
7777 )
7878 ))
7979end
80+ const DOUBLE_THROUGHPUT = occursin (" apple-m" , LoopVectorization. get_cpu_name ())
8081function cost (
8182 ls:: LoopSet ,
8283 op:: Operation ,
@@ -104,6 +105,7 @@ function cost(
104105 # all(opp -> (isloopvalue(opp) | isconstant(opp)), parents(op))
105106 return 0.0 , 0 , 0.0
106107 end
108+ shuffle_rt = 0
107109 opisvectorized = isvectorized (op)
108110 srt, sl, srp =
109111 opisvectorized ? vector_cost (instr, Wshift, size_T) : scalar_cost (instr)
@@ -131,6 +133,7 @@ function cost(
131133 if isload (op) & (length (loopdependencies (op)) > 1 )# vmov(a/u)pd
132134 srt += 0.5 reg_size (ls) / cache_lnsze (ls)
133135 end
136+ shuffle_rt += shifter
134137 srt += shifter # shifter == number of shuffles
135138 sl += shifter
136139 end
@@ -150,7 +153,11 @@ function cost(
150153 sl *= 3
151154 end
152155 end
153- srt, sl, Float64 (srp + 1 )
156+ if DOUBLE_THROUGHPUT
157+ srt *= 0.5
158+ shuffle_rt >>= 1
159+ end
160+ srt, sl, Float64 (srp + 1 ), shuffle_rt
154161end
155162
156163# Base._return_type()
@@ -252,6 +259,11 @@ function depchain_cost!(
252259 rtᵢ, slᵢ = cost (ls, op, (unrolled, Symbol (" " )), vloopsym, Wshift, size_T)
253260 rt += rtᵢ
254261 sl += slᵢ
262+ elseif isload (op)
263+ _, _, _, shufflecost =
264+ cost (ls, op, (unrolled, Symbol (" " )), vloopsym, Wshift, size_T)
265+ rt += shufflecost
266+ sl += shufflecost
255267 end
256268 rt, sl
257269end
@@ -357,7 +369,7 @@ function unroll_no_reductions(ls, order, vloopsym)
357369 # # (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
358370 # (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
359371end
360- function determine_unroll_factor (
372+ function rthroughput_latency (
361373 ls:: LoopSet ,
362374 order:: Vector{Symbol} ,
363375 unrolled:: Symbol ,
@@ -390,17 +402,17 @@ function determine_unroll_factor(
390402 if isouterreduction (ls, op) ≠ - 1 || unrolled ∉ reduceddependencies (op)
391403 latency = max (sl, latency)
392404 end
393- # if unrolled ∈ loopdependencies(op)
394- # compute_recip_throughput_u += rt
395- # else
396405 compute_recip_throughput += rt
397- # end
398406 elseif isload (op)
399- load_recip_throughput +=
400- first (cost (ls, op, (unrolled, Symbol (" " )), vloopsym, Wshift, size_T))
407+ lrt, _, _, shufflert =
408+ cost (ls, op, (unrolled, Symbol (" " )), vloopsym, Wshift, size_T)
409+ load_recip_throughput += lrt - shufflert
410+ # shufflert considered as part of depchain_cost!
401411 elseif isstore (op)
402- store_recip_throughput +=
403- first (cost (ls, op, (unrolled, Symbol (" " )), vloopsym, Wshift, size_T))
412+ srt, _, _, shufflert =
413+ cost (ls, op, (unrolled, Symbol (" " )), vloopsym, Wshift, size_T)
414+ store_recip_throughput += srt - shufflert
415+ compute_recip_throughput += shufflert
404416 end
405417 end
406418 recip_throughput =
@@ -447,7 +459,7 @@ function determine_unroll_factor(
447459 elseif iszero (num_reductions) # handle `BitArray` loops w/out reductions
448460 return 8 ÷ ls. vector_width, vloopsym
449461 else # handle `BitArray` loops with reductions
450- rttemp, ltemp = determine_unroll_factor (ls, order, vloopsym, vloopsym)
462+ rttemp, ltemp = rthroughput_latency (ls, order, vloopsym, vloopsym)
451463 UF =
452464 min (8 , VectorizationBase. nextpow2 (max (1 , round (Int, ltemp / (rttemp)))))
453465 UFfactor = 8 ÷ ls. vector_width
@@ -471,7 +483,7 @@ function determine_unroll_factor(
471483 best_unrolled = Symbol (" " )
472484 for unrolled ∈ order
473485 reject_reorder (ls, unrolled, false ) && continue
474- rttemp, ltemp = determine_unroll_factor (ls, order, unrolled, vloopsym)
486+ rttemp, ltemp = rthroughput_latency (ls, order, unrolled, vloopsym)
475487 rtcomptemp =
476488 rttemp + (
477489 0.01 *
@@ -1156,23 +1168,23 @@ end
11561168
11571169update_cost_vec! (costs, cost, u₁reduces, u₂reduces) = @inbounds if u₁reduces &
11581170 u₂reduces
1159- costs[4 ] += cost
1160- elseif u₂reduces # cost decreased by unrolling u₂loop
1161- costs[2 ] += cost
1162- elseif u₁reduces # cost decreased by unrolling u₁loop
1163- costs[3 ] += cost
1164- else # no cost decrease; cost must be repeated
1165- costs[1 ] += cost
1166- end
1171+ costs[4 ] += cost
1172+ elseif u₂reduces # cost decreased by unrolling u₂loop
1173+ costs[2 ] += cost
1174+ elseif u₁reduces # cost decreased by unrolling u₁loop
1175+ costs[3 ] += cost
1176+ else # no cost decrease; cost must be repeated
1177+ costs[1 ] += cost
1178+ end
11671179update_reg_pres! (rp, cost, u₁reduces, u₂reduces) = @inbounds if u₁reduces# & u₂reduces
1168- rp[4 ] -= cost
1169- elseif u₂reduces # cost decreased by unrolling u₂loop
1170- rp[2 ] += cost
1171- # elseif u₁reduces # cost decreased by unrolling u₁loop
1172- # rp[4] -= cost
1173- else # no cost decrease; cost must be repeated
1174- rp[1 ] += cost
1175- end
1180+ rp[4 ] -= cost
1181+ elseif u₂reduces # cost decreased by unrolling u₂loop
1182+ rp[2 ] += cost
1183+ # elseif u₁reduces # cost decreased by unrolling u₁loop
1184+ # rp[4] -= cost
1185+ else # no cost decrease; cost must be repeated
1186+ rp[1 ] += cost
1187+ end
11761188function child_dependent_u₁u₂ (op:: Operation )
11771189 u₁ = u₂ = false
11781190 for opc ∈ children (op)
0 commit comments