Skip to content

Commit

Permalink
refactor: revert to an accumulator to encode a message
Browse files Browse the repository at this point in the history
This mostly reverts commit 5a34820.

It was a bad decision. Not using an accumulator shows a slowdown from ~9% to ~107% and a memory usage increase from ~20% to ~90%!
  • Loading branch information
ahamez committed Jan 31, 2025
1 parent c5c9554 commit ccf2618
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 65 deletions.
114 changes: 51 additions & 63 deletions lib/protox/define_encoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ defmodule Protox.DefineEncoder do
alias Protox.{Field, OneOf, Scalar}

def define(fields, required_fields, syntax, opts \\ []) do
{unknown_fields_name, _opts} = Keyword.pop!(opts, :unknown_fields_name)

%{oneofs: oneofs, proto3_optionals: proto3_optionals, others: fields_without_oneofs} =
Protox.Defs.split_oneofs(fields)

Expand All @@ -16,7 +14,7 @@ defmodule Protox.DefineEncoder do
encode_oneof_funs = make_encode_oneof_funs(oneofs)
encode_field_funs = make_encode_field_funs(fields, required_fields, syntax)

encode_unknown_fields_fun = make_encode_unknown_fields_fun(unknown_fields_name)
encode_unknown_fields_fun = make_encode_unknown_fields_fun(opts)

quote do
unquote(top_level_encode_fun)
Expand All @@ -27,27 +25,10 @@ defmodule Protox.DefineEncoder do
end

defp make_top_level_encode_fun(oneofs, fields) do
ast =
quote do
[
unquote_splicing(make_encode_oneof_fun(oneofs)),
unquote_splicing(make_encode_fun_field(fields))
]
end

make_encode_fun_body(ast)
end

defp make_encode_fun_body([] = _ast) do
quote do
@spec encode(struct()) :: {:ok, iodata()}
def encode(msg) do
{:ok, encode!(msg)}
end

@spec encode!(struct()) :: iodata()
def encode!(_msg), do: []
end
quote(do: [])
|> make_encode_oneof_fun(oneofs)
|> make_encode_fun_field(fields)
|> make_encode_fun_body()
end

defp make_encode_fun_body(ast) do
Expand All @@ -67,31 +48,31 @@ defmodule Protox.DefineEncoder do
end
end

defp make_encode_fun_field(fields) do
defp make_encode_fun_field(ast, fields) do
ast =
Enum.map(fields, fn %Protox.Field{} = field ->
Enum.reduce(fields, ast, fn %Protox.Field{} = field, ast_acc ->
fun_name = String.to_atom("encode_#{field.name}")

quote(do: unquote(fun_name)(msg))
quote(do: unquote(ast_acc) |> unquote(fun_name)(msg))
end)

quote do
[unquote_splicing(ast), encode_unknown_fields(msg)]
unquote(ast) |> encode_unknown_fields(msg)
end
end

defp make_encode_oneof_fun(oneofs) do
Enum.map(oneofs, fn {parent_name, _children} ->
defp make_encode_oneof_fun(ast, oneofs) do
Enum.reduce(oneofs, ast, fn {parent_name, _children}, ast_acc ->
fun_name = String.to_atom("encode_#{parent_name}")
quote(do: unquote(fun_name)(msg))
quote(do: unquote(ast_acc) |> unquote(fun_name)(msg))
end)
end

defp make_encode_oneof_funs(oneofs) do
for {parent_name, children} <- oneofs do
nil_case =
quote do
nil -> []
nil -> acc
end

children_case_ast =
Expand All @@ -102,15 +83,15 @@ defmodule Protox.DefineEncoder do

quote do
{unquote(child_field.name), _field_value} ->
unquote(encode_child_fun_name)(msg)
unquote(encode_child_fun_name)(acc, msg)
end
end)
|> List.flatten())

encode_parent_fun_name = String.to_atom("encode_#{parent_name}")

quote do
defp unquote(encode_parent_fun_name)(msg) do
defp unquote(encode_parent_fun_name)(acc, msg) do
case msg.unquote(parent_name) do
unquote(children_case_ast)
end
Expand All @@ -121,6 +102,7 @@ defmodule Protox.DefineEncoder do

defp make_encode_field_funs(fields, required_fields, syntax) do
vars = %{
acc: Macro.var(:acc, __MODULE__),
msg: Macro.var(:msg, __MODULE__)
}

Expand All @@ -130,7 +112,7 @@ defmodule Protox.DefineEncoder do
fun_ast = make_encode_field_body(field, required, syntax, vars)

quote do
defp unquote(fun_name)(unquote(vars.msg)) do
defp unquote(fun_name)(unquote(vars.acc), unquote(vars.msg)) do
try do
unquote(fun_ast)
rescue
Expand All @@ -154,14 +136,14 @@ defmodule Protox.DefineEncoder do
quote do
case unquote(vars.msg).unquote(field.name) do
nil -> raise Protox.RequiredFieldsError.new([unquote(field.name)])
_ -> [unquote(key), unquote(encode_value_ast)]
_ -> [unquote(vars.acc), unquote(key), unquote(encode_value_ast)]
end
end
else
quote do
case unquote(var) do
nil -> []
_ -> [unquote(key), unquote(encode_value_ast)]
nil -> unquote(vars.acc)
_ -> [unquote(vars.acc), unquote(key), unquote(encode_value_ast)]
end
end
end
Expand All @@ -170,9 +152,9 @@ defmodule Protox.DefineEncoder do
quote do
# Use == rather than pattern match for float comparison
if unquote(var) == unquote(field.kind.default_value) do
[]
unquote(vars.acc)
else
[unquote(key), unquote(encode_value_ast)]
[unquote(vars.acc), unquote(key), unquote(encode_value_ast)]
end
end
end
Expand All @@ -193,8 +175,11 @@ defmodule Protox.DefineEncoder do
:proto3_optional ->
quote do
case unquote(vars.msg).unquote(field.name) do
nil -> []
unquote(var) -> [unquote(key), unquote(encode_value_ast)]
nil ->
[unquote(vars.acc)]

