From bf8b28051d53832816baab25f92d3e117b421337 Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Mon, 11 Aug 2025 15:10:31 +0300 Subject: [PATCH 1/7] add: partitionFilesLimit param, confirm func, renamed continue-generation and F flag --- doc/en/usage.md | 8 +- doc/ru/usage.md | 8 +- internal/generator/cli/commands/consts.go | 8 +- .../cli/commands/generate/generate.go | 39 ++++--- .../generator/cli/commands/serve/handlers.go | 2 +- internal/generator/cli/confirm/confirm.go | 108 ++++++++++++++++++ internal/generator/cli/progress/bar/bar.go | 5 + internal/generator/cli/progress/interfaces.go | 2 + internal/generator/cli/progress/log/log.go | 24 +++- internal/generator/cli/render/interfaces.go | 10 +- .../generator/cli/render/prompt/prompt.go | 39 ++++--- .../cli/render/prompt/prompt_test.go | 2 +- internal/generator/cli/streams/in.go | 9 +- internal/generator/cli/streams/out.go | 9 +- internal/generator/cli/utils/utils.go | 10 ++ internal/generator/models/generator_output.go | 36 ++++-- .../generator/output/general/model_writer.go | 41 ++++++- internal/generator/output/general/output.go | 10 +- internal/generator/usecase/general/task.go | 1 + 19 files changed, 303 insertions(+), 68 deletions(-) create mode 100644 internal/generator/cli/confirm/confirm.go diff --git a/doc/en/usage.md b/doc/en/usage.md index 3605b33..b27cf40 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -177,6 +177,8 @@ Structure `output.params` for format `csv`: - `datetime_format`: Date-time format. Default is `2006-01-02T15:04:05Z07:00`. - `without_headers`: Flag indicating if CSV headers should be excluded from data files. - `delimiter`: Single-character CSV delimiter. Default is `,`. +- `partition_files_limit`: Limit on the number of partition files, upon reaching which a prompt will appear asking whether to continue. + Ignored if the `--force` flag is specified. Default is `1000`. Structure `output.params` for format `parquet`: @@ -184,6 +186,8 @@ Structure `output.params` for format `parquet`: Default is `UNCOMPRESSED`. - `float_precision`: Floating-point number precision. Default is `2`. - `datetime_format`: Date-time format. Supported values: `millis`, `micros`. Default is `millis`. +- `partition_files_limit`: Limit on the number of partition files, upon reaching which a prompt will appear asking whether to continue. + Ignored if the `--force` flag is specified. Default is `1000`. Structure `output.params` for format `http`: @@ -458,7 +462,7 @@ sdvg generate ./models.yml ### Ignoring conflicts If you want to automatically remove conflicting files from the output directory -and continue generation without additional prompts, use the `-F` or `--force` flag: +and continue generation without additional prompts, use the `-f` or `--force` flag: ```shell sdvg generate --force ./models.yml @@ -469,7 +473,7 @@ sdvg generate --force ./models.yml To continue generation from the last recorded row: ```shell -sdvg generate --continue-generation ./models.yml +sdvg generate --continue ./models.yml ``` > **Important**: To correctly continue generation, you must not change the generation configuration diff --git a/doc/ru/usage.md b/doc/ru/usage.md index b88f2ec..90b3d9b 100644 --- a/doc/ru/usage.md +++ b/doc/ru/usage.md @@ -183,6 +183,8 @@ open_ai: - `datetime_format`: Формат даты и времени. По умолчанию `2006-01-02T15:04:05Z07:00`. - `without_headers`: Флаг, указывающий, исключать ли CSV заголовок из файлов с данными. - `delimiter`: Односимвольный CSV разделитель. По умолчанию `,`. +- `partition_files_limit`: Ограничение количества файлов партиций, при достижении которого всплывет вопрос о продолжении. + Игнорируется при указании флага `--force`. По умолчанию `1000` Структура `output.params` для формата `parquet`: @@ -190,6 +192,8 @@ open_ai: По умолчанию `UNCOMPRESSED`. - `float_precision`: Точность чисел с плавающей запятой. По умолчанию `2`. - `datetime_format`: Формат даты и времени. Поддерживаемые значения: `millis`, `micros`. По умолчанию `millis`. +- `partition_files_limit`: Ограничение количества файлов партиций, при достижении которого всплывет вопрос о продолжении. + Игнорируется при указании флага `--force`. По умолчанию `1000` Структура `output.params` для формата `http`: @@ -464,7 +468,7 @@ sdvg generate ./models.yml ### Игнорирование конфликтов Если вы хотите автоматически удалить конфликтующие файлы в выходной директории -и продолжить генерацию без дополнительных сообщений, используйте флаг `-F` или `--force`: +и продолжить генерацию без дополнительных сообщений, используйте флаг `-f` или `--force`: ```shell sdvg generate --force ./models.yml @@ -475,7 +479,7 @@ sdvg generate --force ./models.yml Для продолжения генерации с последней записанной строки: ```shell -sdvg generate --continue-generation ./models.yml +sdvg generate --continue ./models.yml ``` > **Важно**: для корректного продолжения генерации нельзя менять конфигурацию генерации и уже сгенерированные данные. diff --git a/internal/generator/cli/commands/consts.go b/internal/generator/cli/commands/consts.go index 35271bb..511c81a 100644 --- a/internal/generator/cli/commands/consts.go +++ b/internal/generator/cli/commands/consts.go @@ -6,15 +6,15 @@ const ( ConfigPathDefaultValue = "" ConfigPathUsage = "Location of config file" - ContinueGenerationFlag = "continue-generation" - ContinueGenerationShortFlag = "C" + ContinueGenerationFlag = "continue" + ContinueGenerationShortFlag = "c" ContinueGenerationDefaultValue = false ContinueGenerationUsage = "Continue generation from the last recorded row" ForceGenerationFlag = "force" - ForceGenerationShortFlag = "F" + ForceGenerationShortFlag = "f" ForceGenerationFlagDefaultValue = false - ForceGenerationUsage = "Force generation even if output file conflicts found" + ForceGenerationUsage = "Force generation even if output file conflicts found and partition files limit reached" TTYFlag = "tty" TTYShortFlag = "t" diff --git a/internal/generator/cli/commands/generate/generate.go b/internal/generator/cli/commands/generate/generate.go index caa2837..6a46d43 100644 --- a/internal/generator/cli/commands/generate/generate.go +++ b/internal/generator/cli/commands/generate/generate.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/tarantool/sdvg/internal/generator/cli/commands" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/cli/options" "github.com/tarantool/sdvg/internal/generator/cli/progress" "github.com/tarantool/sdvg/internal/generator/cli/progress/bar" @@ -124,7 +125,9 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { return err } - out := general.NewOutput(generationCfg, opts.continueGeneration, opts.forceGeneration) + progressTrackerManager, confirm := initProgressTrackerManager(ctx, opts.renderer, opts.useTTY) + + out := general.NewOutput(generationCfg, opts.continueGeneration, opts.forceGeneration, confirm) taskID, err := opts.useCase.CreateTask( ctx, usecase.TaskConfig{ @@ -143,12 +146,11 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { ) startProgressTracking( - ctx, + progressTrackerManager, opts.useCase, taskID, &finished, &wg, - opts.useTTY, ) err = opts.useCase.WaitResult(taskID) @@ -173,26 +175,37 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { return nil } +// initProgressTrackerManager inits progress bar manager (progress.Tracker) and builds streams.Confirm func based on useTTY +func initProgressTrackerManager(ctx context.Context, renderer render.Renderer, useTTY bool) (progress.Tracker, confirm.Confirm) { + var progressTrackerManager progress.Tracker + var confirmFunc confirm.Confirm + + if useTTY { + progressTrackerManager = bar.NewProgressBarManager(ctx) + + confirmFunc = confirm.BuildConfirmTTY(renderer, progressTrackerManager) + } else { + isUpdatePaused := &atomic.Bool{} + + progressTrackerManager = log.NewProgressLogManager(ctx, isUpdatePaused) + + confirmFunc = confirm.BuildConfirmNoTTY(renderer, progressTrackerManager, isUpdatePaused) + } + + return progressTrackerManager, confirmFunc +} + // startProgressTracking runs function to track progress of task // by getting progress from usecase object and displaying it. func startProgressTracking( - ctx context.Context, + progressTrackerManager progress.Tracker, uc usecase.UseCase, taskID string, finished *atomic.Bool, wg *sync.WaitGroup, - useTTY bool, ) { const delay = 500 * time.Millisecond - var progressTrackerManager progress.Tracker - - if useTTY { - progressTrackerManager = bar.NewProgressBarManager(ctx) - } else { - progressTrackerManager = log.NewProgressLogManager(ctx) - } - wg.Add(1) go func() { diff --git a/internal/generator/cli/commands/serve/handlers.go b/internal/generator/cli/commands/serve/handlers.go index 7a36fd8..0d709f0 100644 --- a/internal/generator/cli/commands/serve/handlers.go +++ b/internal/generator/cli/commands/serve/handlers.go @@ -58,7 +58,7 @@ func handleGenerate(opts handlerOptions, c echo.Context) error { generationConfig.OutputConfig.Dir = models.DefaultOutputDir - out := general.NewOutput(&generationConfig, false, true) + out := general.NewOutput(&generationConfig, false, true, nil) taskID, err := opts.useCase.CreateTask( c.Request().Context(), usecase.TaskConfig{ diff --git a/internal/generator/cli/confirm/confirm.go b/internal/generator/cli/confirm/confirm.go new file mode 100644 index 0000000..20986a9 --- /dev/null +++ b/internal/generator/cli/confirm/confirm.go @@ -0,0 +1,108 @@ +package confirm + +import ( + "context" + "fmt" + "io" + "strings" + "sync/atomic" + + "github.com/manifoldco/promptui" + "github.com/pkg/errors" + "github.com/tarantool/sdvg/internal/generator/cli/render" + "github.com/tarantool/sdvg/internal/generator/cli/utils" +) + +// Confirm asks user a yes/no question. Returns true for “yes”. +type Confirm func(ctx context.Context, question string) (bool, error) + +func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, question string) (bool, error) { + return func(ctx context.Context, question string) (bool, error) { + fmt.Fprintln(out) + + prompt := promptui.Prompt{ + Label: question + " [y/N]: ", + Default: "y", + Stdin: utils.DummyReadWriteCloser{Reader: in}, + Stdout: utils.DummyReadWriteCloser{Writer: out}, + } + validate := func(s string) error { + if len(s) == 1 && strings.Contains("YyNn", s) || prompt.Default != "" && len(s) == 0 { + return nil + } + return errors.New("invalid input") + } + prompt.Validate = validate + + var ( + input string + err error + promptFinished = make(chan struct{}) + ) + + go func() { + input, err = prompt.Run() // goroutine will block here until user input + + promptFinished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-promptFinished: + } + + if err != nil { + return false, errors.WithMessage(err, "confirm prompt failed") + } + + return strings.Contains("Yy", input), nil + } +} + +func BuildConfirmNoTTY(in render.Renderer, out io.Writer, isUpdatePaused *atomic.Bool) func(ctx context.Context, question string) (bool, error) { + return func(ctx context.Context, question string) (bool, error) { + // here we pause ProgressLogManager to stop sending progress messages + isUpdatePaused.Store(true) + defer isUpdatePaused.Store(false) + + for { + fmt.Fprintf(out, "%s [y/N]: ", question) + + var ( + input string + err error + inputReadFinished = make(chan struct{}) + ) + + go func() { + input, err = in.ReadLine() // goroutine will block here until user input + + inputReadFinished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-inputReadFinished: + } + + if err != nil { + return false, err + } + + if !in.IsTerminal() { + fmt.Fprintln(out, input) + } + + switch strings.ToLower(strings.TrimSpace(input)) { + case "y", "yes": + return true, nil + case "", "n", "no": + return false, nil + default: + fmt.Fprintln(out, "Please enter y or n") + } + } + } +} diff --git a/internal/generator/cli/progress/bar/bar.go b/internal/generator/cli/progress/bar/bar.go index baf1fb3..d1242db 100644 --- a/internal/generator/cli/progress/bar/bar.go +++ b/internal/generator/cli/progress/bar/bar.go @@ -77,3 +77,8 @@ func (p *ProgressBarManager) UpdateProgress(name string, progress usecase.Progre func (p *ProgressBarManager) Wait() { p.progressManager.Wait() } + +// Write writes to stdout. +func (p *ProgressBarManager) Write(b []byte) (int, error) { + return p.progressManager.Write(b) +} diff --git a/internal/generator/cli/progress/interfaces.go b/internal/generator/cli/progress/interfaces.go index 07e3027..57d8329 100644 --- a/internal/generator/cli/progress/interfaces.go +++ b/internal/generator/cli/progress/interfaces.go @@ -10,4 +10,6 @@ type Tracker interface { UpdateProgress(name string, progress usecase.Progress) // Wait function should wait for all tracked tasks to complete. Wait() + // Write function should write to stdout. + Write(b []byte) (int, error) } diff --git a/internal/generator/cli/progress/log/log.go b/internal/generator/cli/progress/log/log.go index 2662a0c..d259b25 100644 --- a/internal/generator/cli/progress/log/log.go +++ b/internal/generator/cli/progress/log/log.go @@ -5,7 +5,9 @@ import ( "fmt" "log/slog" "math" + "os" "sync" + "sync/atomic" "time" "github.com/tarantool/sdvg/internal/generator/cli/progress" @@ -42,13 +44,16 @@ type ProgressLogManager struct { ctx context.Context //nolint:containedctx tasks map[string]*task wg sync.WaitGroup + + isUpdatePaused *atomic.Bool } -// NewProgressLogManager creates NewProgressLogManager object. -func NewProgressLogManager(ctx context.Context) progress.Tracker { +// NewProgressLogManager creates NewProgressLogManager object. isUpdatePaused is used to pause UpdateProgress. +func NewProgressLogManager(ctx context.Context, isUpdatePaused *atomic.Bool) progress.Tracker { return &ProgressLogManager{ - ctx: ctx, - tasks: make(map[string]*task), + ctx: ctx, + tasks: make(map[string]*task), + isUpdatePaused: isUpdatePaused, } } @@ -78,6 +83,12 @@ func (p *ProgressLogManager) UpdateProgress(name string, progress usecase.Progre return } + for p.isUpdatePaused.Load() { + if t.isDone() { + return + } + } + p.updateIntervals(t, progress.Done) t.current = progress.Done @@ -138,3 +149,8 @@ func (p *ProgressLogManager) eta(t *task) string { return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) } + +// Write writes to default stdout. +func (p *ProgressLogManager) Write(b []byte) (int, error) { + return os.Stdout.Write(b) +} diff --git a/internal/generator/cli/render/interfaces.go b/internal/generator/cli/render/interfaces.go index cc479b6..1845a6d 100644 --- a/internal/generator/cli/render/interfaces.go +++ b/internal/generator/cli/render/interfaces.go @@ -1,6 +1,8 @@ package render -import "context" +import ( + "context" +) // Renderer interface implementation should render interactive menu. // @@ -16,4 +18,10 @@ type Renderer interface { TextMenu(ctx context.Context, title string) (string, error) // WithSpinner should display spinner. WithSpinner(title string, fn func()) + // IsTerminal should return true if renderer is connected to a terminal. + IsTerminal() bool + // ReadLine should read input from input stream. + ReadLine() (string, error) + // Read should read from input stream. + Read(p []byte) (int, error) } diff --git a/internal/generator/cli/render/prompt/prompt.go b/internal/generator/cli/render/prompt/prompt.go index 67832af..91206d3 100644 --- a/internal/generator/cli/render/prompt/prompt.go +++ b/internal/generator/cli/render/prompt/prompt.go @@ -90,7 +90,7 @@ func (r *Renderer) SelectionMenu(ctx context.Context, title string, items []stri for { _, _ = fmt.Fprint(r.out, "Write a number: ") - input, err := r.readLine() + input, err := r.ReadLine() if err != nil { resultChan <- result{err: err} @@ -147,7 +147,7 @@ func (r *Renderer) InputMenu(ctx context.Context, title string, validateFunc fun for { _, _ = fmt.Fprintf(r.out, "%s: ", title) - input, err := r.readLine() + input, err := r.ReadLine() if err != nil { resultChan <- result{err: err} @@ -252,6 +252,28 @@ func (r *Renderer) WithSpinner(title string, fn func()) { fn() } +// ReadLine reads input from stdin. +func (r *Renderer) ReadLine() (string, error) { + if r.scanner.Scan() { + return strings.TrimSpace(r.scanner.Text()), nil + } + + if err := r.scanner.Err(); err != nil { + return "", errors.New(err.Error()) + } + + return "", errors.New(io.EOF.Error()) +} + +// IsTerminal returns true if this stream is connected to a terminal. +func (r *Renderer) IsTerminal() bool { + return r.in.IsTerminal() +} + +func (r *Renderer) Read(p []byte) (int, error) { + return r.in.Read(p) +} + // selectionPrompt returns prompt for selection items. func (r *Renderer) selectionPrompt(title string, items []string) promptui.Select { templates := &promptui.SelectTemplates{ @@ -371,19 +393,6 @@ func (r *Renderer) readFile(filePath string) (string, error) { return strings.TrimSpace(sb.String()), nil } -// readInput reads input from stdin. -func (r *Renderer) readLine() (string, error) { - if r.scanner.Scan() { - return strings.TrimSpace(r.scanner.Text()), nil - } - - if err := r.scanner.Err(); err != nil { - return "", errors.New(err.Error()) - } - - return "", errors.New(io.EOF.Error()) -} - func (r *Renderer) readMultiline() (string, error) { var sb strings.Builder diff --git a/internal/generator/cli/render/prompt/prompt_test.go b/internal/generator/cli/render/prompt/prompt_test.go index 1e4da74..6253d5c 100644 --- a/internal/generator/cli/render/prompt/prompt_test.go +++ b/internal/generator/cli/render/prompt/prompt_test.go @@ -427,7 +427,7 @@ func readLinesTestFunc(t *testing.T, tc readLinesTestCase, mode int) { switch mode { case SingleLine: - actual, err = renderer.readLine() + actual, err = renderer.ReadLine() case MultiLine: actual, err = renderer.readMultiline() } diff --git a/internal/generator/cli/streams/in.go b/internal/generator/cli/streams/in.go index eff6e5a..38574bb 100644 --- a/internal/generator/cli/streams/in.go +++ b/internal/generator/cli/streams/in.go @@ -5,14 +5,9 @@ import ( "io" "github.com/moby/term" + "github.com/tarantool/sdvg/internal/generator/cli/utils" ) -type nopReadCloser struct { - io.Reader -} - -func (nopReadCloser) Close() error { return nil } - // In is an input stream to read user input. It implements [io.ReadCloser]. type In struct { isTerminal bool @@ -26,7 +21,7 @@ func NewIn(in io.Reader) *In { if readCloser, ok := in.(io.ReadCloser); ok { i.in = readCloser } else { - i.in = nopReadCloser{in} + i.in = utils.DummyReadWriteCloser{Reader: in} } _, i.isTerminal = term.GetFdInfo(in) diff --git a/internal/generator/cli/streams/out.go b/internal/generator/cli/streams/out.go index 58ce498..83692f8 100644 --- a/internal/generator/cli/streams/out.go +++ b/internal/generator/cli/streams/out.go @@ -5,14 +5,9 @@ import ( "io" "github.com/moby/term" + "github.com/tarantool/sdvg/internal/generator/cli/utils" ) -type nopWriteCloser struct { - io.Writer -} - -func (nopWriteCloser) Close() error { return nil } - // Out is an output stream to write normal program output. It implements an [io.WriteCloser]. type Out struct { isTerminal bool @@ -26,7 +21,7 @@ func NewOut(out io.Writer) *Out { if writeCloser, ok := out.(io.WriteCloser); ok { o.out = writeCloser } else { - o.out = nopWriteCloser{out} + o.out = utils.DummyReadWriteCloser{Writer: out} } _, o.isTerminal = term.GetFdInfo(out) diff --git a/internal/generator/cli/utils/utils.go b/internal/generator/cli/utils/utils.go index a57d26a..aeefb42 100644 --- a/internal/generator/cli/utils/utils.go +++ b/internal/generator/cli/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "io" "path/filepath" "slices" "strings" @@ -98,3 +99,12 @@ func ChooseCommand(cmd *cobra.Command, args []string, renderer render.Renderer) return nil } + +type DummyReadWriteCloser struct { + io.Reader + io.Writer +} + +func (rwc DummyReadWriteCloser) Close() error { + return nil +} diff --git a/internal/generator/models/generator_output.go b/internal/generator/models/generator_output.go index 9f7e729..8bdcc51 100644 --- a/internal/generator/models/generator_output.go +++ b/internal/generator/models/generator_output.go @@ -18,6 +18,8 @@ const ( tcsTimeoutHeader = "x-tcs-timeout_ms" ParquetDateTimeMillisFormat = "millis" ParquetDateTimeMicrosFormat = "micros" + + PartitionFilesLimitDefault = 1000 ) // DataRow type is used to represent any data row that was generated. @@ -167,10 +169,11 @@ var _ Field = (*CSVConfig)(nil) // CSVConfig type used to describe output config for CSV implementation. type CSVConfig struct { - FloatPrecision int `json:"float_precision" yaml:"float_precision"` - DatetimeFormat string `json:"datetime_format" yaml:"datetime_format"` - Delimiter string `backup:"true" json:"delimiter" yaml:"delimiter"` - WithoutHeaders bool `backup:"true" json:"without_headers" yaml:"without_headers"` + FloatPrecision int `json:"float_precision" yaml:"float_precision"` + DatetimeFormat string `json:"datetime_format" yaml:"datetime_format"` + Delimiter string `backup:"true" json:"delimiter" yaml:"delimiter"` + WithoutHeaders bool `backup:"true" json:"without_headers" yaml:"without_headers"` + PartitionFilesLimit *int `json:"partition_files_limit" yaml:"partition_files_limit"` } func (c *CSVConfig) Parse() error { return nil } @@ -187,6 +190,11 @@ func (c *CSVConfig) FillDefaults() { if c.Delimiter == "" { c.Delimiter = "," } + + if c.PartitionFilesLimit == nil { + c.PartitionFilesLimit = new(int) + *c.PartitionFilesLimit = 1000 + } } func (c *CSVConfig) Validate() []error { @@ -200,6 +208,10 @@ func (c *CSVConfig) Validate() []error { errs = append(errs, errors.Errorf("the delimiter must consist of one character, got %v", c.Delimiter)) } + if c.PartitionFilesLimit != nil && *c.PartitionFilesLimit <= 0 { + errs = append(errs, errors.Errorf("partition files limit should be greater than 0, got: %v", *c.PartitionFilesLimit)) + } + return errs } @@ -295,9 +307,10 @@ var _ Field = (*ParquetConfig)(nil) // ParquetConfig type used to describe output config for parquet implementation. type ParquetConfig struct { - CompressionCodec string `backup:"true" json:"compression_codec" yaml:"compression_codec"` - FloatPrecision int `json:"float_precision" yaml:"float_precision"` - DateTimeFormat string `json:"datetime_format" yaml:"datetime_format"` + CompressionCodec string `backup:"true" json:"compression_codec" yaml:"compression_codec"` + FloatPrecision int `json:"float_precision" yaml:"float_precision"` + DateTimeFormat string `json:"datetime_format" yaml:"datetime_format"` + PartitionFilesLimit *int `json:"partition_files_limit" yaml:"partition_files_limit"` } //nolint:lll @@ -318,6 +331,11 @@ func (c *ParquetConfig) FillDefaults() { if c.DateTimeFormat == "" { c.DateTimeFormat = ParquetDateTimeMillisFormat } + + if c.PartitionFilesLimit == nil { + c.PartitionFilesLimit = new(int) + *c.PartitionFilesLimit = 1000 + } } func (c *ParquetConfig) Validate() []error { @@ -337,5 +355,9 @@ func (c *ParquetConfig) Validate() []error { c.DateTimeFormat, parquetSupportedDateTimeFormats)) } + if c.PartitionFilesLimit != nil && *c.PartitionFilesLimit <= 0 { + errs = append(errs, errors.Errorf("partition files limit should be greater than 0, got: %v", *c.PartitionFilesLimit)) + } + return errs } diff --git a/internal/generator/output/general/model_writer.go b/internal/generator/output/general/model_writer.go index 5c3f746..ffa473a 100644 --- a/internal/generator/output/general/model_writer.go +++ b/internal/generator/output/general/model_writer.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/common" "github.com/tarantool/sdvg/internal/generator/models" "github.com/tarantool/sdvg/internal/generator/output" @@ -48,6 +49,10 @@ type ModelWriter struct { writtenRowsWg *sync.WaitGroup writtenRowsChan chan uint64 stopChan chan struct{} + + partitionFilesCount int + partitionFilesLimit *int + confirm confirm.Confirm } // NewModelWriter creates ModelWriter object. @@ -55,7 +60,15 @@ func newModelWriter( model *models.Model, config *models.OutputConfig, continueGeneration bool, -) (*ModelWriter, error) { + confirm confirm.Confirm) (*ModelWriter, error) { + var partitionFilesLimit *int + switch config.Type { + case "csv": + partitionFilesLimit = config.CSVParams.PartitionFilesLimit + case "parquet": + partitionFilesLimit = config.ParquetParams.PartitionFilesLimit + } + orderedColumnNames := make([]string, 0, len(model.Columns)) for _, column := range model.Columns { orderedColumnNames = append(orderedColumnNames, column.Name) @@ -108,6 +121,9 @@ func newModelWriter( writtenRowsWg: &sync.WaitGroup{}, writtenRowsChan: make(chan uint64, buffer), stopChan: make(chan struct{}), + partitionFilesCount: 0, + partitionFilesLimit: partitionFilesLimit, + confirm: confirm, } modelWriter.checkpointFilePath = modelWriter.getCheckpointFilePath() @@ -164,6 +180,7 @@ func (w *ModelWriter) updateCheckpoint() error { } // WriteRows function determines the partitioning key and sends the data to the appropriate writer. +// Note that this func should not be called concurrently from multiple goroutines because of confirm func call. func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) error { for _, row := range rows { partitionPath := w.getPartitionPath(row) @@ -173,6 +190,12 @@ func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) err w.writersMutex.RUnlock() if !ok { + w.partitionFilesCount++ + err := w.shouldContinue(ctx) + if err != nil { + return err + } + newDataWriter, err := w.newWriter(ctx, partitionPath) if err != nil { return err @@ -232,6 +255,22 @@ func (w *ModelWriter) getPartitionPath(row *models.DataRow) string { return sb.String() } +// shouldContinue returns error if user don't want to continue generation. +func (w *ModelWriter) shouldContinue(ctx context.Context) error { + if w.confirm != nil && w.partitionFilesLimit != nil && w.partitionFilesCount == *w.partitionFilesLimit { + shouldContinue, err := w.confirm(ctx, "Number of partitions files reached limit. Continue?") + if err != nil { + return err + } + + if !shouldContinue { + return errors.Errorf("number of partitions achieved limit exceeded: %v", w.partitionFilesCount) + } + } + + return nil +} + // newWriter function creates writer.Writer object based on output type from models.OutputConfig. func (w *ModelWriter) newWriter(ctx context.Context, outPath string) (writer.Writer, error) { var dataWriter writer.Writer diff --git a/internal/generator/output/general/output.go b/internal/generator/output/general/output.go index b9a7fbe..3f0f176 100644 --- a/internal/generator/output/general/output.go +++ b/internal/generator/output/general/output.go @@ -9,6 +9,7 @@ import ( "slices" "github.com/pkg/errors" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/models" "github.com/tarantool/sdvg/internal/generator/output" ) @@ -20,13 +21,15 @@ var _ output.Output = (*Output)(nil) type Output struct { config *models.OutputConfig models map[string]*models.Model + writersByModelName map[string]*ModelWriter + continueGeneration bool forceGeneration bool - writersByModelName map[string]*ModelWriter + confirm confirm.Confirm } // NewOutput function creates Output object. -func NewOutput(cfg *models.GenerationConfig, continueGeneration, forceGeneration bool) output.Output { +func NewOutput(cfg *models.GenerationConfig, continueGeneration, forceGeneration bool, confirm confirm.Confirm) output.Output { filteredModels := make(map[string]*models.Model) for modelName, model := range cfg.Models { @@ -41,6 +44,7 @@ func NewOutput(cfg *models.GenerationConfig, continueGeneration, forceGeneration continueGeneration: continueGeneration, forceGeneration: forceGeneration, writersByModelName: make(map[string]*ModelWriter), + confirm: confirm, } } @@ -56,7 +60,7 @@ func (o *Output) Setup() error { writersByModelName := make(map[string]*ModelWriter) for modelName, model := range o.models { - modelWriter, err := newModelWriter(model, o.config, o.continueGeneration) + modelWriter, err := newModelWriter(model, o.config, o.continueGeneration, o.confirm) if err != nil { return err } diff --git a/internal/generator/usecase/general/task.go b/internal/generator/usecase/general/task.go index 2bf435e..29c0d55 100644 --- a/internal/generator/usecase/general/task.go +++ b/internal/generator/usecase/general/task.go @@ -253,6 +253,7 @@ func (t *Task) skipRows() { } // generateAndSaveBatch function generate batch of values for selected column and send it to output. +// The next batch is written only after the previous one has completed saving. func (t *Task) generateAndSaveBatch( ctx context.Context, outputSync *common.WorkerSyncer, modelName string, generators []*generator.BatchGenerator, count uint64, From 0b2dd04ea1b9cb2d12239e5b88ca0f7c1d85b09b Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Mon, 11 Aug 2025 21:59:18 +0300 Subject: [PATCH 2/7] add: confirm_test; upd: renderer mock --- internal/generator/cli/confirm/confirm.go | 4 +- .../generator/cli/confirm/confirm_test.go | 251 ++++++++++++++++++ .../generator/cli/render/mock/renderer.go | 75 +++++- 3 files changed, 328 insertions(+), 2 deletions(-) create mode 100644 internal/generator/cli/confirm/confirm_test.go diff --git a/internal/generator/cli/confirm/confirm.go b/internal/generator/cli/confirm/confirm.go index 20986a9..8c5c3b4 100644 --- a/internal/generator/cli/confirm/confirm.go +++ b/internal/generator/cli/confirm/confirm.go @@ -13,6 +13,8 @@ import ( "github.com/tarantool/sdvg/internal/generator/cli/utils" ) +var ErrPromptFailed = errors.New("prompt failed") + // Confirm asks user a yes/no question. Returns true for “yes”. type Confirm func(ctx context.Context, question string) (bool, error) @@ -53,7 +55,7 @@ func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, ques } if err != nil { - return false, errors.WithMessage(err, "confirm prompt failed") + return false, fmt.Errorf("%w: %v", ErrPromptFailed, err) } return strings.Contains("Yy", input), nil diff --git a/internal/generator/cli/confirm/confirm_test.go b/internal/generator/cli/confirm/confirm_test.go new file mode 100644 index 0000000..04ece48 --- /dev/null +++ b/internal/generator/cli/confirm/confirm_test.go @@ -0,0 +1,251 @@ +package confirm + +import ( + "bytes" + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + rendererMock "github.com/tarantool/sdvg/internal/generator/cli/render/mock" +) + +func TestConfirmTTY(t *testing.T) { + input := bytes.Buffer{} + output := bytes.Buffer{} + + confirm := BuildConfirmTTY(&input, &output) + + testCases := []struct { + name string + ctx context.Context + question string + input string + expected bool + expectedErr error + }{ + { + name: "Y", + question: "question", + input: "Y", + expected: true, + }, + { + name: "y", + question: "question", + input: "y", + expected: true, + }, + { + name: "yes", + question: "question", + input: "yes", + expectedErr: ErrPromptFailed, + }, + { + name: "N", + question: "question", + input: "N", + expected: false, + }, + { + name: "n", + question: "question", + input: "n", + expected: false, + }, + { + name: "no", + question: "question", + input: "no", + expectedErr: ErrPromptFailed, + }, + { + name: "Context canceled", + expectedErr: context.Canceled, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + + if errors.Is(tc.expectedErr, context.Canceled) { + var cancel context.CancelFunc + + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + input.WriteString(tc.input + "\n") + + res, err := confirm(ctx, tc.question) + require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) + + require.Equal(t, tc.expected, res) + + input.Reset() + output.Reset() + }) + } +} + +var errMockTest = errors.New("mock test error") + +func TestConfirmNoTTY(t *testing.T) { + output := bytes.Buffer{} + + isUpdatePaused := atomic.Bool{} + + testCases := []struct { + name string + ctx context.Context + question string + ch chan time.Time + expected bool + expectedErr error + mockFunc func(r *rendererMock.Renderer) + }{ + { + name: "Y", + question: "question", + expected: true, + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("Y"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + }, + { + name: "y", + question: "question", + expected: true, + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("y"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + }, + { + name: "yes", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("yes"+"\n", errMockTest) + }, + expectedErr: errMockTest, + }, + { + name: "N", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("N"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + expected: false, + }, + { + name: "n", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("n"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + expected: false, + }, + { + name: "no", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("no"+"\n", errMockTest) + }, + expectedErr: errMockTest, + }, + { + name: "Context canceled", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("", nil).Maybe() + }, + expectedErr: context.Canceled, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := rendererMock.NewRenderer(t) + + tc.mockFunc(r) + + confirm := BuildConfirmNoTTY(r, &output, &isUpdatePaused) + + ctx := context.Background() + + if errors.Is(tc.expectedErr, context.Canceled) { + var cancel context.CancelFunc + + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + res, err := confirm(ctx, tc.question) + require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) + + require.Equal(t, tc.expected, res) + + output.Reset() + }) + } +} + +func TestConfirmNoTTY_IsUpdatePaused(t *testing.T) { + output := bytes.Buffer{} + + isUpdatePaused := atomic.Bool{} + + r := rendererMock.NewRenderer(t) + + confirm := BuildConfirmNoTTY(r, &output, &isUpdatePaused) + + mockFunc := func(r *rendererMock.Renderer, ch chan time.Time) { + r.On("ReadLine").WaitUntil(ch). + Return("Y"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + } + + ch := make(chan time.Time) + + mockFunc(r, ch) + + go confirm(context.Background(), "") + + start := time.Now() + ch <- start + + for isUpdatePaused.Load() { + if time.Now().Sub(start) > 2*time.Second { + t.Fatal("isUpdatePaused has not been called") + } + } +} diff --git a/internal/generator/cli/render/mock/renderer.go b/internal/generator/cli/render/mock/renderer.go index 3580615..d0a953a 100644 --- a/internal/generator/cli/render/mock/renderer.go +++ b/internal/generator/cli/render/mock/renderer.go @@ -41,11 +41,85 @@ func (_m *Renderer) InputMenu(ctx context.Context, title string, validateFunc fu return r0, r1 } +// IsTerminal provides a mock function with no fields +func (_m *Renderer) IsTerminal() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsTerminal") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // Logo provides a mock function with no fields func (_m *Renderer) Logo() { _m.Called() } +// Read provides a mock function with given fields: p +func (_m *Renderer) Read(p []byte) (int, error) { + ret := _m.Called(p) + + if len(ret) == 0 { + panic("no return value specified for Read") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok { + return rf(p) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(p) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(p) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReadLine provides a mock function with no fields +func (_m *Renderer) ReadLine() (string, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ReadLine") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func() (string, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // SelectionMenu provides a mock function with given fields: ctx, title, items func (_m *Renderer) SelectionMenu(ctx context.Context, title string, items []string) (string, error) { ret := _m.Called(ctx, title, items) @@ -104,7 +178,6 @@ func (_m *Renderer) TextMenu(ctx context.Context, title string) (string, error) // WithSpinner provides a mock function with given fields: title, fn func (_m *Renderer) WithSpinner(title string, fn func()) { - fn() _m.Called(title, fn) } From 14d09cc7532380587ed29da289a5d75d0cf572a3 Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Mon, 11 Aug 2025 23:44:53 +0300 Subject: [PATCH 3/7] add: racy test confirm in different file; upd: output writers test, generator output config tests --- .../cli/commands/generate/generate_test.go | 2 +- internal/generator/cli/confirm/confirm.go | 5 +- .../cli/confirm/confirm_race_off_test.go | 90 +++++++++++++ .../generator/cli/confirm/confirm_test.go | 90 +------------ internal/generator/cli/confirm/reader.go | 37 +++++ internal/generator/models/models_test.go | 57 ++++---- .../generator/output/general/model_writer.go | 6 +- .../output/general/model_writer_test.go | 2 +- .../output/general/test/bench_test.go | 2 +- .../output/general/test/unit_test.go | 126 ++++++++++++++++-- .../usecase/general/backup/backup_test.go | 3 +- 11 files changed, 294 insertions(+), 126 deletions(-) create mode 100644 internal/generator/cli/confirm/confirm_race_off_test.go create mode 100644 internal/generator/cli/confirm/reader.go diff --git a/internal/generator/cli/commands/generate/generate_test.go b/internal/generator/cli/commands/generate/generate_test.go index 03f5798..7596225 100644 --- a/internal/generator/cli/commands/generate/generate_test.go +++ b/internal/generator/cli/commands/generate/generate_test.go @@ -256,7 +256,7 @@ func TestNewGenerateCommand(t *testing.T) { cliOpts.SetOut(streams.NewOut(os.Stdout)) cmd := NewGenerateCommand(cliOpts) - cmd.SetArgs([]string{"-F"}) + cmd.SetArgs([]string{"-f"}) err = cmd.Execute() diff --git a/internal/generator/cli/confirm/confirm.go b/internal/generator/cli/confirm/confirm.go index 8c5c3b4..1b42f55 100644 --- a/internal/generator/cli/confirm/confirm.go +++ b/internal/generator/cli/confirm/confirm.go @@ -22,10 +22,13 @@ func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, ques return func(ctx context.Context, question string) (bool, error) { fmt.Fprintln(out) + cancelableIn := newCancelableReader(in) + defer cancelableIn.Close() + prompt := promptui.Prompt{ Label: question + " [y/N]: ", Default: "y", - Stdin: utils.DummyReadWriteCloser{Reader: in}, + Stdin: cancelableIn, Stdout: utils.DummyReadWriteCloser{Writer: out}, } validate := func(s string) error { diff --git a/internal/generator/cli/confirm/confirm_race_off_test.go b/internal/generator/cli/confirm/confirm_race_off_test.go new file mode 100644 index 0000000..a3242a8 --- /dev/null +++ b/internal/generator/cli/confirm/confirm_race_off_test.go @@ -0,0 +1,90 @@ +//go:build !race + +package confirm + +import ( + "bytes" + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConfirmTTY(t *testing.T) { + testCases := []struct { + name string + ctx context.Context + question string + input string + expected bool + expectedErr error + }{ + { + name: "Y", + question: "question", + input: "Y", + expected: true, + }, + { + name: "y", + question: "question", + input: "y", + expected: true, + }, + { + name: "yes", + question: "question", + input: "yes", + expectedErr: ErrPromptFailed, + }, + { + name: "N", + question: "question", + input: "N", + expected: false, + }, + { + name: "n", + question: "question", + input: "n", + expected: false, + }, + { + name: "no", + question: "question", + input: "no", + expectedErr: ErrPromptFailed, + }, + { + name: "Context canceled", + expectedErr: context.Canceled, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + input := bytes.Buffer{} + output := bytes.Buffer{} + + confirm := BuildConfirmTTY(&input, &output) + + ctx := context.Background() + + if errors.Is(tc.expectedErr, context.Canceled) { + var cancel context.CancelFunc + + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + input.WriteString(tc.input + "\n") + + res, err := confirm(ctx, tc.question) + require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) + + require.Equal(t, tc.expected, res) + }) + } +} diff --git a/internal/generator/cli/confirm/confirm_test.go b/internal/generator/cli/confirm/confirm_test.go index 04ece48..3546c47 100644 --- a/internal/generator/cli/confirm/confirm_test.go +++ b/internal/generator/cli/confirm/confirm_test.go @@ -13,93 +13,9 @@ import ( rendererMock "github.com/tarantool/sdvg/internal/generator/cli/render/mock" ) -func TestConfirmTTY(t *testing.T) { - input := bytes.Buffer{} - output := bytes.Buffer{} - - confirm := BuildConfirmTTY(&input, &output) - - testCases := []struct { - name string - ctx context.Context - question string - input string - expected bool - expectedErr error - }{ - { - name: "Y", - question: "question", - input: "Y", - expected: true, - }, - { - name: "y", - question: "question", - input: "y", - expected: true, - }, - { - name: "yes", - question: "question", - input: "yes", - expectedErr: ErrPromptFailed, - }, - { - name: "N", - question: "question", - input: "N", - expected: false, - }, - { - name: "n", - question: "question", - input: "n", - expected: false, - }, - { - name: "no", - question: "question", - input: "no", - expectedErr: ErrPromptFailed, - }, - { - name: "Context canceled", - expectedErr: context.Canceled, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx := context.Background() - - if errors.Is(tc.expectedErr, context.Canceled) { - var cancel context.CancelFunc - - ctx, cancel = context.WithCancel(ctx) - cancel() - } - - input.WriteString(tc.input + "\n") - - res, err := confirm(ctx, tc.question) - require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) - - require.Equal(t, tc.expected, res) - - input.Reset() - output.Reset() - }) - } -} - var errMockTest = errors.New("mock test error") func TestConfirmNoTTY(t *testing.T) { - output := bytes.Buffer{} - - isUpdatePaused := atomic.Bool{} - testCases := []struct { name string ctx context.Context @@ -192,9 +108,11 @@ func TestConfirmNoTTY(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { r := rendererMock.NewRenderer(t) - tc.mockFunc(r) + output := bytes.Buffer{} + isUpdatePaused := atomic.Bool{} + confirm := BuildConfirmNoTTY(r, &output, &isUpdatePaused) ctx := context.Background() @@ -210,8 +128,6 @@ func TestConfirmNoTTY(t *testing.T) { require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) require.Equal(t, tc.expected, res) - - output.Reset() }) } } diff --git a/internal/generator/cli/confirm/reader.go b/internal/generator/cli/confirm/reader.go new file mode 100644 index 0000000..760fcb3 --- /dev/null +++ b/internal/generator/cli/confirm/reader.go @@ -0,0 +1,37 @@ +package confirm + +import "io" + +// cancelableReader wraps an io.Reader and can be closed to make future reads fail. +type cancelableReader struct { + r io.Reader + closed chan struct{} +} + +// newCancelableReader creates a ReadCloser from an io.Reader. +// Closing it will make subsequent Read() calls return io.EOF. +func newCancelableReader(r io.Reader) io.ReadCloser { + return &cancelableReader{ + r: r, + closed: make(chan struct{}), + } +} + +func (c *cancelableReader) Read(p []byte) (int, error) { + select { + case <-c.closed: + return 0, io.EOF + default: + return c.r.Read(p) + } +} + +func (c *cancelableReader) Close() error { + select { + case <-c.closed: + // already closed + default: + close(c.closed) + } + return nil +} diff --git a/internal/generator/models/models_test.go b/internal/generator/models/models_test.go index ffb908a..66b35fd 100644 --- a/internal/generator/models/models_test.go +++ b/internal/generator/models/models_test.go @@ -225,9 +225,10 @@ func TestGeneratorConfigYAMLParse(t *testing.T) { OutputConfig: &OutputConfig{ Type: "csv", CSVParams: &CSVConfig{ - FloatPrecision: 2, - DatetimeFormat: "2006-01-02T15:04:05Z07:00", - Delimiter: ",", + FloatPrecision: 2, + DatetimeFormat: "2006-01-02T15:04:05Z07:00", + Delimiter: ",", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -624,15 +625,16 @@ models: Dir: "test_output", CheckpointInterval: time.Second, CSVParams: &CSVConfig{ - FloatPrecision: 2, - DatetimeFormat: "2006-01-02T15:04:05Z07:00", - Delimiter: ",", + FloatPrecision: 2, + DatetimeFormat: "2006-01-02T15:04:05Z07:00", + Delimiter: ",", + PartitionFilesLimit: ptr(1000), }, }, }, }, { - name: "CsvFullConfig", + name: "csv full config", content: ` random_seed: 1 output: @@ -641,6 +643,7 @@ output: datetime_format: "2006-01-02" float_precision: 1 delimiter: ";" + partition_files_limit: 10 models: test: rows_count: 1 @@ -656,9 +659,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, CSVParams: &CSVConfig{ - FloatPrecision: 1, - DatetimeFormat: "2006-01-02", - Delimiter: ";", + FloatPrecision: 1, + DatetimeFormat: "2006-01-02", + Delimiter: ";", + PartitionFilesLimit: ptr(10), }, }, }, @@ -849,6 +853,7 @@ output: datetime_format: micros float_precision: 3 compression_codec: GZIP + partition_files_limit: 1 models: test: rows_count: 1 @@ -864,9 +869,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 3, - DateTimeFormat: ParquetDateTimeMicrosFormat, - CompressionCodec: "GZIP", + FloatPrecision: 3, + DateTimeFormat: ParquetDateTimeMicrosFormat, + CompressionCodec: "GZIP", + PartitionFilesLimit: ptr(1), }, }, }, @@ -892,9 +898,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 2, - DateTimeFormat: ParquetDateTimeMillisFormat, - CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: ParquetDateTimeMillisFormat, + CompressionCodec: "UNCOMPRESSED", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -956,9 +963,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 2, - DateTimeFormat: ParquetDateTimeMillisFormat, - CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: ParquetDateTimeMillisFormat, + CompressionCodec: "UNCOMPRESSED", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -1062,9 +1070,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 2, - DateTimeFormat: ParquetDateTimeMillisFormat, - CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: ParquetDateTimeMillisFormat, + CompressionCodec: "UNCOMPRESSED", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -1107,6 +1116,7 @@ output: compression_codec: non-existent-codec float_precision: -1 datetime_format: non-existent-datetime-format + partition_files_limit: 0 checkpoint_interval: -1s models_to_ignore: - non-existent-column @@ -1158,7 +1168,8 @@ output config: parquet params: - unknown compression codec non-existent-codec, supported [UNCOMPRESSED SNAPPY GZIP LZ4 LZ4RAW LZO ZSTD BROTLI] - float precision should be grater than 0, got -1 -- unknown datetime format non-existent-datetime-format, supported [millis micros]`, +- unknown datetime format non-existent-datetime-format, supported [millis micros] +- partition files limit should be greater than 0, got: 0`, ), }, } diff --git a/internal/generator/output/general/model_writer.go b/internal/generator/output/general/model_writer.go index ffa473a..c7c9730 100644 --- a/internal/generator/output/general/model_writer.go +++ b/internal/generator/output/general/model_writer.go @@ -27,6 +27,8 @@ import ( const buffer = 100 +var ErrPartitionFilesLimitExceeded = errors.New("partition files limit exceeded") + // ModelWriter type implements the general logic of writing data. type ModelWriter struct { model *models.Model @@ -257,14 +259,14 @@ func (w *ModelWriter) getPartitionPath(row *models.DataRow) string { // shouldContinue returns error if user don't want to continue generation. func (w *ModelWriter) shouldContinue(ctx context.Context) error { - if w.confirm != nil && w.partitionFilesLimit != nil && w.partitionFilesCount == *w.partitionFilesLimit { + if w.confirm != nil && w.partitionFilesLimit != nil && w.partitionFilesCount == *w.partitionFilesLimit+1 { shouldContinue, err := w.confirm(ctx, "Number of partitions files reached limit. Continue?") if err != nil { return err } if !shouldContinue { - return errors.Errorf("number of partitions achieved limit exceeded: %v", w.partitionFilesCount) + return fmt.Errorf("%w: %v", ErrPartitionFilesLimitExceeded, w.partitionFilesCount) } } diff --git a/internal/generator/output/general/model_writer_test.go b/internal/generator/output/general/model_writer_test.go index e938c96..ed75dc6 100644 --- a/internal/generator/output/general/model_writer_test.go +++ b/internal/generator/output/general/model_writer_test.go @@ -220,7 +220,7 @@ func TestPartitionPaths(t *testing.T) { }, } - writer, err := newModelWriter(tCase.model, devnullConfig, false) + writer, err := newModelWriter(tCase.model, devnullConfig, false, nil) require.NoError(t, err) err = writer.WriteRows(context.Background(), tCase.data) diff --git a/internal/generator/output/general/test/bench_test.go b/internal/generator/output/general/test/bench_test.go index 450d8c2..68dc9ab 100644 --- a/internal/generator/output/general/test/bench_test.go +++ b/internal/generator/output/general/test/bench_test.go @@ -301,7 +301,7 @@ func runModelsBenches( copyCfg := *genCfg SetOutputParams(©Cfg, uint64(b.N)) - out := general.NewOutput(©Cfg, false, true) + out := general.NewOutput(©Cfg, false, true, nil) require.NoError(b, out.Setup()) b.ResetTimer() diff --git a/internal/generator/output/general/test/unit_test.go b/internal/generator/output/general/test/unit_test.go index 4e7a91f..f12ab90 100644 --- a/internal/generator/output/general/test/unit_test.go +++ b/internal/generator/output/general/test/unit_test.go @@ -7,10 +7,12 @@ import ( "math" "os" "path/filepath" + "strings" "testing" "github.com/pkg/errors" "github.com/stretchr/testify/require" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/common" "github.com/tarantool/sdvg/internal/generator/models" outputGeneral "github.com/tarantool/sdvg/internal/generator/output/general" @@ -63,6 +65,18 @@ models: range_percentage: 0.5 - type_params: to: 5 +` + oneModelConfigWithPartition = ` +models: + model1: + rows_count: 10 + columns: + - name: id + type: integer + distinct_percentage: 1 + partition_columns: + - name: id + write_to_output: true ` ) @@ -94,7 +108,7 @@ func TestContinueGeneration(t *testing.T) { // Generate expected data - require.NoError(t, generate(t, cfg, uc, false, true)) + require.NoError(t, generate(t, cfg, uc, false, true, nil)) expectedFilesData := make(map[string][][]string) @@ -117,7 +131,7 @@ func TestContinueGeneration(t *testing.T) { model.GenerateTo = model.RowsCount / 2 } - require.NoError(t, generate(t, cfg, uc, false, true)) + require.NoError(t, generate(t, cfg, uc, false, true, nil)) for _, model := range cfg.Models { filesCount := int(math.Ceil(float64(model.GenerateTo-model.GenerateFrom) / float64(model.RowsPerFile))) @@ -151,7 +165,7 @@ func TestContinueGeneration(t *testing.T) { require.NoError(t, cfg.ParseFromFile(configPath)) cfg.OutputConfig.Dir = outputDir - require.NoError(t, generate(t, cfg, uc, true, true)) + require.NoError(t, generate(t, cfg, uc, true, true, nil)) for _, model := range cfg.Models { filesCount := math.Ceil(float64(rowsCountByModel[model.Name]) / float64(model.RowsPerFile)) @@ -238,10 +252,10 @@ cause: dir for model is not empty // Generate data in empty output dir - require.NoError(t, generate(t, cfg, uc, false, false)) + require.NoError(t, generate(t, cfg, uc, false, false, nil)) // Try to init new output with conflicts - out := outputGeneral.NewOutput(cfg, false, tc.forceGeneration) + out := outputGeneral.NewOutput(cfg, false, tc.forceGeneration, nil) err := out.Setup() if tc.err != nil { @@ -264,11 +278,98 @@ cause: dir for model is not empty } } +var ( + errMockTest = errors.New("mock test error") + partitionsFileLimit = 2 +) + +func TestConfirmationAsk(t *testing.T) { + testCases := []struct { + name string + shouldContinue bool + wantErr bool + err error + confirm confirm.Confirm + }{ + { + name: "Continue", + shouldContinue: true, + confirm: func(ctx context.Context, question string) (bool, error) { + return true, nil + }, + }, + { + name: "Stop", + shouldContinue: false, + err: outputGeneral.ErrPartitionFilesLimitExceeded, + confirm: func(ctx context.Context, question string) (bool, error) { + return false, nil + }, + }, + { + name: "Error", + shouldContinue: false, + wantErr: true, + err: errMockTest, + confirm: func(ctx context.Context, question string) (bool, error) { + return false, errMockTest + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Write models config + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, configFileName) + require.NoError(t, os.WriteFile(configPath, []byte(oneModelConfigWithPartition), configFilePerm)) + + uc := useCaseGeneral.NewUseCase(useCaseGeneral.UseCaseConfig{}) + require.NoError(t, uc.Setup()) + + // Parse config + + cfg := &models.GenerationConfig{} + + require.NoError(t, cfg.ParseFromFile(configPath)) + + *cfg.OutputConfig.CSVParams.PartitionFilesLimit = partitionsFileLimit + + // Generate data in empty output dir + + err := generate(t, cfg, uc, false, true, tc.confirm) + + // check generated partitions files amount + fileNames, walkErr := common.WalkWithFilter(models.DefaultOutputDir, func(entry os.DirEntry) bool { + return entry.IsDir() && strings.HasPrefix(entry.Name(), "id=") + }) + + require.NoError(t, walkErr, "failed to walk tmpdir: %v", tmpDir) + + if tc.wantErr { + require.Error(t, err) + } else { + if tc.shouldContinue { + require.Equal(t, 10, len(fileNames), "there should be rows_amount dirs") + require.NoError(t, err) + } else { + require.True(t, errors.Is(err, tc.err), "expected error: %v, got: %v", tc.err, err) + require.Equal(t, partitionsFileLimit, len(fileNames), "there should be partitionsFileLimit dirs") + } + } + + // cleanup + + require.NoError(t, os.RemoveAll(models.DefaultOutputDir)) + }) + } +} + //nolint:lll -func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, continueGeneration, forceGeneration bool) error { +func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, continueGeneration, forceGeneration bool, confirm confirm.Confirm) error { t.Helper() - out := outputGeneral.NewOutput(cfg, continueGeneration, forceGeneration) + out := outputGeneral.NewOutput(cfg, continueGeneration, forceGeneration, confirm) taskID, err := uc.CreateTask(context.Background(), usecase.TaskConfig{ GenerationConfig: cfg, @@ -279,8 +380,15 @@ func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, co return err } - require.NoError(t, uc.WaitResult(taskID)) - require.NoError(t, uc.Teardown()) + err = uc.WaitResult(taskID) + if err != nil { + return err + } + + err = uc.Teardown() + if err != nil { + return err + } return nil } diff --git a/internal/generator/usecase/general/backup/backup_test.go b/internal/generator/usecase/general/backup/backup_test.go index e29b433..3168e1f 100644 --- a/internal/generator/usecase/general/backup/backup_test.go +++ b/internal/generator/usecase/general/backup/backup_test.go @@ -35,6 +35,7 @@ func TestHandleBackup(t *testing.T) { }, false, false, + nil, ) type testCase struct { @@ -189,7 +190,7 @@ func TestHandleCheckpoint(t *testing.T) { "model4": 954, } - out := general.NewOutput(cfg, false, false) + out := general.NewOutput(cfg, false, false, nil) require.NoError(t, out.Setup()) for modelName, generateFrom := range checkpoints { From 8652cdd6894fb4ba6227c5204a6fe6843cfd70c5 Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Tue, 12 Aug 2025 14:29:16 +0300 Subject: [PATCH 4/7] add: always call fn argument in WithSpinner mock --- internal/generator/cli/render/interfaces.go | 2 ++ internal/generator/cli/render/mock/renderer.go | 1 + 2 files changed, 3 insertions(+) diff --git a/internal/generator/cli/render/interfaces.go b/internal/generator/cli/render/interfaces.go index 1845a6d..6b0c3a6 100644 --- a/internal/generator/cli/render/interfaces.go +++ b/internal/generator/cli/render/interfaces.go @@ -6,6 +6,8 @@ import ( // Renderer interface implementation should render interactive menu. // +// after regenerating mock, do not forget to add call to fn() argument in WithSpinner +// //go:generate go run github.com/vektra/mockery/v2@v2.51.1 --name=Renderer --output=mock --outpkg=mock type Renderer interface { // Logo should display application logo. diff --git a/internal/generator/cli/render/mock/renderer.go b/internal/generator/cli/render/mock/renderer.go index d0a953a..1114eee 100644 --- a/internal/generator/cli/render/mock/renderer.go +++ b/internal/generator/cli/render/mock/renderer.go @@ -178,6 +178,7 @@ func (_m *Renderer) TextMenu(ctx context.Context, title string) (string, error) // WithSpinner provides a mock function with given fields: title, fn func (_m *Renderer) WithSpinner(title string, fn func()) { + fn() _m.Called(title, fn) } From bb91fa86f4a43197595e6234a0745341aa4b963d Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Tue, 12 Aug 2025 14:34:39 +0300 Subject: [PATCH 5/7] upd: single test racy test without race flag --- Makefile | 1 + .../confirm/{confirm_race_off_test.go => confirm_racy_test.go} | 0 2 files changed, 1 insertion(+) rename internal/generator/cli/confirm/{confirm_race_off_test.go => confirm_racy_test.go} (100%) diff --git a/Makefile b/Makefile index 75d1ecb..4383bf3 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ test/lint/fix: test/unit: go test -race ./... + go test ./internal/generator/cli/confirm -run '' ./internal/generator/cli/confirm/confirm_race_off_test.go test/cover: module=./... test/cover: diff --git a/internal/generator/cli/confirm/confirm_race_off_test.go b/internal/generator/cli/confirm/confirm_racy_test.go similarity index 100% rename from internal/generator/cli/confirm/confirm_race_off_test.go rename to internal/generator/cli/confirm/confirm_racy_test.go From a3be8a75a117ec55225ee2189c914e3ea0f0fddf Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Tue, 12 Aug 2025 14:55:36 +0300 Subject: [PATCH 6/7] upd: linter notes --- internal/generator/cli/commands/consts.go | 2 +- .../generator/cli/commands/generate/generate.go | 15 +++++++++++---- internal/generator/cli/confirm/confirm.go | 10 ++++++++-- .../generator/cli/confirm/confirm_racy_test.go | 6 ++---- internal/generator/cli/confirm/confirm_test.go | 9 ++++----- internal/generator/cli/confirm/reader.go | 3 ++- internal/generator/cli/progress/bar/bar.go | 2 +- internal/generator/cli/progress/log/log.go | 2 +- internal/generator/cli/streams/in.go | 1 - internal/generator/cli/streams/out.go | 1 - internal/generator/models/generator_output.go | 14 +++++++------- internal/generator/output/general/model_writer.go | 4 +++- internal/generator/output/general/output.go | 7 ++++++- .../generator/output/general/test/unit_test.go | 4 ++-- 14 files changed, 48 insertions(+), 32 deletions(-) diff --git a/internal/generator/cli/commands/consts.go b/internal/generator/cli/commands/consts.go index 511c81a..9aeb192 100644 --- a/internal/generator/cli/commands/consts.go +++ b/internal/generator/cli/commands/consts.go @@ -14,7 +14,7 @@ const ( ForceGenerationFlag = "force" ForceGenerationShortFlag = "f" ForceGenerationFlagDefaultValue = false - ForceGenerationUsage = "Force generation even if output file conflicts found and partition files limit reached" + ForceGenerationUsage = "Force generation even if output file conflicts found and partition files limit reached" //nolint:lll TTYFlag = "tty" TTYShortFlag = "t" diff --git a/internal/generator/cli/commands/generate/generate.go b/internal/generator/cli/commands/generate/generate.go index 6a46d43..6341013 100644 --- a/internal/generator/cli/commands/generate/generate.go +++ b/internal/generator/cli/commands/generate/generate.go @@ -175,10 +175,17 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { return nil } -// initProgressTrackerManager inits progress bar manager (progress.Tracker) and builds streams.Confirm func based on useTTY -func initProgressTrackerManager(ctx context.Context, renderer render.Renderer, useTTY bool) (progress.Tracker, confirm.Confirm) { - var progressTrackerManager progress.Tracker - var confirmFunc confirm.Confirm +// initProgressTrackerManager inits progress bar manager (progress.Tracker) +// and builds streams.Confirm func based on useTTY. +func initProgressTrackerManager( + ctx context.Context, + renderer render.Renderer, + useTTY bool, +) (progress.Tracker, confirm.Confirm) { + var ( + progressTrackerManager progress.Tracker + confirmFunc confirm.Confirm + ) if useTTY { progressTrackerManager = bar.NewProgressBarManager(ctx) diff --git a/internal/generator/cli/confirm/confirm.go b/internal/generator/cli/confirm/confirm.go index 1b42f55..adfc66e 100644 --- a/internal/generator/cli/confirm/confirm.go +++ b/internal/generator/cli/confirm/confirm.go @@ -18,6 +18,7 @@ var ErrPromptFailed = errors.New("prompt failed") // Confirm asks user a yes/no question. Returns true for “yes”. type Confirm func(ctx context.Context, question string) (bool, error) +//nolint:gocritic func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, question string) (bool, error) { return func(ctx context.Context, question string) (bool, error) { fmt.Fprintln(out) @@ -35,6 +36,7 @@ func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, ques if len(s) == 1 && strings.Contains("YyNn", s) || prompt.Default != "" && len(s) == 0 { return nil } + return errors.New("invalid input") } prompt.Validate = validate @@ -58,14 +60,18 @@ func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, ques } if err != nil { - return false, fmt.Errorf("%w: %v", ErrPromptFailed, err) + return false, errors.Wrap(ErrPromptFailed, err.Error()) } return strings.Contains("Yy", input), nil } } -func BuildConfirmNoTTY(in render.Renderer, out io.Writer, isUpdatePaused *atomic.Bool) func(ctx context.Context, question string) (bool, error) { +func BuildConfirmNoTTY( + in render.Renderer, + out io.Writer, + isUpdatePaused *atomic.Bool, +) func(ctx context.Context, question string) (bool, error) { return func(ctx context.Context, question string) (bool, error) { // here we pause ProgressLogManager to stop sending progress messages isUpdatePaused.Store(true) diff --git a/internal/generator/cli/confirm/confirm_racy_test.go b/internal/generator/cli/confirm/confirm_racy_test.go index a3242a8..668d785 100644 --- a/internal/generator/cli/confirm/confirm_racy_test.go +++ b/internal/generator/cli/confirm/confirm_racy_test.go @@ -5,17 +5,15 @@ package confirm import ( "bytes" "context" - "errors" - "fmt" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/require" ) func TestConfirmTTY(t *testing.T) { testCases := []struct { name string - ctx context.Context question string input string expected bool @@ -82,7 +80,7 @@ func TestConfirmTTY(t *testing.T) { input.WriteString(tc.input + "\n") res, err := confirm(ctx, tc.question) - require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) + require.ErrorIs(t, err, tc.expectedErr, "expected: %v, got: %v", tc.expectedErr, err) require.Equal(t, tc.expected, res) }) diff --git a/internal/generator/cli/confirm/confirm_test.go b/internal/generator/cli/confirm/confirm_test.go index 3546c47..2f0e0ac 100644 --- a/internal/generator/cli/confirm/confirm_test.go +++ b/internal/generator/cli/confirm/confirm_test.go @@ -3,12 +3,11 @@ package confirm import ( "bytes" "context" - "errors" - "fmt" "sync/atomic" "testing" "time" + "github.com/pkg/errors" "github.com/stretchr/testify/require" rendererMock "github.com/tarantool/sdvg/internal/generator/cli/render/mock" ) @@ -18,7 +17,6 @@ var errMockTest = errors.New("mock test error") func TestConfirmNoTTY(t *testing.T) { testCases := []struct { name string - ctx context.Context question string ch chan time.Time expected bool @@ -125,7 +123,7 @@ func TestConfirmNoTTY(t *testing.T) { } res, err := confirm(ctx, tc.question) - require.True(t, errors.Is(err, tc.expectedErr), fmt.Sprintf("expected: %v, got: %v", tc.expectedErr, err)) + require.ErrorIs(t, err, tc.expectedErr, "expected: %v, got: %v", tc.expectedErr, err) require.Equal(t, tc.expected, res) }) @@ -154,13 +152,14 @@ func TestConfirmNoTTY_IsUpdatePaused(t *testing.T) { mockFunc(r, ch) + //nolint:errcheck go confirm(context.Background(), "") start := time.Now() ch <- start for isUpdatePaused.Load() { - if time.Now().Sub(start) > 2*time.Second { + if time.Since(start) > 2*time.Second { t.Fatal("isUpdatePaused has not been called") } } diff --git a/internal/generator/cli/confirm/reader.go b/internal/generator/cli/confirm/reader.go index 760fcb3..8e5b67a 100644 --- a/internal/generator/cli/confirm/reader.go +++ b/internal/generator/cli/confirm/reader.go @@ -22,7 +22,7 @@ func (c *cancelableReader) Read(p []byte) (int, error) { case <-c.closed: return 0, io.EOF default: - return c.r.Read(p) + return c.r.Read(p) //nolint:wrapcheck } } @@ -33,5 +33,6 @@ func (c *cancelableReader) Close() error { default: close(c.closed) } + return nil } diff --git a/internal/generator/cli/progress/bar/bar.go b/internal/generator/cli/progress/bar/bar.go index d1242db..e6590aa 100644 --- a/internal/generator/cli/progress/bar/bar.go +++ b/internal/generator/cli/progress/bar/bar.go @@ -80,5 +80,5 @@ func (p *ProgressBarManager) Wait() { // Write writes to stdout. func (p *ProgressBarManager) Write(b []byte) (int, error) { - return p.progressManager.Write(b) + return p.progressManager.Write(b) //nolint:wrapcheck } diff --git a/internal/generator/cli/progress/log/log.go b/internal/generator/cli/progress/log/log.go index d259b25..c9d5ed2 100644 --- a/internal/generator/cli/progress/log/log.go +++ b/internal/generator/cli/progress/log/log.go @@ -152,5 +152,5 @@ func (p *ProgressLogManager) eta(t *task) string { // Write writes to default stdout. func (p *ProgressLogManager) Write(b []byte) (int, error) { - return os.Stdout.Write(b) + return os.Stdout.Write(b) //nolint:wrapcheck } diff --git a/internal/generator/cli/streams/in.go b/internal/generator/cli/streams/in.go index 38574bb..138d6c4 100644 --- a/internal/generator/cli/streams/in.go +++ b/internal/generator/cli/streams/in.go @@ -1,4 +1,3 @@ -//nolint:dupl package streams import ( diff --git a/internal/generator/cli/streams/out.go b/internal/generator/cli/streams/out.go index 83692f8..11cddb8 100644 --- a/internal/generator/cli/streams/out.go +++ b/internal/generator/cli/streams/out.go @@ -1,4 +1,3 @@ -//nolint:dupl package streams import ( diff --git a/internal/generator/models/generator_output.go b/internal/generator/models/generator_output.go index 8bdcc51..9d16712 100644 --- a/internal/generator/models/generator_output.go +++ b/internal/generator/models/generator_output.go @@ -169,10 +169,10 @@ var _ Field = (*CSVConfig)(nil) // CSVConfig type used to describe output config for CSV implementation. type CSVConfig struct { - FloatPrecision int `json:"float_precision" yaml:"float_precision"` - DatetimeFormat string `json:"datetime_format" yaml:"datetime_format"` - Delimiter string `backup:"true" json:"delimiter" yaml:"delimiter"` - WithoutHeaders bool `backup:"true" json:"without_headers" yaml:"without_headers"` + FloatPrecision int `json:"float_precision" yaml:"float_precision"` + DatetimeFormat string `json:"datetime_format" yaml:"datetime_format"` + Delimiter string `backup:"true" json:"delimiter" yaml:"delimiter"` + WithoutHeaders bool `backup:"true" json:"without_headers" yaml:"without_headers"` PartitionFilesLimit *int `json:"partition_files_limit" yaml:"partition_files_limit"` } @@ -307,9 +307,9 @@ var _ Field = (*ParquetConfig)(nil) // ParquetConfig type used to describe output config for parquet implementation. type ParquetConfig struct { - CompressionCodec string `backup:"true" json:"compression_codec" yaml:"compression_codec"` - FloatPrecision int `json:"float_precision" yaml:"float_precision"` - DateTimeFormat string `json:"datetime_format" yaml:"datetime_format"` + CompressionCodec string `backup:"true" json:"compression_codec" yaml:"compression_codec"` + FloatPrecision int `json:"float_precision" yaml:"float_precision"` + DateTimeFormat string `json:"datetime_format" yaml:"datetime_format"` PartitionFilesLimit *int `json:"partition_files_limit" yaml:"partition_files_limit"` } diff --git a/internal/generator/output/general/model_writer.go b/internal/generator/output/general/model_writer.go index c7c9730..1bfc436 100644 --- a/internal/generator/output/general/model_writer.go +++ b/internal/generator/output/general/model_writer.go @@ -64,6 +64,7 @@ func newModelWriter( continueGeneration bool, confirm confirm.Confirm) (*ModelWriter, error) { var partitionFilesLimit *int + switch config.Type { case "csv": partitionFilesLimit = config.CSVParams.PartitionFilesLimit @@ -193,6 +194,7 @@ func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) err if !ok { w.partitionFilesCount++ + err := w.shouldContinue(ctx) if err != nil { return err @@ -266,7 +268,7 @@ func (w *ModelWriter) shouldContinue(ctx context.Context) error { } if !shouldContinue { - return fmt.Errorf("%w: %v", ErrPartitionFilesLimitExceeded, w.partitionFilesCount) + return errors.Wrapf(ErrPartitionFilesLimitExceeded, ": %v", w.partitionFilesCount) } } diff --git a/internal/generator/output/general/output.go b/internal/generator/output/general/output.go index 3f0f176..2f2fc89 100644 --- a/internal/generator/output/general/output.go +++ b/internal/generator/output/general/output.go @@ -29,7 +29,12 @@ type Output struct { } // NewOutput function creates Output object. -func NewOutput(cfg *models.GenerationConfig, continueGeneration, forceGeneration bool, confirm confirm.Confirm) output.Output { +func NewOutput( + cfg *models.GenerationConfig, + continueGeneration, + forceGeneration bool, + confirm confirm.Confirm, +) output.Output { filteredModels := make(map[string]*models.Model) for modelName, model := range cfg.Models { diff --git a/internal/generator/output/general/test/unit_test.go b/internal/generator/output/general/test/unit_test.go index f12ab90..8ac7016 100644 --- a/internal/generator/output/general/test/unit_test.go +++ b/internal/generator/output/general/test/unit_test.go @@ -350,11 +350,11 @@ func TestConfirmationAsk(t *testing.T) { require.Error(t, err) } else { if tc.shouldContinue { - require.Equal(t, 10, len(fileNames), "there should be rows_amount dirs") + require.Len(t, fileNames, 10, "there should be rows_amount dirs") require.NoError(t, err) } else { require.True(t, errors.Is(err, tc.err), "expected error: %v, got: %v", tc.err, err) - require.Equal(t, partitionsFileLimit, len(fileNames), "there should be partitionsFileLimit dirs") + require.Len(t, fileNames, partitionsFileLimit, "there should be partitionsFileLimit dirs") } } From ce74fd608e7d7695d402b0b9964460b6b6fc8b6d Mon Sep 17 00:00:00 2001 From: Zaman Gabdrakhmanov Date: Tue, 12 Aug 2025 16:45:03 +0300 Subject: [PATCH 7/7] upd: if force flag is set, partition files limit is ignored --- internal/generator/cli/commands/generate/generate.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/generator/cli/commands/generate/generate.go b/internal/generator/cli/commands/generate/generate.go index 6341013..fa94b28 100644 --- a/internal/generator/cli/commands/generate/generate.go +++ b/internal/generator/cli/commands/generate/generate.go @@ -125,7 +125,7 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { return err } - progressTrackerManager, confirm := initProgressTrackerManager(ctx, opts.renderer, opts.useTTY) + progressTrackerManager, confirm := initProgressTrackerManager(ctx, opts.renderer, opts.useTTY, opts.forceGeneration) out := general.NewOutput(generationCfg, opts.continueGeneration, opts.forceGeneration, confirm) @@ -176,11 +176,12 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { } // initProgressTrackerManager inits progress bar manager (progress.Tracker) -// and builds streams.Confirm func based on useTTY. +// and builds confirm.Confirm func based on useTTY and forceGeneration. func initProgressTrackerManager( ctx context.Context, renderer render.Renderer, useTTY bool, + forceGeneration bool, ) (progress.Tracker, confirm.Confirm) { var ( progressTrackerManager progress.Tracker @@ -199,6 +200,12 @@ func initProgressTrackerManager( confirmFunc = confirm.BuildConfirmNoTTY(renderer, progressTrackerManager, isUpdatePaused) } + if forceGeneration { + confirmFunc = func(_ context.Context, _ string) (bool, error) { + return true, nil + } + } + return progressTrackerManager, confirmFunc }