Skip to content

Commit 6f475e8

Browse files
iceychriserksch
authored andcommitted
Add tests for boolean iValues
1 parent 27afa02 commit 6f475e8

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

build_dummy_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,18 @@ def identity_tensor(self, x: torch.Tensor):
2929
def identity_long(self, x: int):
3030
return x
3131

32+
@torch.jit.export
33+
def identity_bool(self, x: bool):
34+
return x
35+
3236
@torch.jit.export
3337
def identity_list(self, x: List[int]):
3438
return x
3539

40+
@torch.jit.export
41+
def identity_bool_list(self, x: List[bool]):
42+
return x
43+
3644
@torch.jit.export
3745
def identity_tuple(self, x: Tuple[int]):
3846
return x

src/iosX64Test/kotlin/de/voize/pytorch_lite_multiplatform/TorchModuleIOSTest.kt

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,52 @@ class TorchModuleIOSTest {
188188
}
189189
}
190190

191+
@Test
192+
fun testIdentityBool() {
193+
plmScoped {
194+
val module = TorchModule(localModulePath)
195+
val input = IValue.from(false)
196+
val output = module.runMethod(
197+
"identity_bool",
198+
input
199+
)
200+
assertEquals(false, output.toBool())
201+
202+
val input2 = IValue.from(true)
203+
val output2 = module.runMethod(
204+
"identity_bool",
205+
input2
206+
)
207+
assertEquals(true, output2.toBool())
208+
}
209+
}
210+
211+
@Test
212+
fun testIdentityBoolList() {
213+
plmScoped {
214+
val module = TorchModule(localModulePath)
215+
val input = IValue.listFrom(IValue.from(true), IValue.from(false))
216+
val output = module.runMethod(
217+
"identity_bool_list",
218+
input
219+
)
220+
assertEquals(listOf(true, false), output.toBoolList())
221+
}
222+
}
223+
224+
@Test
225+
fun testIdentityBoolList2() {
226+
plmScoped {
227+
val module = TorchModule(localModulePath)
228+
val input = IValue.listFrom(true, false, scope = this)
229+
val output = module.runMethod(
230+
"identity_bool_list",
231+
input
232+
)
233+
assertEquals(listOf(true, false), output.toBoolList())
234+
}
235+
}
236+
191237
@Test
192238
fun testIdentityTensor() {
193239
plmScoped {

0 commit comments

Comments
 (0)