Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ydb/apps/ydb/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
* Added a simple progress bar for non-interactive stderr.
* The `ydb workload vector` now supports `import files` to populate table from CSV and parquet

## 2.27.0 ##

Expand Down
267 changes: 267 additions & 0 deletions ydb/library/workload/vector/vector_data_generator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
#include "vector_data_generator.h"

#include <ydb/library/formats/arrow/csv/converter/csv_arrow.h>
#include <ydb/library/yql/udfs/common/knn/knn-serializer-shared.h>

#include <ydb/public/api/protos/ydb_formats.pb.h>

#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_binary.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_nested.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/array_primitive.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/array/builder_binary.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/chunked_array.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/cast.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/csv/api.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/csv/reader.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/csv/writer.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/io/memory.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/dictionary.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/reader.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/ipc/writer.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/record_batch.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/table.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/type.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/type_fwd.h>

#include <util/stream/mem.h>

namespace NYdbWorkload {

namespace {

class TTransformingDataGenerator final: public IBulkDataGenerator {
private:
std::shared_ptr<IBulkDataGenerator> InnerDataGenerator;
const TString EmbeddingSourceField;

private:
static std::pair<std::shared_ptr<arrow::Schema>, std::shared_ptr<arrow::RecordBatch>> Deserialize(TDataPortion::TArrow* data) {
arrow::ipc::DictionaryMemo dictionary;

arrow::io::BufferReader schemaBuffer(arrow::util::string_view(data->Schema.data(), data->Schema.size()));
const std::shared_ptr<arrow::Schema> schema = arrow::ipc::ReadSchema(&schemaBuffer, &dictionary).ValueOrDie();

arrow::io::BufferReader recordBatchBuffer(arrow::util::string_view(data->Data.data(), data->Data.size()));
const std::shared_ptr<arrow::RecordBatch> recordBatch = arrow::ipc::ReadRecordBatch(schema, &dictionary, {}, &recordBatchBuffer).ValueOrDie();

return std::make_pair(schema, recordBatch);
}

std::shared_ptr<arrow::Table> Deserialize(TDataPortion::TCsv* data) {
Ydb::Formats::CsvSettings csvSettings;
if (Y_UNLIKELY(!csvSettings.ParseFromString(data->FormatString))) {
ythrow yexception() << "Unable to parse CsvSettings";
}

arrow::csv::ReadOptions readOptions = arrow::csv::ReadOptions::Defaults();
readOptions.skip_rows = csvSettings.skip_rows();
if (data->Data.size() > NKikimr::NFormats::TArrowCSV::DEFAULT_BLOCK_SIZE) {
ui32 blockSize = NKikimr::NFormats::TArrowCSV::DEFAULT_BLOCK_SIZE;
blockSize *= data->Data.size() / blockSize + 1;
readOptions.block_size = blockSize;
}

arrow::csv::ParseOptions parseOptions = arrow::csv::ParseOptions::Defaults();
const auto& quoting = csvSettings.quoting();
if (Y_UNLIKELY(quoting.quote_char().length() > 1)) {
ythrow yexception() << "Cannot read CSV: Wrong quote char '" << quoting.quote_char() << "'";
}
const char qchar = quoting.quote_char().empty() ? '"' : quoting.quote_char().front();
parseOptions.quoting = false;
parseOptions.quote_char = qchar;
parseOptions.double_quote = !quoting.double_quote_disabled();
if (csvSettings.delimiter()) {
if (Y_UNLIKELY(csvSettings.delimiter().size() != 1)) {
ythrow yexception() << "Cannot read CSV: Invalid delimitr in csv: " << csvSettings.delimiter();
}
parseOptions.delimiter = csvSettings.delimiter().front();
}

arrow::csv::ConvertOptions convertOptions = arrow::csv::ConvertOptions::Defaults();
if (csvSettings.null_value()) {
convertOptions.null_values = { std::string(csvSettings.null_value().data(), csvSettings.null_value().size()) };
convertOptions.strings_can_be_null = true;
convertOptions.quoted_strings_can_be_null = false;
}

auto bufferReader = std::make_shared<arrow::io::BufferReader>(arrow::util::string_view(data->Data.data(), data->Data.size()));
auto csvReader = arrow::csv::TableReader::Make(
arrow::io::default_io_context(),
bufferReader,
readOptions,
parseOptions,
convertOptions
).ValueOrDie();

return csvReader->Read().ValueOrDie();
}

void TransformArrow(TDataPortion::TArrow* data) {
const auto [schema, batch] = Deserialize(data);

// id
const auto idColumn = batch->GetColumnByName("id");
const auto newIdColumn = arrow::compute::Cast(idColumn, arrow::uint64()).ValueOrDie().make_array();

// embedding
const auto embeddingColumn = std::dynamic_pointer_cast<arrow::ListArray>(batch->GetColumnByName(EmbeddingSourceField));
arrow::StringBuilder newEmbeddingsBuilder;
for (int64_t row = 0; row < batch->num_rows(); ++row) {
const auto embeddingFloatList = std::static_pointer_cast<arrow::FloatArray>(embeddingColumn->value_slice(row));

TStringBuilder buffer;
NKnnVectorSerialization::TSerializer<float> serializer(&buffer.Out);
for (int64_t i = 0; i < embeddingFloatList->length(); ++i) {
serializer.HandleElement(embeddingFloatList->Value(i));
}
serializer.Finish();

if (const auto status = newEmbeddingsBuilder.Append(buffer.MutRef()); !status.ok()) {
status.Abort();
}
}
std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
if (const auto status = newEmbeddingsBuilder.Finish(&newEmbeddingColumn); !status.ok()) {
status.Abort();
}

const auto newSchema = arrow::schema({
arrow::field("id", arrow::uint64()),
arrow::field("embedding", arrow::utf8()),
});
const auto newRecordBatch = arrow::RecordBatch::Make(
newSchema,
batch->num_rows(),
{
newIdColumn,
newEmbeddingColumn,
}
);
data->Schema = arrow::ipc::SerializeSchema(*newSchema).ValueOrDie()->ToString();
data->Data = arrow::ipc::SerializeRecordBatch(*newRecordBatch, arrow::ipc::IpcWriteOptions{}).ValueOrDie()->ToString();
}

void TransformCsv(TDataPortion::TCsv* data) {
const auto table = Deserialize(data);

// id
const auto idColumn = table->GetColumnByName("id");

// embedding
const auto embeddingColumn = table->GetColumnByName(EmbeddingSourceField);
arrow::StringBuilder newEmbeddingsBuilder;
for (int64_t row = 0; row < table->num_rows(); ++row) {
const auto embeddingListString = std::static_pointer_cast<arrow::StringArray>(embeddingColumn->Slice(row, 1)->chunk(0))->Value(0);

TStringBuf buffer(embeddingListString.data(), embeddingListString.size());
buffer.SkipPrefix("[");
buffer.ChopSuffix("]");
TMemoryInput input(buffer);

TStringBuilder newEmbeddingBuilder;
NKnnVectorSerialization::TSerializer<float> serializer(&newEmbeddingBuilder.Out);
while (!input.Exhausted()) {
float val;
input >> val;
input.Skip(1);
serializer.HandleElement(val);
}
serializer.Finish();

if (const auto status = newEmbeddingsBuilder.Append(newEmbeddingBuilder.MutRef()); !status.ok()) {
status.Abort();
}
}
std::shared_ptr<arrow::StringArray> newEmbeddingColumn;
if (const auto status = newEmbeddingsBuilder.Finish(&newEmbeddingColumn); !status.ok()) {
status.Abort();
}

const auto newSchema = arrow::schema({
arrow::field("id", arrow::uint64()),
arrow::field("embedding", arrow::utf8()),
});
const auto newTable = arrow::Table::Make(
newSchema,
{
idColumn,
arrow::ChunkedArray::Make({newEmbeddingColumn}).ValueOrDie(),
}
);
auto outputStream = arrow::io::BufferOutputStream::Create().ValueOrDie();
if (const auto status = arrow::csv::WriteCSV(*newTable, arrow::csv::WriteOptions::Defaults(), outputStream.get()); !status.ok()) {
status.Abort();
}
data->FormatString = "";
data->Data = outputStream->Finish().ValueOrDie()->ToString();
}

void Transform(TDataPortion::TDataType& data) {
if (auto* value = std::get_if<TDataPortion::TArrow>(&data)) {
TransformArrow(value);
}
if (auto* value = std::get_if<TDataPortion::TCsv>(&data)) {
TransformCsv(value);
}
}

public:
TTransformingDataGenerator(std::shared_ptr<IBulkDataGenerator> innerDataGenerator, const TString embeddingSourceField)
: IBulkDataGenerator(innerDataGenerator->GetName(), innerDataGenerator->GetSize())
, InnerDataGenerator(innerDataGenerator)
, EmbeddingSourceField(embeddingSourceField)
{}

virtual TDataPortions GenerateDataPortion() override {
TDataPortions portions = InnerDataGenerator->GenerateDataPortion();
for (auto portion : portions) {
Transform(portion->MutableData());
}
return portions;
}
};

}

TWorkloadVectorFilesDataInitializer::TWorkloadVectorFilesDataInitializer(const TVectorWorkloadParams& params)
: TWorkloadDataInitializerBase("files", "Import vectors from files", params)
, Params(params)
{ }

void TWorkloadVectorFilesDataInitializer::ConfigureOpts(NLastGetopt::TOpts& opts) {
opts.AddLongOption('i', "input",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've discussed in person a bit on user interface for this functionality. I also asked AI to give me some ideas and here is what we came up with. I think these options fit much better and give more explanations:

// --input
{
    TStringStream description;
    description
        << "File or directory with the dataset to import. Only two columns are imported: "
        << colors.BoldColor() << "id" << colors.OldColor() << " and "
        << colors.BoldColor() << "embedding" << colors.OldColor() << ". "
        << "If a directory is set, all supported files inside will be used."
        << "\nSupported formats: CSV/TSV (zipped or unzipped) and Parquet."
        << "\nIn " << colors.BoldColor() << "convert" << colors.OldColor() << " mode, "
        << "embedding is converted from list of floats to YDB binary embedding format."
        << "\nIn " << colors.BoldColor() << "raw" << colors.OldColor() << " mode, "
        << "embedding must already be binary; for CSV/TSV its encoding is controlled by --input-binary-strings."
        << "\nExample dataset: https://huggingface.co/datasets/Cohere/wikipedia-22-12-simple-embeddings";
    config.Opts->AddLongOption('i', "input", description.Str())
        .RequiredArgument("PATH")
        .Required()
        .StoreResult(&DataFiles);
}

// --mode
{
    TStringStream description;
    description
        << "Import mode. Controls whether input data are converted or taken as-is."
        << "\n  " << colors.BoldColor() << "auto" << colors.OldColor()
        << "\n    " << "Detect mode from input schema/content:"
        << "\n    " << "- Parquet: list<float> embedding -> convert; binary/string embedding -> raw."
        << "\n    " << "- CSV/TSV: numeric array-like embedding -> convert; otherwise -> raw (requires --input-binary-strings)."
        << "\n  " << colors.BoldColor() << "convert" << colors.OldColor()
        << "\n    " << "Pick columns " << colors.BoldColor() << "id" << colors.OldColor() << " and "
                      << colors.BoldColor() << "embedding" << colors.OldColor()
        << ", cast id to Int64, convert embedding (list<float>) to YDB binary embedding."
        << "\n    " << "Reference: https://ydb.tech/docs/yql/reference/udf/list/knn#functions-convert"
        << "\n  " << colors.BoldColor() << "raw" << colors.OldColor()
        << "\n    " << "Load as-is: id must be Int64, embedding must be binary."
        << "\n    " << "For CSV/TSV, set embedding binary encoding with --input-binary-strings."
        << "\nDefault: " << colors.CyanColor() << "\"auto\"" << colors.OldColor() << ".";
    config.Opts->AddLongOption("mode", description.Str())
        .RequiredArgument("MODE")
        .DefaultValue("auto")
        .StoreResult(&Mode);
}

// --embedding-column-name
{
    TStringStream description;
    description
        << "Alternative source column name for the embedding field in input files."
        << "\nUsed in " << colors.BoldColor() << "convert" << colors.OldColor()
        << " (and " << colors.BoldColor() << "auto" << colors.OldColor() << " when it chooses convert)."
        << "\nIf not set, the column is expected to be named "
        << colors.BoldColor() << "\"embedding\"" << colors.OldColor() << ".";
    config.Opts->AddLongOption("embedding-column-name", description.Str())
        .RequiredArgument("NAME")
        .DefaultValue("embedding")
        .StoreResult(&EmbeddingColumnName);
}

// --input-binary-strings
{
    TStringStream description;
    description
        << "Binary encoding of the " << colors.BoldColor() << "embedding" << colors.OldColor()
        << " column in CSV/TSV when importing in " << colors.BoldColor() << "raw" << colors.OldColor()
        << " mode (or in " << colors.BoldColor() << "auto" << colors.OldColor() << " when it selects raw)."
        << "\nIgnored for Parquet and for " << colors.BoldColor() << "convert" << colors.OldColor() << " mode."
        << "\nAvailable options:"
        << "\n  " << colors.BoldColor() << "unicode" << colors.OldColor()
        << "\n    " << "Every byte in binary strings that is not a printable ASCII symbol (codes 32-126) should be encoded as UTF-8."
        << "\n  " << colors.BoldColor() << "base64" << colors.OldColor()
        << "\n    " << "Binary strings should be fully encoded with base64."
        << "\nDefault: " << colors.CyanColor() << "\"unicode\"" << colors.OldColor() << ".";
    config.Opts->AddLongOption("input-binary-strings", description.Str())
        .RequiredArgument("STRING")
        .DefaultValue("unicode")
        .StoreResult(&InputBinaryStringEncodingFormat);
}

"File or Directory with dataset. If directory is set, all its available files will be used. "
"Supports zipped and unzipped csv, tsv files and parquet ones that may be downloaded here: "
"https://huggingface.co/datasets/Cohere/wikipedia-22-12-simple-embeddings. "
"For better performance you may split it into some parts for parallel upload."
).Required().StoreResult(&DataFiles);
opts.AddLongOption('t', "transform",
"Perform transformation of input data. "
"Parquet: leave only required fields, cast to expected types, convert list of floats into serialized representation. "
"CSV: leave only required fields, parse float list from string and serialize. "
"Reference for embedding serialization: https://ydb.tech/docs/yql/reference/udf/list/knn#functions-convert"
).Optional().StoreTrue(&DoTransform);
opts.AddLongOption(
"transform-embedding-source-field",
"Specify field that contains list of floats to be converted into YDB embedding format."
).DefaultValue(EmbeddingSourceField).StoreResult(&EmbeddingSourceField);
}

TBulkDataGeneratorList TWorkloadVectorFilesDataInitializer::DoGetBulkInitialData() {
auto dataGenerator = std::make_shared<TDataGenerator>(
*this,
Params.TableName,
0,
Params.TableName,
DataFiles,
Params.GetColumns(),
TDataGenerator::EPortionSizeUnit::Line
);

if (DoTransform) {
return {std::make_shared<TTransformingDataGenerator>(dataGenerator, EmbeddingSourceField)};
}
return {dataGenerator};
}

} // namespace NYdbWorkload
24 changes: 24 additions & 0 deletions ydb/library/workload/vector/vector_data_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "vector_workload_params.h"

