Skip to content

Commit bf267c3

Browse files
committed
Fix ConstantSpec and use overloading for const casts
1 parent 8a00fa7 commit bf267c3

File tree

2 files changed

+169
-97
lines changed

2 files changed

+169
-97
lines changed

Sources/LLVM/Constant.swift

Lines changed: 166 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,30 @@ public struct Constant<Repr: ConstantRepresentation>: IRValue {
4343

4444
// MARK: Casting
4545

46-
extension Constant where Repr == Unsigned {
46+
extension Constant where Repr: IntegralConstantRepresentation {
4747

4848
/// Creates a constant cast to a given integral type.
4949
///
5050
/// - parameter type: The type to cast towards.
5151
///
5252
/// - returns: A const value representing this value cast to the given
5353
/// integral type.
54-
public func cast<T: IntegralConstantRepresentation>(to type: IntType) -> Constant<T> {
55-
let destID = ObjectIdentifier(T.self)
56-
let val = self.asLLVM()
57-
if destID == ObjectIdentifier(Unsigned.self) {
58-
return Constant<T>(llvm: LLVMConstIntCast(val, type.asLLVM(), /*signed:*/ false.llvm))
59-
} else if destID == ObjectIdentifier(Signed.self) {
60-
return Constant<T>(llvm: LLVMConstIntCast(val, type.asLLVM(), /*signed:*/ true.llvm))
61-
} else {
62-
fatalError("Invalid representation \(type(of: T.self))")
63-
}
54+
public func cast(to type: IntType) -> Constant<Signed> {
55+
return Constant<Signed>(llvm: LLVMConstIntCast(llvm, type.asLLVM(), /*signed:*/ true.llvm))
56+
}
57+
58+
/// Creates a constant cast to a given integral type.
59+
///
60+
/// - parameter type: The type to cast towards.
61+
///
62+
/// - returns: A const value representing this value cast to the given
63+
/// integral type.
64+
public func cast(to type: IntType) -> Constant<Unsigned> {
65+
return Constant<Unsigned>(llvm: LLVMConstIntCast(llvm, type.asLLVM(), /*signed:*/ false.llvm))
6466
}
67+
}
68+
69+
extension Constant where Repr == Unsigned {
6570

6671
/// Creates a constant cast to a given floating type.
6772
///
@@ -70,8 +75,7 @@ extension Constant where Repr == Unsigned {
7075
/// - returns: A const value representing this value cast to the given
7176
/// floating type.
7277
public func cast(to type: FloatType) -> Constant<Floating> {
73-
let val = self.asLLVM()
74-
return Constant<Floating>(llvm: LLVMConstUIToFP(val, type.asLLVM()))
78+
return Constant<Floating>(llvm: LLVMConstUIToFP(llvm, type.asLLVM()))
7579
}
7680
}
7781

@@ -83,16 +87,18 @@ extension Constant where Repr == Signed {
8387
///
8488
/// - returns: A const value representing this value cast to the given
8589
/// integral type.
86-
public func cast<T: IntegralConstantRepresentation>(to type: IntType) -> Constant<T> {
87-
let destID = ObjectIdentifier(T.self)
88-
let val = self.asLLVM()
89-
if destID == ObjectIdentifier(Unsigned.self) {
90-
return Constant<T>(llvm: LLVMConstIntCast(val, type.asLLVM(), /*signed:*/ false.llvm))
91-
} else if destID == ObjectIdentifier(Signed.self) {
92-
return Constant<T>(llvm: LLVMConstIntCast(val, type.asLLVM(), /*signed:*/ true.llvm))
93-
} else {
94-
fatalError("Invalid representation \(type(of: T.self))")
95-
}
90+
public func cast(to type: IntType) -> Constant<Signed> {
91+
return Constant<Signed>(llvm: LLVMConstIntCast(llvm, type.asLLVM(), /*signed:*/ true.llvm))
92+
}
93+
94+
/// Creates a constant cast to a given integral type.
95+
///
96+
/// - parameter type: The type to cast towards.
97+
///
98+
/// - returns: A const value representing this value cast to the given
99+
/// integral type.
100+
public func cast(to type: IntType) -> Constant<Unsigned> {
101+
return Constant<Unsigned>(llvm: LLVMConstIntCast(llvm, type.asLLVM(), /*signed:*/ false.llvm))
96102
}
97103

98104
/// Creates a constant cast to a given floating type.
@@ -115,16 +121,18 @@ extension Constant where Repr == Floating {
115121
///
116122
/// - returns: A const value representing this value cast to the given
117123
/// integral type.
118-
public func cast<T: IntegralConstantRepresentation>(to type: IntType) -> Constant<T> {
119-
let destID = ObjectIdentifier(T.self)
120-
let val = self.asLLVM()
121-
if destID == ObjectIdentifier(Unsigned.self) {
122-
return Constant<T>(llvm: LLVMConstFPToUI(val, type.asLLVM()))
123-
} else if destID == ObjectIdentifier(Signed.self) {
124-
return Constant<T>(llvm: LLVMConstFPToSI(val, type.asLLVM()))
125-
} else {
126-
fatalError("Invalid representation \(type(of: T.self))")
127-
}
124+
public func cast(to type: IntType) -> Constant<Signed> {
125+
return Constant<Signed>(llvm: LLVMConstFPToSI(llvm, type.asLLVM()))
126+
}
127+
128+
/// Creates a constant cast to a given integral type.
129+
///
130+
/// - parameter type: The type to cast towards.
131+
///
132+
/// - returns: A const value representing this value cast to the given
133+
/// integral type.
134+
public func cast(to type: IntType) -> Constant<Unsigned> {
135+
return Constant<Unsigned>(llvm: LLVMConstFPToUI(llvm, type.asLLVM()))
128136
}
129137

130138
/// Creates a constant cast to a given floating type.
@@ -140,6 +148,8 @@ extension Constant where Repr == Floating {
140148
}
141149

142150

151+
// NOTE: These are here to improve the error message should a user attempt to cast a const struct
152+
143153
extension Constant where Repr == Struct {
144154

145155
@available(*, unavailable, message: "You cannot cast an aggregate type. See the LLVM Reference manual's section on `bitcast`")
@@ -666,7 +676,7 @@ extension Constant where Repr == Floating {
666676

667677
// MARK: Comparison Operations
668678

669-
extension Constant {
679+
extension Constant where Repr: IntegralConstantRepresentation {
670680

671681
/// A constant equality comparison between two values.
672682
///
@@ -675,18 +685,12 @@ extension Constant {
675685
///
676686
/// - returns: A constant integral value (i1) representing the result of the
677687
/// comparision of the given operands.
678-
public static func equals<T: NumericalConstantRepresentation>(_ lhs: Constant<T>, _ rhs: Constant<T>) -> Constant<Signed> {
679-
680-
switch ObjectIdentifier(T.self) {
681-
case ObjectIdentifier(Unsigned.self): fallthrough
682-
case ObjectIdentifier(Signed.self):
683-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.equal.llvm, lhs.llvm, rhs.llvm))
684-
case ObjectIdentifier(Floating.self):
685-
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedEqual.llvm, lhs.llvm, rhs.llvm))
686-
default:
687-
fatalError("Invalid representation")
688-
}
688+
public static func equals(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
689+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.equal.llvm, lhs.llvm, rhs.llvm))
689690
}
691+
}
692+
693+
extension Constant where Repr == Signed {
690694

691695
/// A constant less-than comparison between two values.
692696
///
@@ -695,18 +699,8 @@ extension Constant {
695699
///
696700
/// - returns: A constant integral value (i1) representing the result of the
697701
/// comparision of the given operands.
698-
public static func lessThan<T: NumericalConstantRepresentation>(_ lhs: Constant<T>, _ rhs: Constant<T>) -> Constant<Signed> {
699-
700-
switch ObjectIdentifier(T.self) {
701-
case ObjectIdentifier(Unsigned.self):
702-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedLessThan.llvm, lhs.llvm, rhs.llvm))
703-
case ObjectIdentifier(Signed.self):
704-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedLessThan.llvm, lhs.llvm, rhs.llvm))
705-
case ObjectIdentifier(Floating.self):
706-
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedLessThan.llvm, lhs.llvm, rhs.llvm))
707-
default:
708-
fatalError("Invalid representation")
709-
}
702+
public static func lessThan(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
703+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedLessThan.llvm, lhs.llvm, rhs.llvm))
710704
}
711705

712706
/// A constant greater-than comparison between two values.
@@ -716,18 +710,8 @@ extension Constant {
716710
///
717711
/// - returns: A constant integral value (i1) representing the result of the
718712
/// comparision of the given operands.
719-
public static func greaterThan<T: NumericalConstantRepresentation>(_ lhs: Constant<T>, _ rhs: Constant<T>) -> Constant<Signed> {
720-
721-
switch ObjectIdentifier(T.self) {
722-
case ObjectIdentifier(Unsigned.self):
723-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedGreaterThan.llvm, lhs.llvm, rhs.llvm))
724-
case ObjectIdentifier(Signed.self):
725-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedGreaterThan.llvm, lhs.llvm, rhs.llvm))
726-
case ObjectIdentifier(Floating.self):
727-
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedGreaterThan.llvm, lhs.llvm, rhs.llvm))
728-
default:
729-
fatalError("Invalid representation")
730-
}
713+
public static func greaterThan(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
714+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedGreaterThan.llvm, lhs.llvm, rhs.llvm))
731715
}
732716

