Skip to content

Commit c467bc8

Browse files
Implement serde traits for JSON (#782)
## Usage and product changes We implement serde's `Serialize` for `ConceptDocument` and our custom `JSON` type. This allows crates like `axum` that rely on `serde` to seamlessly return `Json<ConceptDocument>` from request handlers. We also implement `Deserialize` for `JSON` for convenience. ## Implementation Closes #766
1 parent e4c3eaa commit c467bc8

File tree

4 files changed

+233
-7
lines changed

4 files changed

+233
-7
lines changed

Cargo.lock

Lines changed: 6 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ typedb_driver_deps = [
3737
"@crates//:itertools",
3838
"@crates//:log",
3939
"@crates//:prost",
40+
"@crates//:serde",
4041
"@crates//:tokio",
4142
"@crates//:tokio-stream",
4243
"@crates//:tonic",
@@ -70,7 +71,10 @@ rust_library(
7071
rust_test(
7172
name = "typedb_driver_unit_tests",
7273
crate = ":typedb_driver",
73-
deps = ["@crates//:serde_json"],
74+
deps = [
75+
"@crates//:rand",
76+
"@crates//:serde_json",
77+
],
7478
)
7579

7680
assemble_crate(

rust/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717
[dev-dependencies]
1818

19+
[dev-dependencies.rand]
20+
features = ["alloc", "default", "getrandom", "libc", "rand_chacha", "small_rng", "std", "std_rng"]
21+
version = "0.8.5"
22+
default-features = false
23+
1924
[dev-dependencies.smol]
2025
features = []
2126
version = "1.3.0"
@@ -69,6 +74,11 @@
6974
version = "0.4.27"
7075
default-features = false
7176

77+
[dependencies.serde]
78+
features = ["alloc", "default", "derive", "rc", "serde_derive", "std"]
79+
version = "1.0.219"
80+
default-features = false
81+
7282
[dependencies.tokio-stream]
7383
features = ["default", "net", "time"]
7484
version = "0.1.17"

rust/src/answer/json.rs

Lines changed: 212 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@ use std::{
2121
borrow::Cow,
2222
collections::HashMap,
2323
fmt::{self, Write},
24+
iter,
2425
};
2526

26-
#[derive(Clone, Debug)]
27+
use itertools::Itertools;
28+
use serde::{
29+
ser::{SerializeMap, SerializeSeq},
30+
Deserialize, Serialize,
31+
};
32+
33+
#[derive(Clone, Debug, PartialEq)]
2734
pub enum JSON {
2835
Object(HashMap<Cow<'static, str>, JSON>),
2936
Array(Vec<JSON>),
@@ -112,9 +119,154 @@ fn write_escaped_string(string: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result
112119
write!(f, r#""{}""#, unsafe { String::from_utf8_unchecked(buf) })
113120
}
114121

122+
impl Serialize for JSON {
123+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
124+
where
125+
S: serde::Serializer,
126+
{
127+
match self {
128+
Self::Object(object) => {
129+
let mut map = serializer.serialize_map(Some(object.len()))?;
130+
for (key, value) in object {
131+
map.serialize_entry(key, value)?;
132+
}
133+
map.end()
134+
}
135+
Self::Array(array) => {
136+
let mut seq = serializer.serialize_seq(Some(array.len()))?;
137+
for item in array {
138+
seq.serialize_element(item)?;
139+
}
140+
seq.end()
141+
}
142+
Self::String(string) => serializer.serialize_str(string),
143+
&Self::Number(number) => serializer.serialize_f64(number),
144+
&Self::Boolean(boolean) => serializer.serialize_bool(boolean),
145+
Self::Null => serializer.serialize_unit(),
146+
}
147+
}
148+
}
149+
150+
impl<'de> Deserialize<'de> for JSON {
151+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
152+
where
153+
D: serde::Deserializer<'de>,
154+
{
155+
struct Visitor;
156+
157+
impl<'de> serde::de::Visitor<'de> for Visitor {
158+
type Value = JSON;
159+
160+
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
161+
formatter.write_str("a valid JSON value")
162+
}
163+
164+
fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
165+
where
166+
E: serde::de::Error,
167+
{
168+
Ok(JSON::Boolean(value))
169+
}
170+
171+
fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
172+
where
173+
E: serde::de::Error,
174+
{
175+
Ok(JSON::Number(value as f64))
176+
}
177+
178+
fn visit_i128<E>(self, value: i128) -> Result<Self::Value, E>
179+
where
180+
E: serde::de::Error,
181+
{
182+
Ok(JSON::Number(value as f64))
183+
}
184+
185+
fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
186+
where
187+
E: serde::de::Error,
188+
{
189+
Ok(JSON::Number(value as f64))
190+
}
191+
192+
fn visit_u128<E>(self, value: u128) -> Result<Self::Value, E>
193+
where
194+
E: serde::de::Error,
195+
{
196+
Ok(JSON::Number(value as f64))
197+
}
198+
199+
fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
200+
where
201+
E: serde::de::Error,
202+
{
203+
Ok(JSON::Number(value))
204+
}
205+
206+
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
207+
where
208+
E: serde::de::Error,
209+
{
210+
Ok(JSON::String(Cow::Owned(value.to_owned())))
211+
}
212+
213+
fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
214+
where
215+
E: serde::de::Error,
216+
{
217+
Ok(JSON::String(Cow::Owned(value)))
218+
}
219+
220+
fn visit_none<E>(self) -> Result<Self::Value, E>
221+
where
222+
E: serde::de::Error,
223+
{
224+
Ok(JSON::Null)
225+
}
226+
227+
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
228+
where
229+
D: serde::Deserializer<'de>,
230+
{
231+
JSON::deserialize(deserializer)
232+
}
233+
234+
fn visit_unit<E>(self) -> Result<Self::Value, E>
235+
where
236+
E: serde::de::Error,
237+
{
238+
Ok(JSON::Null)
239+
}
240+
241+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
242+
where
243+
A: serde::de::SeqAccess<'de>,
244+
{
245+
Ok(JSON::Array(iter::from_fn(|| seq.next_element().transpose()).try_collect()?))
246+
}
247+
248+
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
249+
where
250+
A: serde::de::MapAccess<'de>,
251+
{
252+
Ok(JSON::Object(iter::from_fn(|| map.next_entry().transpose()).try_collect()?))
253+
}
254+
}
255+
256+
deserializer.deserialize_any(Visitor)
257+
}
258+
}
259+
115260
#[cfg(test)]
116261
mod test {
117-
use std::borrow::Cow;
262+
use std::{borrow::Cow, collections::HashMap, iter};
263+
264+
use rand::{
265+
distributions::{DistString, Distribution, Standard, WeightedIndex},
266+
rngs::ThreadRng,
267+
thread_rng, Rng,
268+
};
269+
use serde_json::json;
118270

119271
use super::JSON;
120272

@@ -126,4 +278,62 @@ mod test {
126278
let json_string = JSON::String(Cow::Owned(string));
127279
assert_eq!(serde_json::to_string(&serde_json_value).unwrap(), json_string.to_string());
128280
}
281+
282+
fn sample_json() -> JSON {
283+
JSON::Object(HashMap::from([
284+
("array".into(), JSON::Array(vec![JSON::Boolean(true), JSON::String("string".into())])),
285+
("number".into(), JSON::Number(123.4)),
286+
]))
287+
}
288+
289+
#[test]
290+
fn serialize() {
291+
let ser = serde_json::to_value(sample_json()).unwrap();
292+
let value = json!( { "array": [true, "string"], "number": 123.4 });
293+
assert_eq!(ser, value);
294+
}
295+
296+
#[test]
297+
fn deserialize() {
298+
let deser: JSON = serde_json::from_str(r#"{ "array": [true, "string"], "number": 123.4 }"#).unwrap();
299+
let json = sample_json();
300+
assert_eq!(deser, json);
301+
}
302+
303+
fn random_string(rng: &mut impl Rng) -> String {
304+
let len = rng.gen_range(0..64);
305+
Standard.sample_string(rng, len)
306+
}
307+
308+
fn random_json<R: Rng>(rng: &mut R) -> JSON {
309+
let weights = [1, 1, 3, 3, 3, 3];
310+
let generators: &[fn(&mut R) -> JSON] = &[
311+
|rng| {
312+
let len = rng.gen_range(0..12);
313+
JSON::Object(
314+
iter::from_fn(|| Some((Cow::Owned(random_string(rng)), random_json(rng)))).take(len).collect(),
315+
)
316+
},
317+
|rng| {
318+
let len = rng.gen_range(0..12);
319+
JSON::Array(iter::from_fn(|| Some(random_json(rng))).take(len).collect())
320+
},
321+
|rng| JSON::String(Cow::Owned(random_string(rng))),
322+
|rng| JSON::Number(rng.gen()),
323+
|rng| JSON::Boolean(rng.gen()),
324+
|_| JSON::Null,
325+
];
326+
let dist = WeightedIndex::new(&weights).unwrap();
327+
generators[dist.sample(rng)](rng)
328+
}
329+
330+
#[test]
331+
fn serde_roundtrip() {
332+
let mut rng = thread_rng();
333+
for _ in 0..1000 {
334+
let json = random_json(&mut rng);
335+
let deser = serde_json::from_value(serde_json::to_value(&json).unwrap()).unwrap();
336+
assert_eq!(json, deser);
337+
}
338+
}
129339
}

0 commit comments

Comments
 (0)