44
55mutable struct ConcreteRArray{T,N} <: RArray{T,N}
66 data:: XLA.AsyncBuffer
7- # data::XLAArray{T, N}
7+ # data::XLAArray{T, N}
88 shape:: NTuple{N,Int}
99end
1010
11- ConcreteRArray (data:: T ) where {T<: Number } = ConcreteRArray (fill (data))
11+ mutable struct ConcreteRNumber{T} <: RNumber{T}
12+ data:: XLA.AsyncBuffer
13+ end
14+
15+ function ConcreteRNumber (
16+ data:: T ; client= XLA. default_backend[], idx= XLA. default_device_idx[]
17+ ) where {T<: Number }
18+ crarray = ConcreteRArray (fill (data); client, idx)
19+ return ConcreteRNumber {T} (crarray. data)
20+ end
21+
22+ Base. size (:: ConcreteRNumber ) = ()
23+
24+ function ConcreteRArray (
25+ data:: T ; client= XLA. default_backend[], idx= XLA. default_device_idx[]
26+ ) where {T<: Number }
27+ Base. depwarn (
28+ " ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead" ,
29+ :ConcreteRArray ,
30+ )
31+ return ConcreteRArray (fill (data); client, idx)
32+ end
33+
34+ const ConcreteRScalar{T} = Union{ConcreteRArray{T,0 },ConcreteRNumber{T}}
1235
1336Adapt. adapt_storage (:: Type{T} , x:: AbstractArray ) where {T<: ConcreteRArray } = T (x)
1437
@@ -48,7 +71,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
4871 # XLA.from_row_major(data)
4972end
5073
51- function synchronize (x:: ConcreteRArray )
74+ function synchronize (x:: Union{ ConcreteRArray,ConcreteRNumber} )
5275 XLA. synced_buffer (x. data)
5376 return nothing
5477end
6083# return ConcreteRArray{T,N}(x.data)
6184# end
6285
63- function to_float (X:: ConcreteRArray{T,0 } ) where {T}
86+ function to_number (X:: ConcreteRScalar{T } ) where {T}
6487 data = Ref {T} ()
6588 XLA. await (X. data)
6689 buf = X. data. buffer
@@ -70,36 +93,49 @@ function to_float(X::ConcreteRArray{T,0}) where {T}
7093 return data[]
7194end
7295
73- function Base. convert (:: Type{T} , x:: ConcreteRArray{T,0} ) where {T}
74- return to_float (x)
96+ Base. convert (:: Type{T} , x:: ConcreteRScalar{T} ) where {T} = to_number (x)
97+
98+ for jlop in (
99+ :(Base. isless),
100+ :(Base.:+ ),
101+ :(Base.:- ),
102+ :(Base.:* ),
103+ :(Base.:/ ),
104+ :(Base.:^ ),
105+ :(Base.:(== )),
106+ ),
107+ T in (ConcreteRNumber, ConcreteRArray{<: Any ,0 })
108+
109+ @eval begin
110+ $ (jlop)(x:: $ (T), y:: $ (T)) = $ (jlop)(to_number (x), to_number (y))
111+ $ (jlop)(x:: $ (T), y:: Number ) = $ (jlop)(to_number (x), y)
112+ $ (jlop)(x:: Number , y:: $ (T)) = $ (jlop)(x, to_number (y))
113+ end
75114end
76115
77- for jlop in (:(Base . isless), :(Base.: + ), :(Base.: - ), :(Base.: * ), :(Base.: / ), :(Base.: ^ ) )
116+ for T in (ConcreteRNumber, ConcreteRArray{ <: Any , 0 } )
78117 @eval begin
79- function $jlop (x:: ConcreteRArray{T,0} , y:: ConcreteRArray{U,0} ) where {T,U}
80- return $ jlop ( to_float (x), to_float (y) )
118+ function Base . isapprox (x:: $ (T), y:: Number ; kwargs ... )
119+ return Base . isapprox ( to_number (x), y; kwargs ... )
81120 end
82- function $jlop (x:: ConcreteRArray{T,0} , y) where {T}
83- return $ jlop (to_float (x), y)
121+
122+ function Base. isapprox (x:: Number , y:: $ (T); kwargs... )
123+ return Base. isapprox (x, to_number (y); kwargs... )
84124 end
85- function $jlop (x, y:: ConcreteRArray{U,0} ) where {U}
86- return $ jlop (x, to_float (y))
125+
126+ function Base. isapprox (x:: $ (T), y:: $ (T); kwargs... )
127+ return Base. isapprox (to_number (x), to_number (y); kwargs... )
87128 end
88129 end
89130end
90131
91- function Base. isapprox (x:: ConcreteRArray{T,0} , y; kwargs... ) where {T}
92- return Base. isapprox (to_float (x), y; kwargs... )
93- end
94-
95- function Base. isapprox (x, y:: ConcreteRArray{T,0} ; kwargs... ) where {T}
96- return Base. isapprox (x, to_float (y); kwargs... )
97- end
98-
99- function Base. isapprox (
100- x:: ConcreteRArray{T,0} , y:: ConcreteRArray{T2,0} ; kwargs...
101- ) where {T,T2}
102- return Base. isapprox (to_float (x), to_float (y); kwargs... )
132+ function Base. show (io:: IO , X:: ConcreteRScalar{T} ) where {T}
133+ if X. data == XLA. AsyncEmptyBuffer
134+ println (io, " <Empty buffer>" )
135+ return nothing
136+ end
137+ str = sprint (show, to_number (X))
138+ return print (io, " $(typeof (X)) ($(str) )" )
103139end
104140
105141function Base. print_array (io:: IO , X:: ConcreteRArray )
@@ -115,7 +151,8 @@ function Base.show(io::IO, X::ConcreteRArray)
115151 println (io, " <Empty buffer>" )
116152 return nothing
117153 end
118- return Base. show (io, convert (Array, X))
154+ str = sprint (show, convert (Array, X))
155+ return print (io, " $(typeof (X)) ($(str) )" )
119156end
120157
121158const getindex_warned = Ref (false )
0 commit comments