733717
/// A constant less-than-or-equal comparison between two values.
@@ -737,18 +721,8 @@ extension Constant {
737721
///
738722
/// - returns: A constant integral value (i1) representing the result of the
739723
/// comparision of the given operands.
740-
public static func lessThanOrEqual <T: NumericalConstantRepresentation>(_ lhs: Constant<T>, _ rhs: Constant<T>) -> Constant<Signed> {
741-
742-
switch ObjectIdentifier(T.self) {
743-
case ObjectIdentifier(Unsigned.self):
744-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedLessThanOrEqual.llvm, lhs.llvm, rhs.llvm))
745-
case ObjectIdentifier(Signed.self):
746-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedLessThanOrEqual.llvm, lhs.llvm, rhs.llvm))
747-
case ObjectIdentifier(Floating.self):
748-
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedLessThanOrEqual.llvm, lhs.llvm, rhs.llvm))
749-
default:
750-
fatalError("Invalid representation")
751-
}
724+
public static func lessThanOrEqual(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
725+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedLessThanOrEqual.llvm, lhs.llvm, rhs.llvm))
752726
}
753727

754728
/// A constant greater-than-or-equal comparison between two values.
@@ -758,22 +732,120 @@ extension Constant {
758732
///
759733
/// - returns: A constant integral value (i1) representing the result of the
760734
/// comparision of the given operands.
761-
public static func greaterThanOrEqual <T: NumericalConstantRepresentation>(_ lhs: Constant<T>, _ rhs: Constant<T>) -> Constant<Signed> {
762-
763-
switch ObjectIdentifier(T.self) {
764-
case ObjectIdentifier(Unsigned.self):
765-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedGreaterThanOrEqual.llvm, lhs.llvm, rhs.llvm))
766-
case ObjectIdentifier(Signed.self):
767-
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedGreaterThanOrEqual.llvm, lhs.llvm, rhs.llvm))
768-
case ObjectIdentifier(Floating.self):
769-
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedGreaterThanOrEqual.llvm, lhs.llvm, rhs.llvm))
770-
default:
771-
fatalError("Invalid representation")
772-
}
735+
public static func greaterThanOrEqual(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
736+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.signedGreaterThanOrEqual.llvm, lhs.llvm, rhs.llvm))
737+
}
738+
}
739+
740+
extension Constant where Repr == Unsigned {
741+
742+
/// A constant less-than comparison between two values.
743+
///
744+
/// - parameter lhs: The first value to compare.
745+
/// - parameter rhs: The second value to compare.
746+
///
747+
/// - returns: A constant integral value (i1) representing the result of the
748+
/// comparision of the given operands.
749+
public static func lessThan(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
750+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedLessThan.llvm, lhs.llvm, rhs.llvm))
751+
}
752+
753+
/// A constant greater-than comparison between two values.
754+
///
755+
/// - parameter lhs: The first value to compare.
756+
/// - parameter rhs: The second value to compare.
757+
///
758+
/// - returns: A constant integral value (i1) representing the result of the
759+
/// comparision of the given operands.
760+
public static func greaterThan(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
761+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedGreaterThan.llvm, lhs.llvm, rhs.llvm))
773762
}
774763