#include <ydb/library/workload/benchmark_base/workload.h>
#include <ydb/library/workload/benchmark_base/data_generator.h>

namespace NYdbWorkload {

class TWorkloadVectorFilesDataInitializer : public TWorkloadDataInitializerBase {
private:
const TVectorWorkloadParams& Params;
TString DataFiles;
bool DoTransform = false;
TString EmbeddingSourceField = "embedding";

public:
TWorkloadVectorFilesDataInitializer(const TVectorWorkloadParams& params);

virtual void ConfigureOpts(NLastGetopt::TOpts& opts) override;
virtual TBulkDataGeneratorList DoGetBulkInitialData() override;
};

} // namespace NYdbWorkload
19 changes: 19 additions & 0 deletions ydb/library/workload/vector/vector_workload_params.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "vector_data_generator.h"
#include "vector_enums.h"
#include "vector_workload_params.h"
#include "vector_workload_generator.h"
Expand Down Expand Up @@ -55,6 +56,9 @@ void TVectorWorkloadParams::ConfigureOpts(NLastGetopt::TOpts& opts, const EComma
ConfigureCommonOpts(opts);
addInitParam();
break;
case TWorkloadParams::ECommandType::Import:
ConfigureCommonOpts(opts);
break;
case TWorkloadParams::ECommandType::Run:
ConfigureCommonOpts(opts);
switch (static_cast<EWorkloadRunType>(workloadType)) {
Expand Down Expand Up @@ -91,6 +95,15 @@ void TVectorWorkloadParams::ConfigureIndexOpts(NLastGetopt::TOpts& opts) {
.Required().StoreResult(&KmeansTreeClusters);
}

TVector<TString> TVectorWorkloadParams::GetColumns() const {
TVector<TString> result(KeyColumns.begin(), KeyColumns.end());
result.emplace_back(EmbeddingColumn);
if (PrefixColumn.has_value()) {
result.emplace_back(PrefixColumn.value());
}
return result;
}

void TVectorWorkloadParams::Init() {
const TString tablePath = GetFullTableName(TableName.c_str());

Expand Down Expand Up @@ -193,6 +206,12 @@ THolder<IWorkloadQueryGenerator> TVectorWorkloadParams::CreateGenerator() const
return MakeHolder<TVectorWorkloadGenerator>(this);
}

TWorkloadDataInitializer::TList TVectorWorkloadParams::CreateDataInitializers() const {
return {
std::make_shared<TWorkloadVectorFilesDataInitializer>(*this)
};
}

TString TVectorWorkloadParams::GetWorkloadName() const {
return "vector";
}
Expand Down
3 changes: 3 additions & 0 deletions ydb/library/workload/vector/vector_workload_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class TVectorWorkloadParams final: public TWorkloadBaseParams {
public:
void ConfigureOpts(NLastGetopt::TOpts& opts, const ECommandType commandType, int workloadType) override;
THolder<IWorkloadQueryGenerator> CreateGenerator() const override;
TWorkloadDataInitializer::TList CreateDataInitializers() const override;
TString GetWorkloadName() const override;
void Validate(const ECommandType commandType, int workloadType) override;

Expand All @@ -26,6 +27,8 @@ class TVectorWorkloadParams final: public TWorkloadBaseParams {
void ConfigureCommonOpts(NLastGetopt::TOpts& opts);
void ConfigureIndexOpts(NLastGetopt::TOpts& opts);

TVector<TString> GetColumns() const;

TString TableName;
TString QueryTableName;
TString IndexName;
Expand Down
3 changes: 3 additions & 0 deletions ydb/library/workload/vector/ya.make
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ LIBRARY()

SRCS(
vector_command_index.cpp
vector_data_generator.cpp
vector_recall_evaluator.cpp
vector_sampler.cpp
vector_sql.cpp
Expand All @@ -11,7 +12,9 @@ SRCS(
)

PEERDIR(
contrib/libs/apache/arrow
ydb/library/workload/abstract
ydb/public/api/protos
)

GENERATE_ENUM_SERIALIZATION_WITH_HEADER(vector_enums.h)
Expand Down
Loading