Skip to content

Commit 8bfec75

Browse files
authored
Merge pull request #214 from src-d/fix/save-relations
fix how relationships are saved
2 parents ee647cf + 0ad2e1a commit 8bfec75

File tree

12 files changed

+2779
-403
lines changed

12 files changed

+2779
-403
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ if err != nil {
284284
}
285285
```
286286

287-
If our model has relationships, they will be saved (**note:** saved as in insert or update) as well. The relationships of the relationships will not, though. Relationships are only saved with one level of depth.
287+
If our model has relationships, they will be saved, and so will the relationships of the relationships and so on. TL;DR: inserts are recursive.
288+
**Note:** the relationships will be saved using `Save`, not `Insert`.
288289

289290
```go
290291
user := NewUser("foo")
@@ -319,7 +320,8 @@ if err != nil {
319320
}
320321
```
321322

322-
If our model has relationships, they will be saved (**note:** saved as in insert or update) as well. The relationships of the relationships will not, though. Relationships are only saved with one level of depth.
323+
If our model has relationships, they will be saved, and so will the relationships of the relationships and so on. TL;DR: updates are recursive.
324+
**Note:** the relationships will be saved using `Save`, not `Update`.
323325

324326
```go
325327
user := FindLastPoster()
@@ -346,7 +348,7 @@ if updated {
346348
}
347349
```
348350

349-
If our model has relationships, they will be saved as well. The relationships of the relationships will not, though. Relationships are only saved with one level of depth.
351+
If our model has relationships, they will be saved, and so will the relationships of the relationships and so on. TL;DR: saves are recursive.
350352

351353
```go
352354
user := NewUser("foo")

generator/template.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ func (td *TemplateData) GenColumnValues(model *Model) string {
134134
var buf bytes.Buffer
135135
td.genFieldsValues(&buf, model.Fields)
136136
for _, fk := range model.ImplicitFKs {
137-
buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name))
138-
buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n"))
137+
buf.WriteString(fmt.Sprintf(virtualFieldValueTpl, fk.Name))
139138
}
140139
return buf.String()
141140
}
@@ -145,13 +144,20 @@ const nilPtrReturnsUntypedNilTpl = `if %s == (*%s)(nil) {
145144
}
146145
`
147146