unquote(var) ->
[unquote(vars.acc), unquote(key), unquote(encode_value_ast)]
end
end

Expand All @@ -203,7 +188,7 @@ defmodule Protox.DefineEncoder do
# this is why we don't check if the child is set.
quote do
{_, unquote(var)} = unquote(vars.msg).unquote(field.kind.parent)
[unquote(key), unquote(encode_value_ast)]
[unquote(vars.acc), unquote(key), unquote(encode_value_ast)]
end
end
end
Expand All @@ -214,12 +199,8 @@ defmodule Protox.DefineEncoder do

quote do
case unquote(vars.msg).unquote(field.name) do
[] ->
[]

values ->
{bytes, len} = unquote(encode_packed_ast)
[unquote(key), Protox.Varint.encode(len), bytes]
[] -> unquote(vars.acc)
values -> [unquote(vars.acc), unquote(key), unquote(encode_packed_ast)]
end
end
end
Expand All @@ -229,8 +210,8 @@ defmodule Protox.DefineEncoder do

quote do
case unquote(vars.msg).unquote(field.name) do
[] -> []
values -> unquote(encode_repeated_ast)
[] -> unquote(vars.acc)
values -> [unquote(vars.acc), unquote(encode_repeated_ast)]
end
end
end
Expand All @@ -256,7 +237,8 @@ defmodule Protox.DefineEncoder do
quote do
map = Map.fetch!(unquote(vars.msg), unquote(field.name))

Enum.map(map, fn {unquote(k_var), unquote(v_var)} ->
Enum.reduce(map, unquote(vars.acc), fn {unquote(k_var), unquote(v_var)},
unquote(vars.acc) ->
map_key_value_bytes = :binary.list_to_bin([unquote(encode_map_key_ast)])
map_key_value_len = byte_size(map_key_value_bytes)

Expand All @@ -267,6 +249,7 @@ defmodule Protox.DefineEncoder do
Protox.Varint.encode(unquote(map_keys_len) + map_key_value_len + map_value_value_len)

[
unquote(vars.acc),
unquote(key),
len,
unquote(map_key_key_bytes),
Expand All @@ -278,23 +261,25 @@ defmodule Protox.DefineEncoder do
end
end

defp make_encode_unknown_fields_fun(unknown_fields_name) do
defp make_encode_unknown_fields_fun(opts) do
unknown_fields_name = Keyword.fetch!(opts, :unknown_fields_name)

quote do
defp encode_unknown_fields(msg) do
Enum.map(msg.unquote(unknown_fields_name), fn {tag, wire_type, bytes} ->
defp encode_unknown_fields(acc, msg) do
Enum.reduce(msg.unquote(unknown_fields_name), acc, fn {tag, wire_type, bytes}, acc ->
case wire_type do
0 ->
[Protox.Encode.make_key_bytes(tag, :int32), bytes]
[acc, Protox.Encode.make_key_bytes(tag, :int32), bytes]

1 ->
[Protox.Encode.make_key_bytes(tag, :double), bytes]
[acc, Protox.Encode.make_key_bytes(tag, :double), bytes]

2 ->
len_bytes = bytes |> byte_size() |> Protox.Varint.encode()
[Protox.Encode.make_key_bytes(tag, :packed), len_bytes, bytes]
[acc, Protox.Encode.make_key_bytes(tag, :packed), len_bytes, bytes]

5 ->
[Protox.Encode.make_key_bytes(tag, :float), bytes]
[acc, Protox.Encode.make_key_bytes(tag, :float), bytes]
end
end)
end
Expand All @@ -306,10 +291,13 @@ defmodule Protox.DefineEncoder do
encode_value_ast = get_encode_value_body(type, value_var)

quote do
Enum.reduce(values, {[], 0}, fn unquote(value_var), {acc, len} ->
value_bytes = :binary.list_to_bin([unquote(encode_value_ast)])
{[acc, value_bytes], len + byte_size(value_bytes)}
end)
{bytes, len} =
Enum.reduce(values, {[], 0}, fn unquote(value_var), {acc, len} ->
value_bytes = :binary.list_to_bin([unquote(encode_value_ast)])
{[acc, value_bytes], len + byte_size(value_bytes)}
end)

[Protox.Varint.encode(len), bytes]
end
end

Expand Down
3 changes: 1 addition & 2 deletions test/protox/encode_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ defmodule Protox.EncodeTest do
end

test "Default TestAllTypesProto3, with non throwing encode/1" do
{status, bytes} = Protox.encode(%TestAllTypesProto3{})
assert {:ok, []} == {status, List.flatten(bytes)}
assert {:ok, []} == Protox.encode(%TestAllTypesProto3{})
end

test "Messsage with no fields, unknown fields are encoded back" do
Expand Down

0 comments on commit ccf2618

Please sign in to comment.