|
20 | 20 | Base.BroadcastStyle(::Type{<:DArray}) = Broadcast.ArrayStyle{DArray}() |
21 | 21 | Base.BroadcastStyle(::Type{<:DArray}, ::Any) = Broadcast.ArrayStyle{DArray}() |
22 | 22 |
|
23 | | -function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}, ::Type{ElType}) where {ElType} |
24 | | - DA = find_darray(bc) |
25 | | - DArray(I -> Array{ElType}(undef, map(length,I)), DA) |
26 | | -end |
27 | | - |
28 | | -"`DA = find_darray(As)` returns the first DArray among the arguments." |
29 | | -find_darray(bc::Base.Broadcast.Broadcasted) = find_darray(bc.args) |
30 | | -find_darray(args::Tuple) = find_darray(find_darray(args[1]), Base.tail(args)) |
31 | | -find_darray(x) = x |
32 | | -find_darray(a::DArray, rest) = a |
33 | | -find_darray(::Any, rest) = find_darray(rest) |
34 | | - |
35 | | -function Base.copyto!(dest::DArray, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) |
36 | | - @sync for p in procs(dest) |
37 | | - @async remotecall_fetch(p) do |
38 | | - copyto!(localpart(dest), rewrite_local(bc)) |
39 | | - end |
| 23 | +function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) |
| 24 | + T = Base.Broadcast.combine_eltypes(bc.f, bc.args) |
| 25 | + shape = Base.Broadcast.combine_axes(bc.args...) |
| 26 | + iter = Base.CartesianIndices(shape) |
| 27 | + D = DArray(map(length, shape)) do I |
| 28 | + A = map(bc.args) do a |
| 29 | + if isa(a, Union{Number,Ref}) |
| 30 | + return a |
| 31 | + else |
| 32 | + return localtype(a)( |
| 33 | + a[ntuple(i -> i > ndims(a) ? 1 : (size(a, i) == 1 ? (1:1) : I[i]), length(shape))...] |
| 34 | + ) |
| 35 | + end |
| 36 | + end |
| 37 | + broadcast(bc.f, A...) |
40 | 38 | end |
41 | | - dest |
| 39 | + return D |
42 | 40 | end |
43 | 41 |
|
44 | | -""" |
45 | | -Transform a Broadcasted{Broadcast.ArrayStyle{DArray}} object into an equivalent |
46 | | -Broadcasted{Broadcast.DefaultArrayStyle} object for the localparts. |
47 | | -""" |
48 | | -rewrite_local(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) = Broadcast.broadcasted(bc.f, rewrite_local(bc.args)...) |
49 | | -rewrite_local(args::Tuple) = map(rewrite_local, args) |
50 | | -rewrite_local(a::DArray) = localpart(a) |
51 | | -rewrite_local(x) = x |
52 | | - |
53 | | - |
54 | 42 | function Base.reduce(f, d::DArray) |
55 | 43 | results = asyncmap(procs(d)) do p |
56 | 44 | remotecall_fetch(p) do |
@@ -128,6 +116,7 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray) |
128 | 116 | return mapreducedim_between!(identity, op, R, B, region) |
129 | 117 | end |
130 | 118 |
|
| 119 | +## Some special cases |
131 | 120 | function Base._all(f, A::DArray, ::Colon) |
132 | 121 | B = asyncmap(procs(A)) do p |
133 | 122 | remotecall_fetch(p) do |
@@ -171,6 +160,8 @@ function Base.extrema(d::DArray) |
171 | 160 | return reduce((t,s) -> (min(t[1], s[1]), max(t[2], s[2])), r) |
172 | 161 | end |
173 | 162 |
|
| 163 | +Statistics._mean(A::DArray, region) = sum(A, dims = region) ./ prod((size(A, i) for i in region)) |
| 164 | + |
174 | 165 | # Unary vector functions |
175 | 166 | (-)(D::DArray) = map(-, D) |
176 | 167 |
|
|
0 commit comments