764+
/// A constant less-than-or-equal comparison between two values.
765+
///
766+
/// - parameter lhs: The first value to compare.
767+
/// - parameter rhs: The second value to compare.
768+
///
769+
/// - returns: A constant integral value (i1) representing the result of the
770+
/// comparision of the given operands.
771+
public static func lessThanOrEqual(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
772+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedLessThanOrEqual.llvm, lhs.llvm, rhs.llvm))
773+
}
774+
775+
/// A constant greater-than-or-equal comparison between two values.
776+
///
777+
/// - parameter lhs: The first value to compare.
778+
/// - parameter rhs: The second value to compare.
779+
///
780+
/// - returns: A constant integral value (i1) representing the result of the
781+
/// comparision of the given operands.
782+
public static func greaterThanOrEqual(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
783+
return Constant<Signed>(llvm: LLVMConstICmp(IntPredicate.unsignedGreaterThanOrEqual.llvm, lhs.llvm, rhs.llvm))
784+
}
785+
}
775786

776-
// MARK: Logical Operations
787+
extension Constant where Repr == Floating {
788+
789+
/// A constant equality comparison between two values.
790+
///
791+
/// - parameter lhs: The first value to compare.
792+
/// - parameter rhs: The second value to compare.
793+
///
794+
/// - returns: A constant integral value (i1) representing the result of the
795+
/// comparision of the given operands.
796+
public static func equals(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
797+
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedEqual.llvm, lhs.llvm, rhs.llvm))
798+
}
799+
800+
/// A constant less-than comparison between two values.
801+
///
802+
/// - parameter lhs: The first value to compare.
803+
/// - parameter rhs: The second value to compare.
804+
///
805+
/// - returns: A constant integral value (i1) representing the result of the
806+
/// comparision of the given operands.
807+
public static func lessThan(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
808+
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedLessThan.llvm, lhs.llvm, rhs.llvm))
809+
}
810+
811+
/// A constant greater-than comparison between two values.
812+
///
813+
/// - parameter lhs: The first value to compare.
814+
/// - parameter rhs: The second value to compare.
815+
///
816+
/// - returns: A constant integral value (i1) representing the result of the
817+
/// comparision of the given operands.
818+
public static func greaterThan(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
819+
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedGreaterThan.llvm, lhs.llvm, rhs.llvm))
820+
}
821+
822+
/// A constant less-than-or-equal comparison between two values.
823+
///
824+
/// - parameter lhs: The first value to compare.
825+
/// - parameter rhs: The second value to compare.
826+
///
827+
/// - returns: A constant integral value (i1) representing the result of the
828+
/// comparision of the given operands.
829+
public static func lessThanOrEqual(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
830+
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedLessThanOrEqual.llvm, lhs.llvm, rhs.llvm))
831+
}
832+
833+
/// A constant greater-than-or-equal comparison between two values.
834+
///
835+
/// - parameter lhs: The first value to compare.
836+
/// - parameter rhs: The second value to compare.
837+
///
838+
/// - returns: A constant integral value (i1) representing the result of the
839+
/// comparision of the given operands.
840+
public static func greaterThanOrEqual(_ lhs: Constant, _ rhs: Constant) -> Constant<Signed> {
841+
return Constant<Signed>(llvm: LLVMConstFCmp(RealPredicate.orderedGreaterThanOrEqual.llvm, lhs.llvm, rhs.llvm))
842+
}
843+
}
844+
845+
846+
// MARK: Logical Operations
847+
848+
extension Constant {
777849

778850
/// A constant bitwise logical not with the given integral value as an operand.
779851
///

Tests/LLVMTests/ConstantSpec.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ConstantSpec : XCTestCase {
2222
// SIGNEDCONST-NOT: %{{[0-9]+}} = add i64 %%{{[0-9]+}}, %%{{[0-9]+}}
2323
let val1 = builder.buildAdd(constant.adding(constant), constant.multiplying(constant))
2424
// SIGNEDCONST-NOT: %{{[0-9]+}} = sub i64 %%{{[0-9]+}}, %%{{[0-9]+}}
25-
let val2 = builder.buildSub(constant.subtracting(constant), constant.dividing(constant))
25+
let val2 = builder.buildSub(constant.subtracting(constant), constant.dividing(by: constant))
2626
// SIGNEDCONST-NOT: %{{[0-9]+}} = mul i64 %%{{[0-9]+}}, %%{{[0-9]+}}
2727
let val3 = builder.buildMul(val1, val2)
2828
// SIGNEDCONST-NOT: %{{[0-9]+}} = mul i64 %%{{[0-9]+}}, %%{{[0-9]+}}
@@ -52,7 +52,7 @@ class ConstantSpec : XCTestCase {
5252
// UNSIGNEDCONST-NOT: %{{[0-9]+}} = add i64 %%{{[0-9]+}}, %%{{[0-9]+}}
5353
let val1 = builder.buildAdd(constant.adding(constant), constant.multiplying(constant))
5454
// UNSIGNEDCONST-NOT: %{{[0-9]+}} = sub i64 %%{{[0-9]+}}, %%{{[0-9]+}}
55-
let val2 = builder.buildSub(constant.subtracting(constant), constant.dividing(constant))
55+
let val2 = builder.buildSub(constant.subtracting(constant), constant.dividing(by: constant))
5656
// UNSIGNEDCONST-NOT: %{{[0-9]+}} = mul i64 %%{{[0-9]+}}, %%{{[0-9]+}}
5757
let val3 = builder.buildMul(val1, val2)
5858

@@ -80,7 +80,7 @@ class ConstantSpec : XCTestCase {
8080
// FLOATINGCONST-NOT: %{{[0-9]+}} = add double %%{{[0-9]+}}, %%{{[0-9]+}}
8181
let val1 = builder.buildAdd(constant.adding(constant), constant.multiplying(constant))
8282
// FLOATINGCONST-NOT: %{{[0-9]+}} = sub double %%{{[0-9]+}}, %%{{[0-9]+}}
83-
let val2 = builder.buildSub(constant.subtracting(constant), constant.dividing(constant))
83+
let val2 = builder.buildSub(constant.subtracting(constant), constant.dividing(by: constant))
8484
// FLOATINGCONST-NOT: %{{[0-9]+}} = mul double %%{{[0-9]+}}, %%{{[0-9]+}}
8585
let val3 = builder.buildMul(val1, val2)
8686

0 commit comments

Comments
 (0)