diff --git a/lib/realtime/api.ex b/lib/realtime/api.ex index f89500d56..5396aa885 100644 --- a/lib/realtime/api.ex +++ b/lib/realtime/api.ex @@ -124,7 +124,8 @@ defmodule Realtime.Api do data: %{external_id: external_id} }) when is_map_key(changes, :jwt_jwks) or is_map_key(changes, :jwt_secret) do - Phoenix.PubSub.broadcast!(Realtime.PubSub, "realtime:operations:" <> external_id, :disconnect) + IO.inspect("trigger_disconnect") + RealtimeWeb.Endpoint.broadcast("user_socket:#{external_id}", "disconnect", %{}) end defp maybe_trigger_disconnect(_), do: nil @@ -198,7 +199,8 @@ defmodule Realtime.Api do {value, settings} = Map.pop(extension.settings, from) new_settings = Map.put(settings, to, value) - Ecto.Changeset.cast(extension, %{settings: new_settings}, [:settings]) + extension + |> Ecto.Changeset.cast(%{settings: new_settings}, [:settings]) |> Repo.update!() end end diff --git a/lib/realtime_web/channels/realtime_channel.ex b/lib/realtime_web/channels/realtime_channel.ex index 69aa1242b..b710d42c8 100644 --- a/lib/realtime_web/channels/realtime_channel.ex +++ b/lib/realtime_web/channels/realtime_channel.ex @@ -67,6 +67,7 @@ defmodule RealtimeWeb.RealtimeChannel do Realtime.UsersCounter.add(transport_pid, tenant_id) RealtimeWeb.Endpoint.subscribe(tenant_topic) Phoenix.PubSub.subscribe(Realtime.PubSub, "realtime:operations:" <> tenant_id) + Process.monitor(transport_pid) is_new_api = new_api?(params) pg_change_params = pg_change_params(is_new_api, params, channel_pid, claims, sub_topic) @@ -195,8 +196,9 @@ defmodule RealtimeWeb.RealtimeChannel do end end + @impl true def handle_info( - _any, + any, %{ assigns: %{ rate_counter: %{avg: avg}, @@ -205,32 +207,28 @@ defmodule RealtimeWeb.RealtimeChannel do } = socket ) when avg > max do + IO.inspect(any) message = "Too many messages per second" shutdown_response(socket, message) end - @impl true - def handle_info(:sync_presence = msg, socket) do PresenceHandler.track(msg, socket) end - @impl true def handle_info(%{event: "postgres_cdc_rls_down"}, socket) do pg_sub_ref = postgres_subscribe() {:noreply, assign(socket, %{pg_sub_ref: pg_sub_ref})} end - @impl true def handle_info(%{event: "postgres_cdc_down"}, socket) do pg_sub_ref = postgres_subscribe() {:noreply, assign(socket, %{pg_sub_ref: pg_sub_ref})} end - @impl true def handle_info( %{event: type, payload: payload} = msg, %{assigns: %{policies: policies}} = socket @@ -261,7 +259,6 @@ defmodule RealtimeWeb.RealtimeChannel do {:noreply, socket} end - @impl true def handle_info(:postgres_subscribe, %{assigns: %{channel_name: channel_name}} = socket) do %{ assigns: %{ @@ -308,7 +305,6 @@ defmodule RealtimeWeb.RealtimeChannel do {:noreply, assign(socket, :pg_sub_ref, postgres_subscribe(5, 10))} end - @impl true def handle_info(:confirm_token, %{assigns: %{pg_change_params: pg_change_params}} = socket) do case confirm_token(socket) do {:ok, claims, confirm_token_ref, _, _} -> @@ -326,13 +322,19 @@ defmodule RealtimeWeb.RealtimeChannel do end end - def handle_info(:disconnect, %{assigns: %{channel_name: channel_name}} = socket) do + def handle_info(%{event: "phx_leave"}, %{assigns: %{channel_name: channel_name}} = socket) do Logger.info("Received operational call to disconnect channel") push_system_message("system", socket, "ok", "Server requested disconnect", channel_name) - {:stop, :shutdown, socket} + {:stop, {:shutdown, :left}, socket} + end + + def handle_info({:shutdown, :closed}, %{assigns: %{channel_name: channel_name}} = socket) do + push_system_message("system", socket, "ok", "Server requested disconnect", channel_name) + {:stop, {:shutdown, :closed}, socket} end def handle_info(msg, socket) do + IO.inspect(msg) log_error("UnhandledSystemMessage", msg) {:noreply, socket} end @@ -432,7 +434,6 @@ defmodule RealtimeWeb.RealtimeChannel do def handle_in(type, payload, socket) do socket = count(socket) - # Log info here so that bad messages from clients won't flood Logflare # Can subscribe to a Channel with `log_level` `info` to see these messages message = "Unexpected message from client of type `#{type}` with payload: #{inspect(payload)}" @@ -442,8 +443,16 @@ defmodule RealtimeWeb.RealtimeChannel do end @impl true - def terminate(reason, _state) do + def terminate({:shutdown, :closed}, %{assigns: %{channel_name: channel_name}} = socket) do + IO.inspect("Channel terminated with reason: shutdown") + push_system_message("system", socket, "ok", "Server requested disconnect", channel_name) + :ok + end + + def terminate(reason, %{assigns: %{channel_name: channel_name}} = socket) do + IO.inspect("Channel terminated with reason: #{inspect(reason)}") Logger.debug("Channel terminated with reason: " <> inspect(reason)) + push_system_message("system", socket, "ok", "Server requested disconnect", channel_name) :telemetry.execute([:prom_ex, :plugin, :realtime, :disconnected], %{}) :ok end diff --git a/lib/realtime_web/channels/user_socket.ex b/lib/realtime_web/channels/user_socket.ex index 6fdee146f..048dc90a6 100644 --- a/lib/realtime_web/channels/user_socket.ex +++ b/lib/realtime_web/channels/user_socket.ex @@ -26,6 +26,8 @@ defmodule RealtimeWeb.UserSocket do @impl true def connect(params, socket, opts) do + IO.inspect("connect") + if Application.fetch_env!(:realtime, :secure_channels) do %{uri: %{host: host}, x_headers: headers} = opts @@ -59,6 +61,8 @@ defmodule RealtimeWeb.UserSocket do jwt_secret_dec <- Crypto.decrypt!(jwt_secret), {:ok, claims} <- ChannelsAuthorization.authorize_conn(token, jwt_secret_dec, jwt_jwks), {:ok, postgres_cdc_module} <- PostgresCdc.driver(postgres_cdc_default) do + RealtimeWeb.Endpoint.subscribe(subscribers_id(external_id)) + assigns = %RealtimeChannel.Assigns{ claims: claims, jwt_secret: jwt_secret, diff --git a/mix.lock b/mix.lock index 19dab21fa..b6fa208da 100644 --- a/mix.lock +++ b/mix.lock @@ -53,7 +53,7 @@ "octo_fetch": {:hex, :octo_fetch, "0.4.0", "074b5ecbc08be10b05b27e9db08bc20a3060142769436242702931c418695b19", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:ssl_verify_fun, "~> 1.1", [hex: :ssl_verify_fun, repo: "hexpm", optional: false]}], "hexpm", "cf8be6f40cd519d7000bb4e84adcf661c32e59369ca2827c4e20042eda7a7fc6"}, "open_api_spex": {:hex, :open_api_spex, "3.21.2", "6a704f3777761feeb5657340250d6d7332c545755116ca98f33d4b875777e1e5", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:plug, "~> 1.7", [hex: :plug, repo: "hexpm", optional: false]}, {:poison, "~> 3.0 or ~> 4.0 or ~> 5.0 or ~> 6.0", [hex: :poison, repo: "hexpm", optional: true]}, {:ymlr, "~> 2.0 or ~> 3.0 or ~> 4.0 or ~> 5.0", [hex: :ymlr, repo: "hexpm", optional: true]}], "hexpm", "f42ae6ed668b895ebba3e02773cfb4b41050df26f803f2ef634c72a7687dc387"}, "parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"}, - "phoenix": {:hex, :phoenix, "1.7.18", "5310c21443514be44ed93c422e15870aef254cf1b3619e4f91538e7529d2b2e4", [:mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:phoenix_pubsub, "~> 2.1", [hex: :phoenix_pubsub, repo: "hexpm", optional: false]}, {:phoenix_template, "~> 1.0", [hex: :phoenix_template, repo: "hexpm", optional: false]}, {:phoenix_view, "~> 2.0", [hex: :phoenix_view, repo: "hexpm", optional: true]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.7", [hex: :plug_cowboy, repo: "hexpm", optional: true]}, {:plug_crypto, "~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:websock_adapter, "~> 0.5.3", [hex: :websock_adapter, repo: "hexpm", optional: false]}], "hexpm", "1797fcc82108442a66f2c77a643a62980f342bfeb63d6c9a515ab8294870004e"}, + "phoenix": {:hex, :phoenix, "1.7.19", "36617efe5afbd821099a8b994ff4618a340a5bfb25531a1802c4d4c634017a57", [:mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:phoenix_pubsub, "~> 2.1", [hex: :phoenix_pubsub, repo: "hexpm", optional: false]}, {:phoenix_template, "~> 1.0", [hex: :phoenix_template, repo: "hexpm", optional: false]}, {:phoenix_view, "~> 2.0", [hex: :phoenix_view, repo: "hexpm", optional: true]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.7", [hex: :plug_cowboy, repo: "hexpm", optional: true]}, {:plug_crypto, "~> 1.2 or ~> 2.0", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:websock_adapter, "~> 0.5.3", [hex: :websock_adapter, repo: "hexpm", optional: false]}], "hexpm", "ba4dc14458278773f905f8ae6c2ec743d52c3a35b6b353733f64f02dfe096cd6"}, "phoenix_ecto": {:hex, :phoenix_ecto, "4.4.3", "86e9878f833829c3f66da03d75254c155d91d72a201eb56ae83482328dc7ca93", [:mix], [{:ecto, "~> 3.5", [hex: :ecto, repo: "hexpm", optional: false]}, {:phoenix_html, "~> 2.14.2 or ~> 3.0 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: true]}, {:plug, "~> 1.9", [hex: :plug, repo: "hexpm", optional: false]}], "hexpm", "d36c401206f3011fefd63d04e8ef626ec8791975d9d107f9a0817d426f61ac07"}, "phoenix_html": {:hex, :phoenix_html, "3.3.4", "42a09fc443bbc1da37e372a5c8e6755d046f22b9b11343bf885067357da21cb3", [:mix], [{:plug, "~> 1.5", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "0249d3abec3714aff3415e7ee3d9786cb325be3151e6c4b3021502c585bf53fb"}, "phoenix_live_dashboard": {:hex, :phoenix_live_dashboard, "0.8.6", "7b1f0327f54c9eb69845fd09a77accf922f488c549a7e7b8618775eb603a62c7", [:mix], [{:ecto, "~> 3.6.2 or ~> 3.7", [hex: :ecto, repo: "hexpm", optional: true]}, {:ecto_mysql_extras, "~> 0.5", [hex: :ecto_mysql_extras, repo: "hexpm", optional: true]}, {:ecto_psql_extras, "~> 0.7", [hex: :ecto_psql_extras, repo: "hexpm", optional: true]}, {:ecto_sqlite3_extras, "~> 1.1.7 or ~> 1.2.0", [hex: :ecto_sqlite3_extras, repo: "hexpm", optional: true]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:phoenix_live_view, "~> 0.19 or ~> 1.0", [hex: :phoenix_live_view, repo: "hexpm", optional: false]}, {:telemetry_metrics, "~> 0.6 or ~> 1.0", [hex: :telemetry_metrics, repo: "hexpm", optional: false]}], "hexpm", "1681ab813ec26ca6915beb3414aa138f298e17721dc6a2bde9e6eb8a62360ff6"}, diff --git a/test/integration/integration.ex b/test/integration/integration.ex new file mode 100644 index 000000000..e0b9bebd6 --- /dev/null +++ b/test/integration/integration.ex @@ -0,0 +1,64 @@ +defmodule Integration do + import Generators + + alias Realtime.Database + alias Realtime.Integration.WebsocketClient + alias Phoenix.Socket.V1 + alias Realtime.Database + alias Realtime.Integration.WebsocketClient + + @serializer V1.JSONSerializer + @secret "secure_jwt_secret" + @external_id "dev_tenant" + defp uri(port), do: "ws://#{@external_id}.localhost:#{port}/socket/websocket" + def token_valid(role, claims \\ %{}), do: generate_token(Map.put(claims, :role, role)) + def token_no_role, do: generate_token() + + def generate_token(claims \\ %{}) do + claims = + Map.merge( + %{ + ref: "localhost", + iat: System.system_time(:second), + exp: System.system_time(:second) + 604_800 + }, + claims + ) + + {:ok, generate_jwt_token(@secret, claims)} + end + + def get_connection(port, role \\ "anon", claims \\ %{}, params \\ %{vsn: "1.0.0", log_level: :warning}) do + params = Enum.reduce(params, "", fn {k, v}, acc -> "#{acc}&#{k}=#{v}" end) + uri = "#{uri(port)}?#{params}" + + with {:ok, token} <- token_valid(role, claims), + {:ok, socket} <- + WebsocketClient.connect(self(), uri, @serializer, [{"x-api-key", token}]) do + {socket, token} + end + end + + def rls_context(%{tenant: tenant} = context) do + {:ok, db_conn} = Database.connect(tenant, "realtime_test", :stop) + + clean_table(db_conn, "realtime", "messages") + topic = Map.get(context, :topic, random_string()) + message = message_fixture(tenant, %{topic: topic}) + + if policies = context[:policies] do + create_rls_policies(db_conn, policies, message) + end + + Map.put(context, :topic, message.topic) + end + + def change_tenant_configuration(limit, value) do + @external_id + |> Realtime.Tenants.get_tenant_by_external_id() + |> Realtime.Api.Tenant.changeset(%{limit => value}) + |> Realtime.Repo.update!() + + Realtime.Tenants.Cache.invalidate_tenant_cache(@external_id) + end +end diff --git a/test/integration/rt_channel_test.exs b/test/integration/rt_channel_test.exs index 623d2f787..2172f8ca8 100644 --- a/test/integration/rt_channel_test.exs +++ b/test/integration/rt_channel_test.exs @@ -1,4 +1,5 @@ Code.require_file("../support/websocket_client.exs", __DIR__) +Code.require_file("./Integration.ex", __DIR__) defmodule Realtime.Integration.RtChannelTest do # async: false due to the fact that multiple operations against the database will use the same connection @@ -7,12 +8,12 @@ defmodule Realtime.Integration.RtChannelTest do import ExUnit.CaptureLog import Generators import Mock + import Integration require Logger alias Extensions.PostgresCdcRls alias Phoenix.Socket.Message - alias Phoenix.Socket.V1 alias Postgrex alias Realtime.Api.Tenant alias Realtime.Database @@ -26,10 +27,7 @@ defmodule Realtime.Integration.RtChannelTest do alias Realtime.Tenants.Migrations @moduletag :capture_log @port 4002 - @serializer V1.JSONSerializer @external_id "dev_tenant" - @uri "ws://#{@external_id}.localhost:#{@port}/socket/websocket" - @secret "secure_jwt_secret" Application.put_env(:phoenix, Endpoint, https: false, @@ -90,7 +88,7 @@ defmodule Realtime.Integration.RtChannelTest do end test "handle postgres extension" do - {socket, _} = get_connection() + {socket, _} = get_connection(@port) topic = "realtime:any" config = %{postgres_changes: [%{event: "*", schema: "public"}]} @@ -207,7 +205,7 @@ defmodule Realtime.Integration.RtChannelTest do setup [:rls_context] test "public broadcast" do - {socket, _} = get_connection() + {socket, _} = get_connection(@port) config = %{ broadcast: %{self: true}, @@ -233,7 +231,7 @@ defmodule Realtime.Integration.RtChannelTest do test "private broadcast with valid channel with permissions sends message", %{ topic: topic } do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -269,8 +267,8 @@ defmodule Realtime.Integration.RtChannelTest do topic: "topic" test "private broadcast with valid channel a colon character sends message and won't intercept in public channels", %{topic: topic} do - {anon_socket, _} = get_connection("anon") - {socket, _} = get_connection("authenticated") + {anon_socket, _} = get_connection(@port, "anon") + {socket, _} = get_connection(@port, "authenticated") valid_topic = "realtime:#{topic}" malicious_topic = "realtime:private:#{topic}" @@ -307,13 +305,13 @@ defmodule Realtime.Integration.RtChannelTest do config = %{broadcast: %{self: true}, private: true} topic = "realtime:#{topic}" - {service_role_socket, _} = get_connection("service_role") + {service_role_socket, _} = get_connection(@port, "service_role") WebsocketClient.join(service_role_socket, topic, %{config: config}) assert_receive %Message{event: "phx_reply", topic: ^topic}, 500 assert_receive %Message{event: "presence_state"}, 500 - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") WebsocketClient.join(socket, topic, %{config: config}) assert_receive %Message{event: "phx_reply", topic: ^topic}, 500 assert_receive %Message{event: "presence_state"}, 500 @@ -353,7 +351,7 @@ defmodule Realtime.Integration.RtChannelTest do expected = "You do not have permissions to read from this Channel topic: #{topic}" topic = "realtime:#{topic}" - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") log = capture_log(fn -> @@ -379,7 +377,7 @@ defmodule Realtime.Integration.RtChannelTest do setup [:rls_context] test "public presence" do - {socket, _} = get_connection() + {socket, _} = get_connection(@port) config = %{presence: %{key: ""}, private: false} topic = "realtime:any" @@ -429,7 +427,7 @@ defmodule Realtime.Integration.RtChannelTest do ] test "private presence with read and write permissions will be able to track and receive presence changes", %{topic: topic} do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{presence: %{key: ""}, private: true} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -464,8 +462,8 @@ defmodule Realtime.Integration.RtChannelTest do @tag policies: [:authenticated_read_broadcast_and_presence] test "private presence with read permissions will be able to receive presence changes but won't be able to track", %{topic: topic} do - {socket, _} = get_connection("authenticated") - {secondary_socket, _} = get_connection("service_role") + {socket, _} = get_connection(@port, "authenticated") + {secondary_socket, _} = get_connection(@port, "service_role") config = fn key -> %{presence: %{key: key}, private: true} end topic = "realtime:#{topic}" @@ -533,30 +531,13 @@ defmodule Realtime.Integration.RtChannelTest do describe "token handling" do setup [:rls_context] - @tag policies: [ - :authenticated_read_broadcast_and_presence, - :authenticated_write_broadcast_and_presence - ] - test "invalid JWT with expired token" do - assert capture_log(fn -> - get_connection("authenticated", %{:exp => System.system_time(:second) - 1000}) - end) =~ "InvalidJWTToken: Token as expired 1000 seconds ago" - end - - test "token required the role key" do - {:ok, token} = token_no_role() - - assert {:error, %{status_code: 403}} = - WebsocketClient.connect(self(), @uri, @serializer, [{"x-api-key", token}]) - end - @tag policies: [ :authenticated_read_broadcast_and_presence, :authenticated_write_broadcast_and_presence ] test "on new access_token and channel is private policies are reevaluated for read policy", %{topic: topic} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") realtime_topic = "realtime:#{topic}" @@ -597,7 +578,7 @@ defmodule Realtime.Integration.RtChannelTest do ] test "on new access_token and channel is private policies are reevaluated for write policy", %{topic: topic, tenant: tenant} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") realtime_topic = "realtime:#{topic}" WebsocketClient.join(socket, realtime_topic, %{ @@ -652,7 +633,7 @@ defmodule Realtime.Integration.RtChannelTest do test "on new access_token and channel is public policies are not reevaluated", %{topic: topic} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") {:ok, new_token} = token_valid("anon") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -671,7 +652,7 @@ defmodule Realtime.Integration.RtChannelTest do test "on empty string access_token the socket sends an error message", %{topic: topic} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -698,7 +679,7 @@ defmodule Realtime.Integration.RtChannelTest do sub = random_string() {socket, access_token} = - get_connection("authenticated", %{sub: sub}) + get_connection(@port, "authenticated", %{sub: sub}) config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -734,7 +715,7 @@ defmodule Realtime.Integration.RtChannelTest do sub = random_string() {socket, access_token} = - get_connection("authenticated", %{sub: sub}, %{log_level: :warning}) + get_connection(@port, "authenticated", %{sub: sub}, %{log_level: :warning}) config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -758,7 +739,7 @@ defmodule Realtime.Integration.RtChannelTest do test "missing claims close connection", %{topic: topic} do {socket, access_token} = - get_connection("authenticated") + get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -789,7 +770,7 @@ defmodule Realtime.Integration.RtChannelTest do test "checks token periodically", %{topic: topic} do {socket, access_token} = - get_connection("authenticated") + get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -824,7 +805,7 @@ defmodule Realtime.Integration.RtChannelTest do end test "token expires in between joins", %{topic: topic} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -855,7 +836,7 @@ defmodule Realtime.Integration.RtChannelTest do end test "token loses claims in between joins", %{topic: topic} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -887,7 +868,7 @@ defmodule Realtime.Integration.RtChannelTest do end test "token is badly formatted in between joins", %{topic: topic} do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{topic}" @@ -925,7 +906,7 @@ defmodule Realtime.Integration.RtChannelTest do in_series([:_, :_, :_], [&passthrough([&1, &2, &3]), {:error, "RPC Error"}]) ]} ] do - {socket, access_token} = get_connection("authenticated") + {socket, access_token} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} realtime_topic = "realtime:#{topic}" @@ -977,7 +958,7 @@ defmodule Realtime.Integration.RtChannelTest do db_conn: db_conn, table_name: table_name } do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -1020,7 +1001,7 @@ defmodule Realtime.Integration.RtChannelTest do } do value = random_string() - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -1068,7 +1049,7 @@ defmodule Realtime.Integration.RtChannelTest do db_conn: db_conn, table_name: table_name } do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -1109,7 +1090,7 @@ defmodule Realtime.Integration.RtChannelTest do topic: topic, db_conn: db_conn } do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} full_topic = "realtime:#{topic}" @@ -1145,7 +1126,7 @@ defmodule Realtime.Integration.RtChannelTest do topic: topic, db_conn: db_conn } do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} full_topic = "realtime:#{topic}" @@ -1188,7 +1169,7 @@ defmodule Realtime.Integration.RtChannelTest do } do change_tenant_configuration(:private_only, true) - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -1218,7 +1199,7 @@ defmodule Realtime.Integration.RtChannelTest do Process.sleep(100) - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} topic = "realtime:#{topic}" WebsocketClient.join(socket, topic, %{config: config}) @@ -1228,93 +1209,6 @@ defmodule Realtime.Integration.RtChannelTest do end end - describe "sensitive information updates" do - setup [:rls_context] - - test "on jwks the socket closes and sends a system message", %{topic: topic} do - {socket, _} = get_connection("authenticated") - config = %{broadcast: %{self: true}, private: false} - realtime_topic = "realtime:#{topic}" - - WebsocketClient.join(socket, realtime_topic, %{config: config}) - - assert_receive %Message{event: "phx_reply"}, 500 - assert_receive %Message{event: "presence_state"}, 500 - tenant = Tenants.get_tenant_by_external_id(@external_id) - Realtime.Api.update_tenant(tenant, %{jwt_jwks: %{keys: ["potato"]}}) - - assert_receive %Message{ - topic: ^realtime_topic, - event: "system", - payload: %{ - "extension" => "system", - "message" => "Server requested disconnect", - "status" => "ok" - } - }, - 500 - end - - test "on jwt_secret the socket closes and sends a system message", %{topic: topic} do - {socket, _} = get_connection("authenticated") - config = %{broadcast: %{self: true}, private: false} - realtime_topic = "realtime:#{topic}" - - WebsocketClient.join(socket, realtime_topic, %{config: config}) - - assert_receive %Message{event: "phx_reply"}, 500 - assert_receive %Message{event: "presence_state"}, 500 - - tenant = Tenants.get_tenant_by_external_id(@external_id) - Realtime.Api.update_tenant(tenant, %{jwt_secret: "potato"}) - - assert_receive %Message{ - topic: ^realtime_topic, - event: "system", - payload: %{ - "extension" => "system", - "message" => "Server requested disconnect", - "status" => "ok" - } - }, - 500 - end - - test "on other param changes the socket won't close and no message is sent", %{topic: topic} do - {socket, _} = get_connection("authenticated") - config = %{broadcast: %{self: true}, private: false} - realtime_topic = "realtime:#{topic}" - - WebsocketClient.join(socket, realtime_topic, %{config: config}) - - assert_receive %Message{event: "phx_reply"}, 500 - assert_receive %Message{event: "presence_state"}, 500 - - tenant = Tenants.get_tenant_by_external_id(@external_id) - Realtime.Api.update_tenant(tenant, %{max_concurrent_users: 100}) - - refute_receive %Message{ - topic: ^realtime_topic, - event: "system", - payload: %{ - "extension" => "system", - "message" => "Server requested disconnect", - "status" => "ok" - } - }, - 500 - end - - test "invalid JWT with expired token" do - log = - capture_log(fn -> - get_connection("authenticated", %{:exp => System.system_time(:second) - 1000}) - end) - - assert log =~ "InvalidJWTToken: Token as expired 1000 seconds ago" - end - end - describe "rate limits" do setup [:rls_context] @@ -1324,7 +1218,7 @@ defmodule Realtime.Integration.RtChannelTest do change_tenant_configuration(:max_concurrent_users, 1) - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{random_string()}" WebsocketClient.join(socket, realtime_topic, %{config: config}) @@ -1352,7 +1246,7 @@ defmodule Realtime.Integration.RtChannelTest do change_tenant_configuration(:max_events_per_second, 1) - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{random_string()}" WebsocketClient.join(socket, realtime_topic, %{config: config}) @@ -1383,7 +1277,7 @@ defmodule Realtime.Integration.RtChannelTest do change_tenant_configuration(:max_channels_per_client, 1) - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic_1 = "realtime:#{random_string()}" realtime_topic_2 = "realtime:#{random_string()}" @@ -1427,7 +1321,7 @@ defmodule Realtime.Integration.RtChannelTest do change_tenant_configuration(:max_joins_per_second, 1) - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: false} realtime_topic = "realtime:#{random_string()}" @@ -1457,7 +1351,7 @@ defmodule Realtime.Integration.RtChannelTest do @tag role: "authenticated", policies: [:broken_read_presence, :broken_write_presence] test "handle failing rls policy" do - {socket, _} = get_connection("authenticated") + {socket, _} = get_connection(@port, "authenticated") config = %{broadcast: %{self: true}, private: true} topic = random_string() realtime_topic = "realtime:#{topic}" @@ -1485,72 +1379,6 @@ defmodule Realtime.Integration.RtChannelTest do end end - test "handle empty topic by closing the socket" do - {socket, _} = get_connection("authenticated") - config = %{broadcast: %{self: true}, private: false} - realtime_topic = "realtime:" - - WebsocketClient.join(socket, realtime_topic, %{config: config}) - - assert_receive %Message{ - event: "phx_reply", - payload: %{ - "response" => %{"reason" => "You must provide a topic name"}, - "status" => "error" - } - }, - 500 - - refute_receive %Message{event: "phx_reply"} - refute_receive %Message{event: "presence_state"} - end - - defp token_valid(role, claims \\ %{}), do: generate_token(Map.put(claims, :role, role)) - defp token_no_role, do: generate_token() - - defp generate_token(claims \\ %{}) do - claims = - Map.merge( - %{ - ref: "localhost", - iat: System.system_time(:second), - exp: System.system_time(:second) + 604_800 - }, - claims - ) - - {:ok, generate_jwt_token(@secret, claims)} - end - - defp get_connection( - role \\ "anon", - claims \\ %{}, - params \\ %{vsn: "1.0.0", log_level: :warning} - ) do - params = Enum.reduce(params, "", fn {k, v}, acc -> "#{acc}&#{k}=#{v}" end) - uri = "#{@uri}?#{params}" - - with {:ok, token} <- token_valid(role, claims), - {:ok, socket} <- - WebsocketClient.connect(self(), uri, @serializer, [{"x-api-key", token}]) do - {socket, token} - end - end - - def rls_context(%{tenant: tenant} = context) do - {:ok, db_conn} = Database.connect(tenant, "realtime_test", :stop) - - clean_table(db_conn, "realtime", "messages") - topic = Map.get(context, :topic, random_string()) - message = message_fixture(tenant, %{topic: topic}) - - if policies = context[:policies] do - create_rls_policies(db_conn, policies, message) - end - - Map.put(context, :topic, message.topic) - end - def setup_trigger(%{tenant: tenant, topic: topic} = context) do Realtime.Tenants.Connect.shutdown(@external_id) Process.sleep(500) @@ -1601,13 +1429,4 @@ defmodule Realtime.Integration.RtChannelTest do |> Map.put(:db_conn, db_conn) |> Map.put(:table_name, random_name) end - - defp change_tenant_configuration(limit, value) do - @external_id - |> Realtime.Tenants.get_tenant_by_external_id() - |> Realtime.Api.Tenant.changeset(%{limit => value}) - |> Realtime.Repo.update!() - - Realtime.Tenants.Cache.invalidate_tenant_cache(@external_id) - end end diff --git a/test/integration/user_socket_test.exs b/test/integration/user_socket_test.exs new file mode 100644 index 000000000..3bc183c51 --- /dev/null +++ b/test/integration/user_socket_test.exs @@ -0,0 +1,204 @@ +Code.require_file("../support/websocket_client.exs", __DIR__) +Code.require_file("./Integration.ex", __DIR__) + +defmodule RealtimeWeb.UserSocketTest do + use RealtimeWeb.ConnCase, async: false + import ExUnit.CaptureLog + import Integration + + alias Phoenix.Socket.Message + alias Phoenix.Socket.V1 + alias Realtime.Api.Tenant + alias Realtime.Integration.WebsocketClient + alias Realtime.Repo + alias Realtime.Tenants + alias Realtime.Tenants.Cache + alias Realtime.Tenants.Migrations + alias RealtimeWeb.UserSocketTest.Endpoint + + @moduletag :capture_log + @port 4003 + @serializer V1.JSONSerializer + @external_id "dev_tenant" + @uri "ws://#{@external_id}.localhost:#{@port}/socket/websocket" + + Application.put_env(:phoenix, Endpoint, + https: false, + http: [port: @port], + debug_errors: false, + server: true, + pubsub_server: __MODULE__, + secret_key_base: String.duplicate("a", 64) + ) + + Application.delete_env(:joken, :current_time_adapter) + + defmodule Endpoint do + use Phoenix.Endpoint, otp_app: :phoenix + + @session_config store: :cookie, + key: "_hello_key", + signing_salt: "change_me" + + socket("/socket", RealtimeWeb.UserSocket, + websocket: [ + connect_info: [:peer_data, :uri, :x_headers], + fullsweep_after: 20, + max_frame_size: 8_000_000 + ], + longpoll: true + ) + + plug(Plug.Session, @session_config) + plug(:fetch_session) + plug(Plug.CSRFProtection) + plug(:put_session) + + defp put_session(conn, _) do + conn + |> put_session(:from_session, "123") + |> send_resp(200, Plug.CSRFProtection.get_csrf_token()) + end + end + + defmodule Token do + use Joken.Config + end + + setup do + Cache.invalidate_tenant_cache(@external_id) + Process.sleep(500) + [tenant] = Tenant |> Repo.all() |> Repo.preload(:extensions) + :ok = Migrations.run_migrations(tenant) + %{tenant: tenant} + end + + setup_all do + capture_log(fn -> start_supervised!(Endpoint) end) + start_supervised!({Phoenix.PubSub, name: __MODULE__}) + :ok + end + + describe "token handling on connect" do + setup [:rls_context] + + @tag policies: [ + :authenticated_read_broadcast_and_presence, + :authenticated_write_broadcast_and_presence + ] + test "invalid JWT with expired token" do + assert capture_log(fn -> + get_connection(@port, "authenticated", %{:exp => System.system_time(:second) - 1000}) + end) =~ "InvalidJWTToken: Token as expired 1000 seconds ago" + end + + test "token required the role key" do + {:ok, token} = token_no_role() + + assert {:error, %{status_code: 403}} = + WebsocketClient.connect(self(), @uri, @serializer, [{"x-api-key", token}]) + end + end + + describe "disconnecting users" do + setup do + {socket, _} = get_connection(@port, "authenticated") + on_exit(fn -> WebsocketClient.close(socket) end) + %{socket: socket} + end + + test "on jwt_jwks the socket closes and sends a system message", %{socket: socket} do + config = %{broadcast: %{self: true}, private: false} + + topics = + for _ <- 1..3, reduce: [] do + acc -> + topic = "realtime:#{random_string()}" + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{topic: ^topic, event: "phx_reply"}, 500 + assert_receive %Message{topic: ^topic, event: "presence_state"}, 500 + + [topic | acc] + end + + tenant = Tenants.get_tenant_by_external_id(@external_id) + Realtime.Api.update_tenant(tenant, %{jwt_jwks: %{keys: ["potato"]}}) + IO.inspect("Sleeping for 1 second") + Process.sleep(1000) + + for topic <- topics do + WebsocketClient.send_event(socket, topic, "broadcast", %{event: random_string()}) + refute_receive %Message{topic: ^topic} + end + end + + test "on jwt_secret the socket closes and sends a system message", %{socket: socket} do + config = %{broadcast: %{self: true}, private: false} + + topics = + for _ <- 1..3, reduce: [] do + acc -> + topic = "realtime:#{random_string()}" + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{topic: ^topic, event: "phx_reply"}, 500 + assert_receive %Message{topic: ^topic, event: "presence_state"}, 500 + + [topic | acc] + end + + tenant = Tenants.get_tenant_by_external_id(@external_id) + Realtime.Api.update_tenant(tenant, %{jwt_secret: random_string()}) + IO.inspect("Sleeping for 1 second") + Process.sleep(1000) + + for topic <- topics do + WebsocketClient.send_event(socket, topic, "broadcast", %{event: random_string()}) + refute_receive %Message{topic: ^topic} + end + end + + test "on other param changes the socket won't close and no message is sent", %{socket: socket} do + config = %{broadcast: %{self: true}, private: false} + + topics = + for _ <- 1..3, reduce: [] do + acc -> + topic = "realtime:#{random_string()}" + WebsocketClient.join(socket, topic, %{config: config}) + assert_receive %Message{topic: ^topic, event: "phx_reply"}, 500 + assert_receive %Message{topic: ^topic, event: "presence_state"}, 500 + + [topic | acc] + end + + tenant = Tenants.get_tenant_by_external_id(@external_id) + Realtime.Api.update_tenant(tenant, %{max_concurrent_users: 100}) + Process.sleep(1000) + + for topic <- topics do + WebsocketClient.send_event(socket, topic, "broadcast", %{event: random_string()}) + assert_receive %Message{topic: ^topic} + end + end + end + + test "handle empty topic by closing the socket" do + {socket, _} = get_connection(@port, "authenticated") + config = %{broadcast: %{self: true}, private: false} + realtime_topic = "realtime:" + + WebsocketClient.join(socket, realtime_topic, %{config: config}) + + assert_receive %Message{ + event: "phx_reply", + payload: %{ + "response" => %{"reason" => "You must provide a topic name"}, + "status" => "error" + } + }, + 500 + + refute_receive %Message{event: "phx_reply"} + refute_receive %Message{event: "presence_state"} + end +end diff --git a/test/realtime/api_test.exs b/test/realtime/api_test.exs index 0ce268acd..9b4577edd 100644 --- a/test/realtime/api_test.exs +++ b/test/realtime/api_test.exs @@ -3,6 +3,7 @@ defmodule Realtime.ApiTest do import Mock + alias Phoenix.Socket.Broadcast alias Realtime.Api alias Realtime.Api.Extensions alias Realtime.Api.Tenant @@ -60,8 +61,8 @@ defmodule Realtime.ApiTest do tenants = tenants ++ dev_tenant Enum.each(tenants, fn tenant -> - :ok = - Phoenix.PubSub.subscribe(Realtime.PubSub, "realtime:operations:" <> tenant.external_id) + RealtimeWeb.Endpoint.subscribe("user_socket:" <> tenant.external_id) + Phoenix.PubSub.subscribe(Realtime.PubSub, "realtime:operations:" <> tenant.external_id) end) %{tenants: tenants} @@ -128,14 +129,14 @@ defmodule Realtime.ApiTest do tenants: [tenant | _] } do assert {:ok, %Tenant{}} = Api.update_tenant(tenant, %{jwt_jwks: %{keys: ["test"]}}) - assert_receive :disconnect + assert_receive %Broadcast{topic: "user_socket:external_id1", event: "disconnect", payload: %{}} end test "update_tenant/2 with valid data and jwt_secret change will send disconnect event", %{ tenants: [tenant | _] } do assert {:ok, %Tenant{}} = Api.update_tenant(tenant, %{jwt_secret: "potato"}) - assert_receive :disconnect + assert_receive %Broadcast{topic: "user_socket:external_id1", event: "disconnect", payload: %{}} end test "update_tenant/2 with valid data but not updating jwt_secret or jwt_jwks won't send event", diff --git a/test/support/websocket_client.exs b/test/support/websocket_client.exs index 743a58b08..c488a4477 100644 --- a/test/support/websocket_client.exs +++ b/test/support/websocket_client.exs @@ -93,24 +93,10 @@ defmodule Realtime.Integration.WebsocketClient do @doc false def handle_call({:connect, url, headers}, from, state) do uri = URI.parse(url) - - http_scheme = - case uri.scheme do - "ws" -> :http - "wss" -> :https - end - - ws_scheme = - case uri.scheme do - "ws" -> :ws - "wss" -> :wss - end - - path = - case uri.query do - nil -> uri.path - query -> uri.path <> "?" <> query - end + http_scheme = :http + ws_scheme = :ws + query = if uri.query, do: "?" <> uri.query, else: "" + path = uri.path <> "?" <> query with {:ok, conn} <- Mint.HTTP.connect(http_scheme, uri.host, uri.port), {:ok, conn, ref} <- Mint.WebSocket.upgrade(ws_scheme, conn, path, headers) do