147+
const virtualFieldValueTpl = `case "%s":
148+
v := r.Model.VirtualColumn(col)
149+
if v == nil {
150+
return nil, kallax.ErrEmptyVirtualColumn
151+
}
152+
return v, nil
153+
`
154+
148155
func (td *TemplateData) genFieldsValues(buf *bytes.Buffer, fields []*Field) {
149156
for _, f := range fields {
150157
if f.Inline() {
151158
td.genFieldsValues(buf, f.Fields)
152159
} else if isOneToOneRelationship(f) && f.IsInverse() {
153-
buf.WriteString(fmt.Sprintf("case \"%s\":\n", f.ForeignKey()))
154-
buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n"))
160+
buf.WriteString(fmt.Sprintf(virtualFieldValueTpl, f.ForeignKey()))
155161
} else if f.Kind != Relationship {
156162
buf.WriteString(fmt.Sprintf("case \"%s\":\n", f.ColumnName()))
157163
if f.IsPtr {
@@ -529,7 +535,7 @@ const (
529535
// The passed values to the FindBy will be used in an kallax.ArrayContains
530536
tplFindByCollection = `
531537
// FindBy%[1]s adds a new filter to the query that will require that
532-
// the %[1]s property contains all the passed values; if no passed values,
538+
// the %[1]s property contains all the passed values; if no passed values,
533539
// it will do nothing.
534540
func (q *%[2]s) FindBy%[1]s(v ...%[3]s) *%[2]s {
535541
if len(v) == 0 {return q}
@@ -557,7 +563,7 @@ const (
557563
// The passed values to the FindBy will be used in an kallax.In condition.
558564
tplFindByID = `
559565
// FindBy%[1]s adds a new filter to the query that will require that
560-
// the %[1]s property is equal to one of the passed values; if no passed values,
566+
// the %[1]s property is equal to one of the passed values; if no passed values,
561567
// it will do nothing.
562568
func (q *%[2]s) FindBy%[1]s(v ...%[3]s) *%[2]s {
563569
if len(v) == 0 {return q}

generator/template_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ return (*types.URL)(r.URL), nil
141141
case "url_no_ptr":
142142
return (*types.URL)(&r.UrlNoPtr), nil
143143
case "rel_id":
144-
return r.Model.VirtualColumn(col), nil
144+
v := r.Model.VirtualColumn(col)
145+
if v == nil {
146+
return nil, kallax.ErrEmptyVirtualColumn
147+
}
148+
return v, nil
145149
`
146150

147151
func (s *TemplateSuite) TestGenColumnValues() {

generator/templates/base.tgo

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ import (
1515
var _ types.SQLType
1616
var _ fmt.Formatter
1717

18+
type modelSaveFunc func(*kallax.Store) error
19+
1820
{{template "model" .}}
1921
{{template "schema" .}}

generator/templates/model.tgo

Lines changed: 36 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -119,53 +119,57 @@ func (s *{{.StoreName}}) DebugWith(logger kallax.LoggerFunc) *{{.StoreName}} {
119119
}
120120

121121
{{if .HasNonInverses}}
122-
func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []kallax.RecordWithSchema {
123-
var records []kallax.RecordWithSchema
122+
func (s *{{.StoreName}}) relationshipRecords(record *{{.Name}}) []modelSaveFunc {
123+
var result []modelSaveFunc
124124
{{range .NonInverses}}
125125
{{if .IsOneToManyRelationship}}
126126
for i := range record.{{.Name}} {
127-
record.{{.Name}}[i].ClearVirtualColumns()
128-
record.{{.Name}}[i].AddVirtualColumn("{{.ForeignKey}}", record.GetID())
129-
records = append(records, kallax.RecordWithSchema{
130-
Schema: Schema.{{.TypeSchemaName}}.BaseSchema,
131-
Record: {{if not ($.IsPtrSlice .)}}&{{end}}record.{{.Name}}[i],
132-
})
127+
r := {{if not ($.IsPtrSlice .)}}&{{end}}record.{{.Name}}[i]
128+
if !r.IsSaving() {
129+
r.AddVirtualColumn("{{.ForeignKey}}", record.GetID())
130+
result = append(result, func(store *kallax.Store) error {
131+
_, err := (&{{.TypeSchemaName}}Store{store}).Save(r)
132+
return err
133+
})
134+
}
133135
}
134136
{{else}}
135-
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} {
136-
record.{{.Name}}.ClearVirtualColumns()
137-
record.{{.Name}}.AddVirtualColumn("{{.ForeignKey}}", record.GetID())
138-
records = append(records, kallax.RecordWithSchema{
139-
Schema: Schema.{{.TypeSchemaName}}.BaseSchema,
140-
Record: {{if not .IsPtr}}&{{end}}record.{{.Name}},
137+
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} && !record.{{.Name}}.IsSaving() {
138+
r := {{if not .IsPtr}}&{{end}}record.{{.Name}}
139+
r.AddVirtualColumn("{{.ForeignKey}}", record.GetID())
140+
result = append(result, func(store *kallax.Store) error {
141+
_, err := (&{{.TypeSchemaName}}Store{store}).Save(r)
142+
return err
141143
})
142144
}
143145
{{end}}
144146
{{end}}
145-
return records
147+
return result
146148
}
147149
{{end}}
148150

149151
{{if .HasInverses}}
150-
func (s *{{.StoreName}}) inverseRecords(record *{{.Name}}) []kallax.RecordWithSchema {
151-
record.ClearVirtualColumns()
152-
var records []kallax.RecordWithSchema
152+
func (s *{{.StoreName}}) inverseRecords(record *{{.Name}}) []modelSaveFunc {
153+
var result []modelSaveFunc
153154
{{range .Inverses}}
154-
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} {
155+
if {{if .IsPtr}}record.{{.Name}} != nil{{else}}!record.{{.Name}}.GetID().IsEmpty(){{end}} && !record.{{.Name}}.IsSaving() {
155156
record.AddVirtualColumn("{{.ForeignKey}}", record.{{.Name}}.GetID())
156-
records = append(records, kallax.RecordWithSchema{
157-
Schema: Schema.{{.TypeSchemaName}}.BaseSchema,
158-
Record: {{if not .IsPtr}}&{{end}}record.{{.Name}},
157+
result = append(result, func(store *kallax.Store) error {
158+
_, err := (&{{.TypeSchemaName}}Store{store}).Save(record.{{.Name}})
159+
return err
159160
})
160161
}
161162
{{end}}
162-
return records
163+
return result
163164
}
164165
{{end}}
165166

166167
// Insert inserts a {{.Name}} in the database. A non-persisted object is
167168
// required for this operation.
168169
func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
170+
record.SetSaving(true)
171+
defer record.SetSaving(false)
172+
169173
{{$.GenTimeTruncations .}}
170174
{{if .Events.Has "BeforeSave"}}
171175
if err := record.BeforeSave(); err != nil {
@@ -183,20 +187,11 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
183187
{{if .HasInverses}}
184188
inverseRecords := s.inverseRecords(record)
185189
{{end}}
186-
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}&&{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
190+
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}||{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
187191
return s.Store.Transaction(func(s *kallax.Store) error {
188192
{{if .HasInverses}}
189193
for _, r := range inverseRecords {
190-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
191-
return err
192-
}
193-
persisted := r.Record.IsPersisted()
194-
195-
if _, err := s.Save(r.Schema, r.Record); err != nil {
196-
return err
197-
}
198-
199-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
194+
if err := r(s); err != nil {
200195
return err
201196
}
202197
}
@@ -206,16 +201,7 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
206201
}
207202
{{if .HasNonInverses}}
208203
for _, r := range records {
209-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
210-
return err
211-
}
212-
persisted := r.Record.IsPersisted()
213-
214-
if _, err := s.Save(r.Schema, r.Record); err != nil {
215-
return err
216-
}
217-
218-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
204+
if err := r(s); err != nil {
219205
return err
220206
}
221207
}
@@ -267,6 +253,9 @@ func (s *{{.StoreName}}) Insert(record *{{.Name}}) error {
267253
// been just inserted or retrieved using a query with no custom select fields.
268254
func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) (updated int64, err error) {
269255
{{$.GenTimeTruncations .}}
256+
257+
record.SetSaving(true)
258+
defer record.SetSaving(false)
270259
{{if .Events.Has "BeforeSave"}}
271260
if err := record.BeforeSave(); err != nil {
272261
return 0, err
@@ -284,20 +273,11 @@ func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) (
284273
{{if .HasInverses}}
285274
inverseRecords := s.inverseRecords(record)
286275
{{end}}
287-
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}&&{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
276+
if {{if .HasNonInverses}}len(records) > 0{{end}} {{if and (.HasNonInverses) (.HasInverses)}}||{{end}} {{if .HasInverses}}len(inverseRecords) > 0{{end}} {
288277
err = s.Store.Transaction(func(s *kallax.Store) error {
289278
{{if .HasInverses}}
290279
for _, r := range inverseRecords {
291-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
292-
return err
293-
}
294-
persisted := r.Record.IsPersisted()
295-
296-
if _, err := s.Save(r.Schema, r.Record); err != nil {
297-
return err
298-
}
299-
300-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
280+
if err := r(s); err != nil {
301281
return err
302282
}
303283
}
@@ -310,16 +290,7 @@ func (s *{{.StoreName}}) Update(record *{{.Name}}, cols ...kallax.SchemaField) (
310290

311291
{{if .HasNonInverses}}
312292
for _, r := range records {
313-
if err := kallax.ApplyBeforeEvents(r.Record); err != nil {
314-
return err
315-
}
316-
persisted := r.Record.IsPersisted()
317-
318-
if _, err := s.Save(r.Schema, r.Record); err != nil {
319-
return err
320-
}
321-
322-
if err := kallax.ApplyAfterEvents(r.Record, persisted); err != nil {
293+
if err := r(s); err != nil {
323294
return err
324295
}
325296
}

model.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Model struct {
3838
virtualColumns map[string]Identifier
3939
persisted bool
4040
writable bool
41+
saving bool
4142
}
4243

4344
// NewModel creates a new Model that is writable and not persisted.
@@ -46,6 +47,7 @@ func NewModel() Model {
4647
persisted: false,
4748
writable: true,
4849
virtualColumns: make(map[string]Identifier),
50+
saving: false,
4951
}
5052
}
5153

@@ -72,6 +74,19 @@ func (m *Model) setWritable(w bool) {
7274
m.writable = w
7375
}
7476

77+
// IsSaving reports whether the model is in the process of being saved or not.
78+
func (m *Model) IsSaving() bool {
79+
return m.saving
80+
}
81+
82+
// SetSaving sets the saving status of the model.
83+
// This is an internal function, even though it is exposed for the concrete
84+
// kallax generated code to use. Please, don't use it, and if you do,
85+
// bad things may happen.
86+
func (m *Model) SetSaving(saving bool) {
87+
m.saving = saving
88+
}
89+
7590
// ClearVirtualColumns clears all the previous virtual columns.
7691
// This method is only intended for internal use. It is only exposed for
7792
// technical reasons.
@@ -171,18 +186,33 @@ type VirtualColumnContainer interface {
171186
getVirtualColumns() map[string]Identifier
172187
}
173188

189+
var ErrEmptyVirtualColumn = fmt.Errorf("empty virtual column")
190+
174191
// RecordValues returns the values of a record at the given columns in the same
175192
// order as the columns.
176-
func RecordValues(record Valuer, columns ...string) ([]interface{}, error) {
177-
var values = make([]interface{}, len(columns))
178-
for i, col := range columns {
193+
// It also returns the columns with any empty virtual column removed.
194+
func RecordValues(record Valuer, columns ...string) ([]interface{}, []string, error) {
195+
var cols = make([]string, 0, len(columns))
196+
var values = make([]interface{}, 0, len(columns))
197+
for _, col := range columns {
179198
v, err := record.Value(col)
199+
if err == ErrEmptyVirtualColumn {
200+
continue
201+
}
202+
180203
if err != nil {
181-
return nil, err
204+
return nil, nil, err
182205
}
183-
values[i] = v
206+
values = append(values, v)
207+
cols = append(cols, col)
184208
}
185-
return values, nil
209+
return values, cols, nil
210+
}
211+
212+
// Saveable can report whether it's being saved or change the saving status.
213+
type Saveable interface {
214+
IsSaving() bool
215+
SetSaving(bool)
186216
}
187217

188218
// Record is something that can be stored as a row in the database.
@@ -194,6 +224,7 @@ type Record interface {
194224
ColumnAddresser
195225
Valuer
196226
VirtualColumnContainer
227+
Saveable
197228
}
198229

199230
var randPool = &sync.Pool{

store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func (s *Store) Insert(schema Schema, record Record) error {
153153
return ErrNoColumns
154154
}
155155

156-
values, err := RecordValues(record, cols...)
156+
values, cols, err := RecordValues(record, cols...)
157157
if err != nil {
158158
return err
159159
}
@@ -212,7 +212,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
212212
}
213213

214214
columnNames := ColumnNames(cols)
215-
values, err := RecordValues(record, columnNames...)
215+
values, columnNames, err := RecordValues(record, columnNames...)
216216
if err != nil {
217217
return 0, err
218218
}

0 commit comments

Comments
 (0)