Skip to content

Commit ac0f590

Browse files
committed
Add comptime_float support to std.math.float functions
1 parent c479d05 commit ac0f590

File tree

1 file changed

+60
-47
lines changed

1 file changed

+60
-47
lines changed

lib/std/math/float.zig

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -117,21 +117,28 @@ pub fn FloatRepr(comptime Float: type) type {
117117

118118
/// Creates a raw "1.0" mantissa for floating point type T. Used to dedupe f80 logic.
119119
inline fn mantissaOne(comptime T: type) comptime_int {
120+
if (T == comptime_float) return 0;
120121
return if (@typeInfo(T).float.bits == 80) 1 << floatFractionalBits(T) else 0;
121122
}
122123

123124
/// Creates floating point type T from an unbiased exponent and raw mantissa.
124125
inline fn reconstructFloat(comptime T: type, comptime exponent: comptime_int, comptime mantissa: comptime_int) T {
125-
const TBits = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(T) } });
126-
const biased_exponent = @as(TBits, exponent + floatExponentMax(T));
127-
return @as(T, @bitCast((biased_exponent << floatMantissaBits(T)) | @as(TBits, mantissa)));
126+
const UBits, const FBits = switch (@typeInfo(T)) {
127+
.float => |float| .{ std.meta.Int(.unsigned, float.bits), T },
128+
.comptime_float => .{ std.meta.Int(.unsigned, 128), f128 },
129+
else => unreachable,
130+
};
131+
const biased_exponent = @as(UBits, exponent + floatExponentMax(T));
132+
return @as(T, @as(FBits, @bitCast((biased_exponent << floatMantissaBits(T)) | @as(UBits, mantissa))));
128133
}
129134

