Skip to content

Commit

Permalink
generalize kafka key and headers formation (#8053)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: aabfc64f9c12fd393909dc10542d41f50ca98603
  • Loading branch information
zxqfd555-pw authored and Manul from Pathway committed Jan 22, 2025
1 parent b9efe06 commit 12bccce
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
- **BREAKING**: `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` now returns a dictionary from `pw_ai_answer` endpoint.
- `pw.xpacks.llm.question_answering.BaseRAGQuestionAnswerer` allows optionally returning context documents from `pw_ai_answer` endpoint.
- **BREAKING**: When using delay in temporal behavior, current time is updated immediately, not in the next batch.
- `pw.io.kafka.write` now allows to specify `key` and `headers` for JSON and CSV data formats.

### Fixed
- `generate_class` method in `Schema` now correctly renders columns of `UnionType` and `None` types.
Expand Down
38 changes: 38 additions & 0 deletions integration_tests/kafka/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,41 @@ def stream_inputs():
),
30,
)


@pytest.mark.flaky(reruns=3)
def test_kafka_json_key(tmp_path, kafka_context):
input_path = tmp_path / "input.jsonl"
with open(input_path, "w") as f:
f.write(json.dumps({"k": 0, "v": "foo"}))
f.write("\n")
f.write(json.dumps({"k": 1, "v": "bar"}))
f.write("\n")
f.write(json.dumps({"k": 2, "v": "baz"}))
f.write("\n")

class InputSchema(pw.Schema):
k: int = pw.column_definition(primary_key=True)
v: str

