From de8724157512861ad88722a0c5a2caef88658dc6 Mon Sep 17 00:00:00 2001 From: Alexandre Hamez Date: Thu, 30 Jan 2025 14:59:44 +0100 Subject: [PATCH] refactor: directly pattern match on varint when decoding Benchmark shows - from ~1.2% to ~9.8% speedup - a reduction from ~6.2% to ~11% in memory consumption --- lib/protox/define_decoder.ex | 133 ++++++++++++++++++++++++++--------- 1 file changed, 100 insertions(+), 33 deletions(-) diff --git a/lib/protox/define_decoder.ex b/lib/protox/define_decoder.ex index 8d6c383e..1fa86180 100644 --- a/lib/protox/define_decoder.ex +++ b/lib/protox/define_decoder.ex @@ -102,30 +102,32 @@ defmodule Protox.DefineDecoder do end defp make_parse_key_value_body(keep_set_fields, fields, vars, opts) do - # Fragment to handle the (invalid) field with tag 0. - tag_0_case = make_parse_key_value_tag_0() - # Fragment to parse unknown fields. Those are identified with an unknown tag. - unknown_fields_name = Keyword.fetch!(opts, :unknown_fields_name) - - unknown_tag_case = + unknown_tag_clause = make_parse_key_value_unknown( vars, keep_set_fields, - unknown_fields_name + Keyword.fetch!(opts, :unknown_fields_name) ) - # Fragment to parse known fields. - known_tags_case = make_parse_key_value_known(vars, fields, keep_set_fields) + # Fragment to parse all regular fields. + all_fields_clase = make_parse_key_value_known(vars, fields, keep_set_fields) - all_cases = tag_0_case ++ known_tags_case ++ unknown_tag_case + all_clauses = + make_parse_key_value_invalid_varint() ++ + make_parse_key_value_tag_0() ++ + all_fields_clase ++ + unknown_tag_clause + # Note we directly pattern-match against the bytes: we don't decode the tag + # and the wire type using Varint.decode. Indeed, as we know the varint encoding + # at compile time, we can generate the appropriate clauses. + # This has the benefit of a small speedup (~1%-10%) and a decrease in memory usage (~10%) from + # the Varint.decode version. if keep_set_fields do quote do {new_set_fields, unquote(vars.field), rest} = - case Protox.Decode.parse_key(bytes) do - unquote(all_cases) - end + case bytes, do: unquote(all_clauses) msg_updated = struct(unquote(vars.msg), unquote(vars.field)) parse_key_value(new_set_fields, rest, msg_updated) @@ -133,9 +135,7 @@ defmodule Protox.DefineDecoder do else quote do {unquote(vars.field), rest} = - case Protox.Decode.parse_key(bytes) do - unquote(all_cases) - end + case bytes, do: unquote(all_clauses) msg_updated = struct(unquote(vars.msg), unquote(vars.field)) parse_key_value(rest, msg_updated) @@ -145,7 +145,23 @@ defmodule Protox.DefineDecoder do defp make_parse_key_value_tag_0() do quote do - {0, _, _} -> raise %Protox.IllegalTagError{} + <<0::5, _::3, _rest::binary>> -> raise %Protox.IllegalTagError{} + end + end + + defp make_parse_key_value_invalid_varint() do + quote do + <<_::5, 3::3, _rest::binary>> -> + raise Protox.DecodingError.new(bytes, "invalid wire type 3") + + <<_::5, 4::3, _rest::binary>> -> + raise Protox.DecodingError.new(bytes, "invalid wire type 4") + + <<_::5, 6::3, _rest::binary>> -> + raise Protox.DecodingError.new(bytes, "invalid wire type 6") + + <<_::5, 7::3, _rest::binary>> -> + raise Protox.DecodingError.new(bytes, "invalid wire type 7") end end @@ -178,7 +194,8 @@ defmodule Protox.DefineDecoder do end quote do - {tag, wire_type, rest} -> + <> -> + {tag, wire_type, rest} = Protox.Decode.parse_key(unquote(vars.bytes)) {unquote(vars.value), rest} = Protox.Decode.parse_unknown(tag, wire_type, rest) unquote(case_return) @@ -200,20 +217,36 @@ defmodule Protox.DefineDecoder do parse_single = make_parse_single(vars.bytes, field.type) update_field = make_update_field(vars.value, field, vars, _wrap_value = true) - # No need to maintain a list of set fields for proto3 - case_return = + # No need to maintain a list of set fields for proto3. + clause_return = case keep_set_fields do - true -> - quote do: {[unquote(field.name) | set_fields], [unquote(update_field)], rest} + true -> quote do: {[unquote(field.name) | set_fields], [unquote(update_field)], rest} + false -> quote do: {[unquote(update_field)], rest} + end + + key_bytes = make_key_bytes(field) - false -> - quote do: {[unquote(update_field)], rest} + # The last 3 bits of the first byte are the wire type, which we can to ignore here as we know beforehand + # how the field is encoded. + <> = key_bytes + + clause = + case tail do + "" -> + quote do + <> + end + + _ -> + quote do + <> + end end quote do - {unquote(field.tag), _, unquote(vars.bytes)} -> + unquote(clause) -> {value, rest} = unquote(parse_single) - unquote(case_return) + unquote(clause_return) end end @@ -266,17 +299,36 @@ defmodule Protox.DefineDecoder do false -> quote do: {[unquote(update_field)], rest} end - # If `single` was not generated, then we don't need the `@wire_delimited discrimant - # as there is only one clause for this `tag`. - wire_type = - case single_generated do - true -> quote do: unquote(@wire_delimited) - false -> quote do: _ + key_bytes = make_key_bytes(%Field{field | kind: :packed}) + + clause = + if single_generated do + # If the single clause was not generated for this field, we don't need the wire type + # discrimant as there is only one clause matching for this field. + quote do + <> + end + else + <> = key_bytes + + case tail do + "" -> + quote do + <> + end + + _ -> + quote do + <> + end + end end quote do - {unquote(field.tag), unquote(wire_type), unquote(vars.bytes)} -> + unquote(clause) -> {len, unquote(vars.bytes)} = Protox.Varint.decode(unquote(vars.bytes)) + {unquote(vars.delimited), rest} = Protox.Decode.parse_delimited(unquote(vars.bytes), len) unquote(case_return) end @@ -608,4 +660,19 @@ defmodule Protox.DefineDecoder do _ -> make_parse_single(vars.rest, type) end end + + # Compute at compile time the varint representation of a field + # tag and wire type. + defp make_key_bytes(%Field{} = field) do + # We need to convert the type to something recognized + # by Protox.Encode.make_key_bytes/2. + ty = + case field.kind do + :map -> :map_entry + :packed -> :packed + _ -> field.type + end + + Protox.Encode.make_key_bytes(field.tag, ty) |> IO.iodata_to_binary() + end end