Skip to content

Commit 53d4e1f

Browse files
committed
refactor: add SchemaValidator::get_field_name
1 parent e7e2322 commit 53d4e1f

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

src/row_metadata.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ static ROW_METADATA_CACHE: OnceCell<LockedRowMetadataCache> = OnceCell::const_ne
2222
#[derive(Debug, PartialEq)]
2323
pub(crate) enum AccessType {
2424
WithSeqAccess,
25-
WithMapAccess(Vec<usize>),
25+
WithMapAccess(Vec<usize>, Vec<&'static str>),
2626
}
2727

2828
/// Contains a vector of [`Column`] objects parsed from the beginning
@@ -122,7 +122,7 @@ impl RowMetadata {
122122
}
123123
}
124124
if should_use_map {
125-
AccessType::WithMapAccess(mapping)
125+
AccessType::WithMapAccess(mapping, T::column_names().into_iter().collect())
126126
} else {
127127
AccessType::WithSeqAccess
128128
}
@@ -137,7 +137,7 @@ impl RowMetadata {
137137
#[inline]
138138
pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize {
139139
match &self.access_type {
140-
AccessType::WithMapAccess(mapping) => {
140+
AccessType::WithMapAccess(mapping, _) => {
141141
if struct_idx < mapping.len() {
142142
mapping[struct_idx]
143143
} else {
@@ -149,9 +149,22 @@ impl RowMetadata {
149149
}
150150
}
151151

152+
#[inline]
153+
pub(crate) fn get_field_name(&self, struct_idx: usize) -> Option<&'static str> {
154+
match &self.access_type {
155+
AccessType::WithMapAccess(mapping, field_names) => {
156+
let Some(mapped) = mapping.get(struct_idx) else {
157+
return None;
158+
};
159+
field_names.get(*mapped).copied()
160+
}
161+
AccessType::WithSeqAccess => None,
162+
}
163+
}
164+
152165
#[inline]
153166
pub(crate) fn is_field_order_wrong(&self) -> bool {
154-
matches!(self.access_type, AccessType::WithMapAccess(_))
167+
matches!(self.access_type, AccessType::WithMapAccess(_, _))
155168
}
156169
}
157170

src/rowbinary/de.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ where
287287
visitor.visit_map(RowBinaryStructAsMapAccess {
288288
deserializer: self,
289289
current_field_idx: 0,
290-
fields,
291290
})
292291
}
293292
}
@@ -417,7 +416,6 @@ where
417416
{
418417
deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>,
419418
current_field_idx: usize,
420-
fields: &'static [&'static str],
421419
}
422420

423421
struct StructFieldIdentifier(&'static str);
@@ -475,14 +473,14 @@ where
475473
where
476474
K: DeserializeSeed<'data>,
477475
{
478-
if self.current_field_idx >= self.fields.len() {
479-
return Ok(None);
480-
}
481-
let schema_index = self
476+
let Some(field_name) = self
482477
.deserializer
483478
.validator
484-
.get_schema_index(self.current_field_idx);
485-
let field_id = StructFieldIdentifier(self.fields[schema_index]);
479+
.get_field_name(self.current_field_idx)
480+
else {
481+
return Ok(None);
482+
};
483+
let field_id = StructFieldIdentifier(field_name);
486484
self.current_field_idx += 1;
487485
seed.deserialize(field_id).map(Some)
488486
}
@@ -495,7 +493,8 @@ where
495493
}
496494

497495
fn size_hint(&self) -> Option<usize> {
498-
Some(self.fields.len())
496+
// Some(self.fields.len())
497+
None
499498
}
500499
}
501500

src/rowbinary/validation.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub(crate) trait SchemaValidator<R: Row>: Sized {
2828
/// It is used only if the crate detects that while the field names and the types are correct,
2929
/// the field order in the struct does not match the column order in the database schema.
3030
fn get_schema_index(&self, struct_idx: usize) -> usize;
31+
fn get_field_name(&self, struct_idx: usize) -> Option<&'static str>;
3132
}
3233

3334
pub(crate) struct DataTypeValidator<'cursor, R: Row> {
@@ -184,6 +185,10 @@ impl<'cursor, R: Row> SchemaValidator<R> for DataTypeValidator<'cursor, R> {
184185
self.metadata.get_schema_index(struct_idx)
185186
}
186187

188+
fn get_field_name(&self, struct_idx: usize) -> Option<&'static str> {
189+
self.metadata.get_field_name(struct_idx)
190+
}
191+
187192
#[cold]
188193
fn validate_identifier<T: EnumOrVariantIdentifier>(&mut self, _value: T) {
189194
unreachable!()
@@ -393,6 +398,10 @@ impl<'cursor, R: Row> SchemaValidator<R> for Option<InnerDataTypeValidator<'_, '
393398
fn get_schema_index(&self, _struct_idx: usize) -> usize {
394399
unreachable!()
395400
}
401+
402+
fn get_field_name(&self, _struct_idx: usize) -> Option<&'static str> {
403+
unreachable!()
404+
}
396405
}
397406

398407
impl<R: Row> Drop for InnerDataTypeValidator<'_, '_, R> {
@@ -638,6 +647,10 @@ impl<R: Row> SchemaValidator<R> for () {
638647
fn get_schema_index(&self, _struct_idx: usize) -> usize {
639648
unreachable!()
640649
}
650+
651+
fn get_field_name(&self, _struct_idx: usize) -> Option<&'static str> {
652+
unreachable!()
653+
}
641654
}
642655

643656
/// Which Serde data type (De)serializer used for the given type.

0 commit comments

Comments
 (0)