table = pw.io.jsonlines.read(input_path, schema=InputSchema, mode="static")
pw.io.kafka.write(
table,
rdkafka_settings=kafka_context.default_rdkafka_settings(),
topic_name=kafka_context.output_topic,
format="json",
key=table["v"],
headers=[table["k"], table["v"]],
)
pw.run()
output_topic_contents = kafka_context.read_output_topic()
for message in output_topic_contents:
key = message.key
value = json.loads(message.value)
assert value["v"].encode("utf-8") == key
assert "k" in value
headers = {}
for header_key, header_value in message.headers:
headers[header_key] = header_value
assert headers["k"] == str(value["k"]).encode("utf-8")
assert headers["v"] == f'"{value["v"]}"'.encode("utf-8")
76 changes: 34 additions & 42 deletions python/pathway/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,17 +389,12 @@ def construct_s3_data_storage(
def check_raw_and_plaintext_only_kwargs_for_message_queues(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if kwargs.get("format") not in ("raw", "plaintext"):
unexpected_params = [
"key",
"value",
"headers",
]
for param in unexpected_params:
if param in kwargs and kwargs[param] is not None:
raise ValueError(
f"Unsupported argument for {format} format: {param}"
)
data_format = kwargs.get("format")
if data_format not in ("raw", "plaintext"):
if "value" in kwargs and kwargs["value"] is not None:
raise ValueError(
f"Unsupported argument for {data_format} format: 'value'"
)

return f(*args, **kwargs)

Expand Down Expand Up @@ -427,31 +422,42 @@ def construct(
) -> MessageQueueOutputFormat:
key_field_index = None
header_fields: dict[str, int] = {}
if format == "json":
data_format = api.DataFormat(
format_type="jsonlines",
key_field_names=[],
value_fields=_format_output_value_fields(table),
extracted_field_indices: dict[str, int] = {}
columns_to_extract: list[ColumnReference] = []
allowed_column_types = (dt.BYTES, dt.STR, dt.ANY)

# Common part for all formats: obtain key field index and prepare header fields
if key is not None:
if table[key._name]._column.dtype not in allowed_column_types:
raise ValueError(
f"The key column should be of the type '{allowed_column_types[0]}'"
)
key_field_index = cls.add_column_reference_to_extract(
key, columns_to_extract, extracted_field_indices
)
elif format == "dsv":
if headers is not None:
for header in headers:
header_fields[header.name] = cls.add_column_reference_to_extract(
header, columns_to_extract, extracted_field_indices
)

# Format-dependent parts: handle json and dsv separately
if format == "json" or format == "dsv":
for column_name in table._columns:
cls.add_column_reference_to_extract(
table[column_name], columns_to_extract, extracted_field_indices
)
table = table.select(*columns_to_extract)
data_format = api.DataFormat(
format_type="dsv",
format_type="jsonlines" if format == "json" else "dsv",
key_field_names=[],
value_fields=_format_output_value_fields(table),
delimiter=delimiter,
)
elif format == "raw" or format == "plaintext":
value_field_index = None
extracted_field_indices: dict[str, int] = {}
columns_to_extract: list[ColumnReference] = []
allowed_column_types = (dt.BYTES if format == "raw" else dt.STR, dt.ANY)

if key is not None:
if value is None:
raise ValueError("'value' must be specified if 'key' is not None")
key_field_index = cls.add_column_reference_to_extract(
key, columns_to_extract, extracted_field_indices
)
if key is not None and value is None:
raise ValueError("'value' must be specified if 'key' is not None")
if value is not None:
value_field_index = cls.add_column_reference_to_extract(
value, columns_to_extract, extracted_field_indices
Expand All @@ -468,21 +474,7 @@ def construct(
value, columns_to_extract, extracted_field_indices
)

if headers is not None:
for header in headers:
header_fields[header.name] = cls.add_column_reference_to_extract(
header, columns_to_extract, extracted_field_indices
)

table = table.select(*columns_to_extract)

if (
key is not None
and table[key._name]._column.dtype not in allowed_column_types
):
raise ValueError(
f"The key column should be of the type '{allowed_column_types[0]}'"
)
if table[value._name]._column.dtype not in allowed_column_types:
raise ValueError(
f"The value column should be of the type '{allowed_column_types[0]}'"
Expand Down
13 changes: 7 additions & 6 deletions python/pathway/io/kafka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,17 @@ def write(
``table`` must either contain exactly one binary column that will be dumped as it is into the
Kafka message, or the reference to the target binary column must be specified explicitly
in the ``value`` parameter. Similarly, if "plaintext" is chosen, the table should consist
of a single column of the string type.
of a single column of the string type, or the reference to the target string column
must be specified explicitly in the ``value`` parameter.
delimiter: field delimiter to be used in case of delimiter-separated values
format.
key: reference to the column that should be used as a key in the
produced message in 'plaintext' or 'raw' format. If left empty, an internal primary key will
be used.
format 'dsv'.
key: reference to the column that should be used as a key in the produced message.
If left empty, an internal primary key will be used.
value: reference to the column that should be used as a value in
the produced message in 'plaintext' or 'raw' format. It can be deduced automatically if the
table has exactly one column. Otherwise it must be specified directly. It also has to be
explicitly specified, if ``key`` is set.
explicitly specified, if ``key`` is set. The type of the column must correspond to the
format used: ``str`` for the 'plaintext' format and ``binary`` for the 'raw' format.
headers: references to the table fields that must be provided as message
headers. These headers are named in the same way as fields that are forwarded and correspond
to the string representations of the respective values encoded in UTF-8. If a binary
Expand Down
4 changes: 2 additions & 2 deletions src/connectors/data_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ impl Formatter for DsvFormatter {
Ok(FormatterContext::new(
payloads,
*key,
Vec::new(),
values.to_vec(),
time,
diff,
))
Expand Down Expand Up @@ -1852,7 +1852,7 @@ impl Formatter for JsonLinesFormatter {
Ok(FormatterContext::new_single_payload(
serializer.into_inner(),
*key,
Vec::new(),
values.to_vec(),
time,
diff,
))
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_dsv_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn test_dsv_format_ok() -> eyre::Result<()> {

let target_payloads = vec![b"b;c;time;diff".to_vec(), b"\"x\";\"y\";0;1".to_vec()];

assert_eq!(result.values.len(), 0);
assert_eq!(result.values, &[Value::from("x"), Value::from("y")]);
assert_eq!(result.payloads.len(), target_payloads.len());
for (result_payload, target_payload) in zip(result.payloads, target_payloads) {
assert_document_raw_byte_contents(&result_payload, &target_payload);
Expand Down
20 changes: 10 additions & 10 deletions tests/integration/test_json_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn test_json_format_ok() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":"b","diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values, &[Value::from("b")]);

Ok(())
}
Expand All @@ -43,7 +43,7 @@ fn test_json_int_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":555,"diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -63,7 +63,7 @@ fn test_json_float_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":5.55,"diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -83,7 +83,7 @@ fn test_json_bool_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":true,"diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -103,7 +103,7 @@ fn test_json_null_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":null,"diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -124,7 +124,7 @@ fn test_json_pointer_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":"^04000000000000000000000000","diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -149,7 +149,7 @@ fn test_json_tuple_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":[true,null],"diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -171,7 +171,7 @@ fn test_json_date_time_naive_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":"2023-05-15T10:51:00.000000000","diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -191,7 +191,7 @@ fn test_json_date_time_utc_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":"2023-05-15T10:51:00.000000000+0000","diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}
Expand All @@ -212,7 +212,7 @@ fn test_json_duration_serialization() -> eyre::Result<()> {
&result.payloads[0],
r#"{"a":1197780000000000,"diff":1,"time":0}"#.as_bytes(),
);
assert_eq!(result.values.len(), 0);
assert_eq!(result.values.len(), 1);

Ok(())
}

0 comments on commit 12bccce

Please sign in to comment.