Skip to content

Commit 0ad2e1a

Browse files
committed
fix how relationships are saved
There has been a major change here: - Recursive saves (even with circular dependencies) are now supported! And also a bug has been fixed as a result: - If the virtual column corresponding to a relationship foreign key is empty (because, for example, the relationship has not been retrieved in the query) but exists in the database, now it is ignored. Before it was deleted without checking if it existed in the database. This does not break any expected behavior because relationships should not be removed by just removing them from the object but using the Remove* methods from the store. Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
1 parent 2ce780b commit 0ad2e1a

File tree

11 files changed

+2354
-501
lines changed

11 files changed

+2354
-501
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ if err != nil {
283283
}
284284
```
285285

286-
If our model has relationships, they will be saved (**note:** saved as in insert or update) as well. The relationships of the relationships will not, though. Relationships are only saved with one level of depth.
286+
If our model has relationships, they will be saved, and so will the relationships of the relationships and so on. TL;DR: inserts are recursive.
287+
**Note:** the relationships will be saved using `Save`, not `Insert`.
287288

288289
```go
289290
user := NewUser("foo")
@@ -318,7 +319,8 @@ if err != nil {
318319
}
319320
```
320321

321-
If our model has relationships, they will be saved (**note:** saved as in insert or update) as well. The relationships of the relationships will not, though. Relationships are only saved with one level of depth.
322+
If our model has relationships, they will be saved, and so will the relationships of the relationships and so on. TL;DR: updates are recursive.
323+
**Note:** the relationships will be saved using `Save`, not `Update`.
322324

323325
```go
324326
user := FindLastPoster()
@@ -345,7 +347,7 @@ if updated {
345347
}
346348
```
347349

348-
If our model has relationships, they will be saved as well. The relationships of the relationships will not, though. Relationships are only saved with one level of depth.
350+
If our model has relationships, they will be saved, and so will the relationships of the relationships and so on. TL;DR: saves are recursive.
349351

350352
```go
351353
user := NewUser("foo")

generator/template.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ func (td *TemplateData) GenColumnValues(model *Model) string {
134134
var buf bytes.Buffer
135135
td.genFieldsValues(&buf, model.Fields)
136136
for _, fk := range model.ImplicitFKs {
137-
buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name))
138-
buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n"))
137+
buf.WriteString(fmt.Sprintf(virtualFieldValueTpl, fk.Name))
139138
}
140139
return buf.String()
141140
}
@@ -145,13 +144,20 @@ const nilPtrReturnsUntypedNilTpl = `if %s == (*%s)(nil) {
145144
}
146145
`
147146