130135
/// Returns the number of bits in the exponent of floating point type T.
131136
pub inline fn floatExponentBits(comptime T: type) comptime_int {
132-
comptime assert(@typeInfo(T) == .float);
137+
const info = @typeInfo(T);
138+
comptime assert(info == .float or info == .comptime_float);
133139

134-
return switch (@typeInfo(T).float.bits) {
140+
if (info == .comptime_float) return 15;
141+
return switch (info.float.bits) {
135142
16 => 5,
136143
32 => 8,
137144
64 => 11,
@@ -143,9 +150,11 @@ pub inline fn floatExponentBits(comptime T: type) comptime_int {
143150

144151
/// Returns the number of bits in the mantissa of floating point type T.
145152
pub inline fn floatMantissaBits(comptime T: type) comptime_int {
146-
comptime assert(@typeInfo(T) == .float);
153+
const info = @typeInfo(T);
154+
comptime assert(info == .float or info == .comptime_float);
147155

148-
return switch (@typeInfo(T).float.bits) {
156+
if (info == .comptime_float) return 112;
157+
return switch (info.float.bits) {
149158
16 => 10,
150159
32 => 23,
151160
64 => 52,
@@ -157,12 +166,14 @@ pub inline fn floatMantissaBits(comptime T: type) comptime_int {
157166

158167
/// Returns the number of fractional bits in the mantissa of floating point type T.
159168
pub inline fn floatFractionalBits(comptime T: type) comptime_int {
160-
comptime assert(@typeInfo(T) == .float);
169+
const info = @typeInfo(T);
170+
comptime assert(info == .float or info == .comptime_float);
161171

162172
// standard IEEE floats have an implicit 0.m or 1.m integer part
163173
// f80 is special and has an explicitly stored bit in the MSB
164174
// this function corresponds to `MANT_DIG - 1' from C
165-
return switch (@typeInfo(T).float.bits) {
175+
if (info == .comptime_float) return 112;
176+
return switch (info.float.bits) {
166177
16 => 10,
167178
32 => 23,
168179
64 => 52,
@@ -208,58 +219,58 @@ pub inline fn floatEps(comptime T: type) T {
208219
/// Returns the local epsilon of floating point type T.
209220
pub inline fn floatEpsAt(comptime T: type, x: T) T {
210221
switch (@typeInfo(T)) {
211-
.float => |F| {
212-
const U: type = @Type(.{ .int = .{ .signedness = .unsigned, .bits = F.bits } });
222+
.float => |float| {
223+
const U = std.meta.Int(.unsigned, float.bits);
213224
const u: U = @bitCast(x);
214225
const y: T = @bitCast(u ^ 1);
215226
return @abs(x - y);
216227
},
228+
.comptime_float => {
229+
const u: u128 = @bitCast(@as(f128, x));
230+
const y: f128 = @bitCast(u ^ 1);
231+
return @as(comptime_float, @abs(x - y));
232+
},
217233
else => @compileError("floatEpsAt only supports floats"),
218234
}
219235
}
220236

221237
/// Returns the inf value for a floating point `Type`.
222238
pub inline fn inf(comptime Type: type) Type {
223-
const RuntimeType = switch (Type) {
224-
else => Type,
225-
comptime_float => f128, // any float type will do
239+
return switch (@typeInfo(Type)) {
240+
.float => reconstructFloat(Type, floatExponentMax(Type) + 1, mantissaOne(Type)),
241+
.comptime_float => @compileError("comptime_float cannot be infinity"),
242+
else => @compileError("unknown floating point type " ++ @typeName(Type)),
226243
};
227-
return reconstructFloat(RuntimeType, floatExponentMax(RuntimeType) + 1, mantissaOne(RuntimeType));
228244
}
229245

230246
/// Returns the canonical quiet NaN representation for a floating point `Type`.
231247
pub inline fn nan(comptime Type: type) Type {
232-
const RuntimeType = switch (Type) {
233-
else => Type,
234-
comptime_float => f128, // any float type will do
248+
return switch (@typeInfo(Type)) {
249+
.float => reconstructFloat(Type, floatExponentMax(Type) + 1, mantissaOne(Type) | 1 << (floatFractionalBits(Type) - 1)),
250+
.comptime_float => @compileError("comptime_float cannot be NaN"),
251+
else => @compileError("unknown floating point type " ++ @typeName(Type)),
235252
};
236-
return reconstructFloat(
237-
RuntimeType,
238-
floatExponentMax(RuntimeType) + 1,
239-
mantissaOne(RuntimeType) | 1 << (floatFractionalBits(RuntimeType) - 1),
240-
);
241253
}
242254

243255
/// Returns a signalling NaN representation for a floating point `Type`.
244256
///
245257
/// TODO: LLVM is known to miscompile on some architectures to quiet NaN -
246258
/// this is tracked by https://github.com/ziglang/zig/issues/14366
247259
pub inline fn snan(comptime Type: type) Type {
248-
const RuntimeType = switch (Type) {
249-
else => Type,
250-
comptime_float => f128, // any float type will do
260+
return switch (@typeInfo(Type)) {
261+
.float => reconstructFloat(Type, floatExponentMax(Type) + 1, mantissaOne(Type) | 1 << (floatFractionalBits(Type) - 2)),
262+
.comptime_float => @compileError("comptime_float cannot be NaN"),
263+
else => @compileError("unknown floating point type " ++ @typeName(Type)),
251264
};
252-
return reconstructFloat(
253-
RuntimeType,
254-
floatExponentMax(RuntimeType) + 1,
255-
mantissaOne(RuntimeType) | 1 << (floatFractionalBits(RuntimeType) - 2),
256-
);
257265
}
258266

259267
fn floatBits(comptime Type: type) !void {
260268
// (1 +) for the sign bit, since it is separate from the other bits
261269
const size = 1 + floatExponentBits(Type) + floatMantissaBits(Type);
262-
try expect(@bitSizeOf(Type) == size);
270+
if (@typeInfo(Type) == .float)
271+
try expect(@bitSizeOf(Type) == size)
272+
else
273+
try expect(128 == size);
263274
try expect(floatFractionalBits(Type) <= floatMantissaBits(Type));
264275

265276
// for machine epsilon, assert expmin <= -prec <= expmax
@@ -273,6 +284,8 @@ test floatBits {
273284
try floatBits(f80);
274285
try floatBits(f128);
275286
try floatBits(c_longdouble);
287+
try floatBits(comptime_float);
288+
try comptime floatBits(comptime_float);
276289
}
277290

278291
test inf {
@@ -281,11 +294,11 @@ test inf {
281294
const inf_u64: u64 = 0x7FF0000000000000;
282295
const inf_u80: u80 = 0x7FFF8000000000000000;
283296
const inf_u128: u128 = 0x7FFF0000000000000000000000000000;
284-
try expectEqual(inf_u16, @as(u16, @bitCast(inf(f16))));
285-
try expectEqual(inf_u32, @as(u32, @bitCast(inf(f32))));
286-
try expectEqual(inf_u64, @as(u64, @bitCast(inf(f64))));
287-
try expectEqual(inf_u80, @as(u80, @bitCast(inf(f80))));
288-
try expectEqual(inf_u128, @as(u128, @bitCast(inf(f128))));
297+
try expect(inf_u16 == @as(u16, @bitCast(inf(f16))));
298+
try expect(inf_u32 == @as(u32, @bitCast(inf(f32))));
299+
try expect(inf_u64 == @as(u64, @bitCast(inf(f64))));
300+
try expect(inf_u80 == @as(u80, @bitCast(inf(f80))));
301+
try expect(inf_u128 == @as(u128, @bitCast(inf(f128))));
289302
}
290303

291304
test nan {
@@ -294,11 +307,11 @@ test nan {
294307
const qnan_u64: u64 = 0x7FF8000000000000;
295308
const qnan_u80: u80 = 0x7FFFC000000000000000;
296309
const qnan_u128: u128 = 0x7FFF8000000000000000000000000000;
297-
try expectEqual(qnan_u16, @as(u16, @bitCast(nan(f16))));
298-
try expectEqual(qnan_u32, @as(u32, @bitCast(nan(f32))));
299-
try expectEqual(qnan_u64, @as(u64, @bitCast(nan(f64))));
300-
try expectEqual(qnan_u80, @as(u80, @bitCast(nan(f80))));
301-
try expectEqual(qnan_u128, @as(u128, @bitCast(nan(f128))));
310+
try expect(qnan_u16 == @as(u16, @bitCast(nan(f16))));
311+
try expect(qnan_u32 == @as(u32, @bitCast(nan(f32))));
312+
try expect(qnan_u64 == @as(u64, @bitCast(nan(f64))));
313+
try expect(qnan_u80 == @as(u80, @bitCast(nan(f80))));
314+
try expect(qnan_u128 == @as(u128, @bitCast(nan(f128))));
302315
}
303316

304317
test snan {
@@ -307,9 +320,9 @@ test snan {
307320
const snan_u64: u64 = 0x7FF4000000000000;
308321
const snan_u80: u80 = 0x7FFFA000000000000000;
309322
const snan_u128: u128 = 0x7FFF4000000000000000000000000000;
310-
try expectEqual(snan_u16, @as(u16, @bitCast(snan(f16))));
311-
try expectEqual(snan_u32, @as(u32, @bitCast(snan(f32))));
312-
try expectEqual(snan_u64, @as(u64, @bitCast(snan(f64))));
313-
try expectEqual(snan_u80, @as(u80, @bitCast(snan(f80))));
314-
try expectEqual(snan_u128, @as(u128, @bitCast(snan(f128))));
323+
try expect(snan_u16 == @as(u16, @bitCast(snan(f16))));
324+
try expect(snan_u32 == @as(u32, @bitCast(snan(f32))));
325+
try expect(snan_u64 == @as(u64, @bitCast(snan(f64))));
326+
try expect(snan_u80 == @as(u80, @bitCast(snan(f80))));
327+
try expect(snan_u128 == @as(u128, @bitCast(snan(f128))));
315328
}

0 commit comments

Comments
 (0)