diff --git a/ydb/core/kqp/provider/yql_kikimr_datasink.cpp b/ydb/core/kqp/provider/yql_kikimr_datasink.cpp index 62171167561b..e753f704854d 100644 --- a/ydb/core/kqp/provider/yql_kikimr_datasink.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_datasink.cpp @@ -13,6 +13,7 @@ namespace { using namespace NKikimr; using namespace NNodes; +namespace { bool HasUpdateIntersection(const NCommon::TWriteTableSettings& settings) { THashSet columnNames; auto equalStmts = settings.Update.Cast().Ptr()->Child(1); @@ -38,6 +39,71 @@ bool HasUpdateIntersection(const NCommon::TWriteTableSettings& settings) { return hasIntersection; } +TExprNode::TPtr CreateNodeParameter(const TString& name, const TTypeAnnotationNode* colType, TPositionHandle pos, + TExprContext& ctx) { + if (colType->GetKind() == ETypeAnnotationKind::Optional) { + colType = colType->Cast()->GetItemType(); + } + + return ctx.NewCallable(pos, "Parameter", { + ctx.NewAtom(pos, name), + ctx.NewCallable(pos, "DataType", { + ctx.NewAtom(pos, FormatType(colType)) + }) + }); +} + +TCoLambda RewriteBatchFilter(const TCoLambda& node, const TKikimrTableDescription& tableDesc, TExprContext& ctx) { + const TPositionHandle pos = node.Pos(); + const TExprNode::TPtr newLambda = ctx.DeepCopyLambda(node.Ref()); + const TExprNode::TPtr row = newLambda->ChildPtr(0)->ChildPtr(0); + const TExprNode::TPtr filter = newLambda->ChildPtr(1); + + TVector primaryColumns = tableDesc.Metadata->KeyColumnNames; + + TExprNode::TListType beginParamsList; + TExprNode::TListType endParamsList; + TExprNode::TListType primaryMembersList; + + for (size_t i = 0; i < primaryColumns.size(); ++i) { + auto colType = tableDesc.GetColumnType(primaryColumns[i]); + beginParamsList.push_back(CreateNodeParameter("_kqp_batch_begin_" + ToString(i + 1), colType, pos, ctx)); + endParamsList.push_back(CreateNodeParameter("_kqp_batch_end_" + ToString(i + 1), colType, pos, ctx)); + + primaryMembersList.push_back(ctx.NewCallable(pos, "Member", { + row, + ctx.NewAtom(pos, primaryColumns[i]) + })); + } + + TExprNode::TPtr beginNodeParams = beginParamsList.front(); + TExprNode::TPtr endNodeParams = endParamsList.front(); + TExprNode::TPtr primaryNodeMember = primaryMembersList.front(); + + if (primaryColumns.size() > 1) { + beginNodeParams = ctx.NewList(pos, std::move(beginParamsList)); + endNodeParams = ctx.NewList(pos, std::move(endParamsList)); + primaryNodeMember = ctx.NewList(pos, std::move(primaryMembersList)); + } + + TExprNode::TPtr newFilter = ctx.ChangeChild(*filter, 0, ctx.NewCallable(pos, "And", { + ctx.NewCallable(pos, "And", { + ctx.NewCallable(pos, ">=", { + primaryNodeMember, + beginNodeParams + }), + ctx.NewCallable(pos, "<", { + primaryNodeMember, + endNodeParams + }) + }), + filter->ChildPtr(0) + })); + + return TCoLambda(ctx.ChangeChild(*newLambda, 1, std::move(newFilter))); +} +} // namespace + class TKiSinkIntentDeterminationTransformer: public TKiSinkVisitorTransformer { public: TKiSinkIntentDeterminationTransformer(TIntrusivePtr sessionCtx) @@ -1084,6 +1150,16 @@ class TKikimrDataSink : public TDataProviderBase } else if (mode == "update") { if (settings.Filter) { YQL_ENSURE(settings.Update); + + if (settings.IsBatch) { + TKiDataSink dataSink(node->Child(1)); + auto tableDesc = SessionCtx->Tables().EnsureTableExists( + TString(dataSink.Cluster()), + key.GetTablePath(), node->Pos(), ctx); + + settings.Filter = RewriteBatchFilter(std::move(settings.Filter.Cast()), *tableDesc, ctx); + } + return Build(ctx, node->Pos()) .World(node->Child(0)) .DataSink(node->Child(1)) @@ -1116,6 +1192,15 @@ class TKikimrDataSink : public TDataProviderBase } else if (mode == "delete") { YQL_ENSURE(settings.Filter || settings.PgFilter); if (settings.Filter) { + if (settings.IsBatch) { + TKiDataSink dataSink(node->Child(1)); + auto tableDesc = SessionCtx->Tables().EnsureTableExists( + TString(dataSink.Cluster()), + key.GetTablePath(), node->Pos(), ctx); + + settings.Filter = RewriteBatchFilter(std::move(settings.Filter.Cast()), *tableDesc, ctx); + } + return Build(ctx, node->Pos()) .World(node->Child(0)) .DataSink(node->Child(1))