147+
const virtualFieldValueTpl = `case "%s":
148+
v := r.Model.VirtualColumn(col)
149+
if v == nil {
150+
return nil, kallax.ErrEmptyVirtualColumn
151+
}
152+
return v, nil
153+
`
154+
148155
func (td *TemplateData) genFieldsValues(buf *bytes.Buffer, fields []*Field) {
149156
for _, f := range fields {
150157
if f.Inline() {
151158
td.genFieldsValues(buf, f.Fields)
152159
} else if isOneToOneRelationship(f) && f.IsInverse() {
153-
buf.WriteString(fmt.Sprintf("case \"%s\":\n", f.ForeignKey()))
154-
buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n"))
160+
buf.WriteString(fmt.Sprintf(virtualFieldValueTpl, f.ForeignKey()))
155161
} else if f.Kind != Relationship {
156162
buf.WriteString(fmt.Sprintf("case \"%s\":\n", f.ColumnName()))
157163
if f.IsPtr {
@@ -529,7 +535,7 @@ const (
529535
// The passed values to the FindBy will be used in an kallax.ArrayContains
530536
tplFindByCollection = `
531537
// FindBy%[1]s adds a new filter to the query that will require that
532-
// the %[1]s property contains all the passed values; if no passed values,
538+
// the %[1]s property contains all the passed values; if no passed values,
533539
// it will do nothing.
534540
func (q *%[2]s) FindBy%[1]s(v ...%[3]s) *%[2]s {
535541
if len(v) == 0 {return q}
@@ -557,7 +563,7 @@ const (
557563
// The passed values to the FindBy will be used in an kallax.In condition.
558564
tplFindByID = `
559565
// FindBy%[1]s adds a new filter to the query that will require that
560-
// the %[1]s property is equal to one of the passed values; if no passed values,
566+
// the %[1]s property is equal to one of the passed values; if no passed values,
561567
// it will do nothing.
562568
func (q *%[2]s) FindBy%[1]s(v ...%[3]s) *%[2]s {
563569
if len(v) == 0 {return q}

generator/template_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ return (*types.URL)(r.URL), nil
141141
case "url_no_ptr":
142142
return (*types.URL)(&r.UrlNoPtr), nil
143143
case "rel_id":
144-
return r.Model.VirtualColumn(col), nil
144+
v := r.Model.VirtualColumn(col)
145+
if v == nil {
146+
return nil, kallax.ErrEmptyVirtualColumn
147+
}
148+
return v, nil
145149
`
146150

147151
func (s *TemplateSuite) TestGenColumnValues() {

generator/templates/base.tgo

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ import (
1515
var _ types.SQLType
1616
var _ fmt.Formatter
1717

18+
type modelSaveFunc func(*kallax.Store) error
19+
1820
{{template "model" .}}
1921
{{template "schema" .}}

generator/templates/model.tgo

Lines changed: 36 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -119,53 +119,57 @@ func (s *{{.StoreName}}) DebugWith(logger kallax.LoggerFunc) *{{.StoreName}} {
119119
}
120120

121121
{{if .HasNonInverses}}
122-
func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []kallax.RecordWithSchema {
123-
var records []kallax.RecordWithSchema
122+
func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []modelSaveFunc {
123+
var result []modelSaveFunc
124124
{{range .NonInverses}}
125125
{{if .IsOneToManyRelationship}}
126126
for i := range record.{{.Name}} {
127-
record.{{.Name}}[i].ClearVirtualColumns()
128-
record.{{.Name}}[i].AddVirtualColumn("{{.ForeignKey}}", record.GetID())
129-
records = append(records, kallax.RecordWithSchema{
130-
Schema: Schema.{{.TypeSchemaName}}.BaseSchema,
131-
Record: {{if not ($.IsPtrSlice .)}}&{{end}}record.{{.Name}}[i],
132-
})
127+
r := {{if not ($.IsPtrSlice .)}}&{{end}}record.{{.Name}}[i]
128+
if !r.IsSaving() {
129+
r.AddVirtualColumn("{{.ForeignKey}}", record.GetID())
130+
result = append(result, func(store *kallax.Store) error {
131+
_, err := (&{{.TypeSchemaName}}Store{store}).Save(r)
132+
return err
133+
})
134+
}
133135
}
134136
{{else}}
135-
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} {
136-
record.{{.Name}}.ClearVirtualColumns()
137-
record.{{.Name}}.AddVirtualColumn("{{.ForeignKey}}", record.GetID())
138-
records = append(records, kallax.RecordWithSchema{
139-
Schema: Schema.{{.TypeSchemaName}}.BaseSchema,
140-
Record: {{if not .IsPtr}}&{{end}}record.{{.Name}},
137+
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} && !record.{{.Name}}.IsSaving() {
138+
r := {{if not .IsPtr}}&{{end}}record.{{.Name}}
139+
r.AddVirtualColumn("{{.ForeignKey}}", record.GetID())
140+
result = append(result, func(store *kallax.Store) error {
141+
_, err := (&{{.TypeSchemaName}}Store{store}).Save(r)
142+
return err
141143
})
142144
}
143145
{{end}}
144146
{{end}}
145-
return records
147+
return result
146148
}
147149
{{end}}
148150

149151
{{if .HasInverses}}
150-
func (s *{{.StoreName}}) inverseRecords(record *{{.Name}}) []kallax.RecordWithSchema {
151-
record.ClearVirtualColumns()
152-
var records []kallax.RecordWithSchema
152+
func (s *{{.StoreName}}) inverseRecords(record *{{.Name}}) []modelSaveFunc {
153+
var result []modelSaveFunc
153154
{{range .Inverses}}
154-
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} {
155+
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} && !record.{{.Name}}.IsSaving() {
155156
record.AddVirtualColumn("{{.ForeignKey}}", record.{{.Name}}.GetID())
156-
records = append(records, kallax.RecordWithSchema{
157-
Schema: Schema.{{.TypeSchemaName}}.BaseSchema,
158-
Record: {{if not .IsPtr}}&{{end}}record.{{.Name}},
157+
result = append(result, func(store *kallax.Store) error {
158+
_, err := (&{{.TypeSchemaName}}Store{store}).Save(record.{{.Name}})
159+
return err
159160
})
160161
}
161162
{{end}}
162-
return records
163+
return result
163164
}
164165
{{end}}
165166

166167
// Insert inserts a {{.Name}} in the database. A non-persisted object is
167168
// required for this operation.
168169
func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
170+
record.SetSaving(true)
171+
defer record.SetSaving(false)
172+
169173
{{$.GenTimeTruncations .}}
170174
{{if .Events.Has "BeforeSave"}}
171175
if err := record.BeforeSave(); err != nil {
@@ -183,20 +187,11 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
183187
{{if .HasInverses}}
184188
inverseRecords := s.inverseRecords(record)
185189
{{end}}
186-
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}&&{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
190+
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}||{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
187191
return s.Store.Transaction(func(s *kallax.Store) error {
188192
{{if .HasInverses}}
189193
for _, r := range inverseRecords {
190-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
191-
return err
192-
}
193-
persisted := r.Record.IsPersisted()
194-
195-
if _, err := s.Save(r.Schema, r.Record); err != nil {
196-
return err
197-
}
198-
199-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
194+
if err := r(s); err != nil {
200195
return err
201196
}
202197
}
@@ -206,16 +201,7 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
206201
}
207202
{{if .HasNonInverses}}
208203
for _, r := range records {
209-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
210-
return err
211-
}
212-
persisted := r.Record.IsPersisted()
213-
214-
if _, err := s.Save(r.Schema, r.Record); err != nil {
215-
return err
216-
}
217-
218-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
204+
if err := r(s); err != nil {
219205
return err
220206
}
221207
}
@@ -267,6 +253,9 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
267253
// been just inserted or retrieved using a query with no custom select fields.
268254
func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) (updated int64, err error) {
269255
{{$.GenTimeTruncations .}}
256+
257+
record.SetSaving(true)
258+
defer record.SetSaving(false)
270259
{{if .Events.Has "BeforeSave"}}
271260
if err := record.BeforeSave(); err != nil {
272261
return 0, err
@@ -284,20 +273,11 @@ func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) (
284273
{{if .HasInverses}}
285274
inverseRecords := s.inverseRecords(record)
286275
{{end}}
287-
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}&&{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
276+
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}||{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
288277
err = s.Store.Transaction(func(s *kallax.Store) error {
289278
{{if .HasInverses}}
290279
for _, r := range inverseRecords {
291-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
292-
return err
293-
}
294-
persisted := r.Record.IsPersisted()
295-
296-
if _, err := s.Save(r.Schema, r.Record); err != nil {
297-
return err
298-
}
299-
300-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
280+
if err := r(s); err != nil {
301281
return err
302282
}
303283
}
@@ -310,16 +290,7 @@ func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) (
310290

311291
{{if .HasNonInverses}}
312292
for _, r := range records {
313-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
314-
return err
315-
}
316-
persisted := r.Record.IsPersisted()
317-
318-
if _, err := s.Save(r.Schema, r.Record); err != nil {
319-
return err
320-
}
321-
322-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
293+
if err := r(s); err != nil {
323294
return err
324295
}
325296
}

model.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Model struct {
3838
virtualColumns map[string]Identifier
3939
persisted bool
4040
writable bool
41+
saving bool
4142
}
4243

4344
// NewModel creates a new Model that is writable and not persisted.
@@ -46,6 +47,7 @@ func NewModel() Model {
4647
persisted: false,
4748
writable: true,
4849
virtualColumns: make(map[string]Identifier),
50+
saving: false,
4951
}
5052
}
5153

@@ -72,6 +74,19 @@ func (m *Model) setWritable(w bool) {
7274
m.writable = w
7375
}
7476

77+
// IsSaving reports whether the model is in the process of being saved or not.
78+
func (m *Model) IsSaving() bool {
79+
return m.saving
80+
}
81+
82+
// SetSaving sets the saving status of the model.
83+
// This is an internal function, even though it is exposed for the concrete
84+
// kallax generated code to use. Please, don't use it, and if you do,
85+
// bad things may happen.
86+
func (m *Model) SetSaving(saving bool) {
87+
m.saving = saving
88+
}
89+
7590
// ClearVirtualColumns clears all the previous virtual columns.
7691
// This method is only intended for internal use. It is only exposed for
7792
// technical reasons.
@@ -171,18 +186,33 @@ type VirtualColumnContainer interface {
171186
getVirtualColumns() map[string]Identifier
172187
}
173188

189+
var ErrEmptyVirtualColumn = fmt.Errorf("empty virtual column")
190+
174191
// RecordValues returns the values of a record at the given columns in the same
175192
// order as the columns.
176-
func RecordValues(record Valuer, columns ...string) ([]interface{}, error) {
177-
var values = make([]interface{}, len(columns))
178-
for i, col := range columns {
193+
// It also returns the columns with any empty virtual column removed.
194+
func RecordValues(record Valuer, columns ...string) ([]interface{}, []string, error) {
195+
var cols = make([]string, 0, len(columns))
196+
var values = make([]interface{}, 0, len(columns))
197+
for _, col := range columns {
179198
v, err := record.Value(col)
199+
if err == ErrEmptyVirtualColumn {
200+
continue
201+
}
202+
180203
if err != nil {
181-
return nil, err
204+
return nil, nil, err
182205
}
183-
values[i] = v
206+
values = append(values, v)
207+
cols = append(cols, col)
184208
}
185-
return values, nil
209+
return values, cols, nil
210+
}
211+
212+
// Saveable can report whether it's being saved or change the saving status.
213+
type Saveable interface {
214+
IsSaving() bool
215+
SetSaving(bool)
186216
}
187217

188218
// Record is something that can be stored as a row in the database.
@@ -194,6 +224,7 @@ type Record interface {
194224
ColumnAddresser
195225
Valuer
196226
VirtualColumnContainer
227+
Saveable
197228
}
198229

199230
var randPool = &sync.Pool{

store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (s *Store) Insert(schema Schema, record Record) error {
153153
return ErrNoColumns
154154
}
155155

156-
values, err := RecordValues(record, cols...)
156+
values, cols, err := RecordValues(record, cols...)
157157
if err != nil {
158158
return err
159159
}
@@ -212,7 +212,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
212212
}
213213

214214
columnNames := ColumnNames(cols)
215-
values, err := RecordValues(record, columnNames...)
215+
values, columnNames, err := RecordValues(record, columnNames...)
216216
if err != nil {
217217
return 0, err
218218
}

0 commit comments

Comments
 (0)