diff --git a/deps/rabbit/src/rabbit_confirms.erl b/deps/rabbit/src/rabbit_confirms.erl index 2ea00bc9cb39..41a858f932f5 100644 --- a/deps/rabbit/src/rabbit_confirms.erl +++ b/deps/rabbit/src/rabbit_confirms.erl @@ -30,7 +30,8 @@ -opaque state() :: #?MODULE{}. -export_type([ - state/0 + state/0, + mx/0 ]). -spec init() -> state(). diff --git a/deps/rabbitmq_cli/test/plugins/disable_plugins_command_test.exs b/deps/rabbitmq_cli/test/plugins/disable_plugins_command_test.exs index 26a8c787fa2e..8876d660db7e 100644 --- a/deps/rabbitmq_cli/test/plugins/disable_plugins_command_test.exs +++ b/deps/rabbitmq_cli/test/plugins/disable_plugins_command_test.exs @@ -258,23 +258,26 @@ defmodule DisablePluginsCommandTest do test "disabling a dependency disables all plugins that depend on it", context do assert {:stream, test_stream} = @command.run(["amqp_client"], context[:opts]) result = Enum.to_list(test_stream) - expected_list = [:rabbitmq_exchange_federation, :rabbitmq_federation, :rabbitmq_federation_common, :rabbitmq_queue_federation, :rabbitmq_stomp] + expected_disabled = [:rabbitmq_exchange_federation, :rabbitmq_federation, + :rabbitmq_federation_common, :rabbitmq_queue_federation] expected = [ - [], + [:rabbitmq_stomp], %{ mode: :online, started: [], - stopped: expected_list, - disabled: expected_list, - set: [] + stopped: expected_disabled, + disabled: expected_disabled, + set: [:rabbitmq_stomp] } ] assert normalize_stream_result(expected) == normalize_stream_result(result) - assert {:ok, [[]]} == :file.consult(context[:opts][:enabled_plugins_file]) + assert {:ok, [[:rabbitmq_stomp]]} == :file.consult(context[:opts][:enabled_plugins_file]) + # Before native STOMP, this would be empty because STOMP depended on + # amqp_client. Native STOMP does not depend on amqp_client. result = :rabbit_misc.rpc_call(context[:opts][:node], :rabbit_plugins, :active, []) - assert Enum.empty?(result) + assert Enum.sort(result) == [:rabbitmq_stomp] end test "formats enabled plugins mismatch errors", context do diff --git a/deps/rabbitmq_cli/test/plugins/enable_plugins_command_test.exs b/deps/rabbitmq_cli/test/plugins/enable_plugins_command_test.exs index 6ee5b2c9571e..4b02b47400c7 100644 --- a/deps/rabbitmq_cli/test/plugins/enable_plugins_command_test.exs +++ b/deps/rabbitmq_cli/test/plugins/enable_plugins_command_test.exs @@ -205,7 +205,8 @@ defmodule EnablePluginsCommandTest do Enum.to_list(test_stream0) check_plugins_enabled([:rabbitmq_stomp], context) - assert_equal_sets([:amqp_client, :rabbitmq_stomp], currently_active_plugins(context)) + # Native STOMP does not depend on amqp_client + assert_equal_sets([:rabbitmq_stomp], currently_active_plugins(context)) {:stream, test_stream1} = @command.run(["rabbitmq_federation"], context[:opts]) diff --git a/deps/rabbitmq_cli/test/plugins/set_plugins_command_test.exs b/deps/rabbitmq_cli/test/plugins/set_plugins_command_test.exs index e78131d41e28..7ad0515d236a 100644 --- a/deps/rabbitmq_cli/test/plugins/set_plugins_command_test.exs +++ b/deps/rabbitmq_cli/test/plugins/set_plugins_command_test.exs @@ -127,7 +127,8 @@ defmodule SetPluginsCommandTest do assert {:ok, [[:rabbitmq_stomp]]} = :file.consult(context[:opts][:enabled_plugins_file]) - assert [:amqp_client, :rabbitmq_stomp] = + # Native STOMP does not depend on amqp_client + assert [:rabbitmq_stomp] = Enum.sort(:rabbit_misc.rpc_call(context[:opts][:node], :rabbit_plugins, :active, [])) assert {:stream, test_stream1} = @command.run(["rabbitmq_federation"], context[:opts]) diff --git a/deps/rabbitmq_mqtt/test/protocol_interop_SUITE.erl b/deps/rabbitmq_mqtt/test/protocol_interop_SUITE.erl index b36b902f793d..2fce216ddeec 100644 --- a/deps/rabbitmq_mqtt/test/protocol_interop_SUITE.erl +++ b/deps/rabbitmq_mqtt/test/protocol_interop_SUITE.erl @@ -482,12 +482,12 @@ amqp_mqtt(Qos, Config) -> mqtt_stomp_mqtt(Config) -> {ok, StompC0} = stomp_connect(Config), - ok = stomp_send(StompC0, "SUBSCRIBE", [{"destination", "/topic/t.1"}, - {"receipt", "my-receipt"}, - {"id", "subscription-888"}, - {"durable", "true"}]), - {#stomp_frame{command = "RECEIPT", - headers = [{"receipt-id","my-receipt"}]}, StompC1} = stomp_recv(StompC0), + ok = stomp_send(StompC0, 'SUBSCRIBE', [{<<"destination">>, <<"/topic/t.1">>}, + {<<"receipt">>, <<"my-receipt">>}, + {<<"id">>, <<"subscription-888">>}, + {<<"durable">>, <<"true">>}]), + {#stomp_frame{command = 'RECEIPT', + headers = #{<<"receipt-id">> := <<"my-receipt">>}}, StompC1} = stomp_recv(StompC0), %% MQTT 5.0 to STOMP 1.2 C = connect(<<"my-mqtt-client">>, Config), @@ -513,40 +513,39 @@ mqtt_stomp_mqtt(Config) -> 'User-Property' => UserProperty}, RequestPayload, [{qos, 1}]), - {#stomp_frame{command = "MESSAGE", - headers = Headers0, - body_iolist = Body} = Msg1, StompC2} = stomp_recv(StompC1), + {#stomp_frame{command = 'MESSAGE', + headers = Headers, + body_iolist_rev = BodyRev} = Msg1, StompC2} = stomp_recv(StompC1), + Body = lists:reverse(BodyRev), ?assertEqual(RequestPayload, iolist_to_binary(Body)), - Headers1 = maps:from_list(Headers0), - Headers = maps:map(fun(_K, V) -> unicode:characters_to_binary(V) end, Headers1), ct:pal("Received STOMP 1.2 message:~n~p~n" "with headers map:~n~p", [Msg1, Headers]), ?assertMatch( - #{"content-type" := ContentType, - "correlation-id" := Correlation, - "destination" := <<"/topic/t.1">>, + #{<<"content-type">> := ContentType, + <<"correlation-id">> := Correlation, + <<"destination">> := <<"/topic/t.1">>, %% With Native STOMP, this should be translated to %% reply-to: /topic/response.topic - "x-reply-to-topic" := <<"response.topic">>, - "subscription" := <<"subscription-888">>, - "persistent" := <<"true">>, + <<"x-reply-to-topic">> := <<"response.topic">>, + <<"subscription">> := <<"subscription-888">>, + <<"persistent">> := <<"true">>, %% The STOMP spec mandates headers to be encoded as UTF-8, but unfortunately the RabbitMQ %% STOMP implementation (as of 3.13) does not adhere and therefore does not provide UTF-8 support. - % "rabbit🐇" := <<"carrot🥕"/utf8>>, - % "x-rabbit🐇" := <<"carrot🥕"/utf8>>, - "key" := <<"val1">>, - "x-key" := <<"val1">> + % <<"rabbit🐇"/utf8>> := <<"carrot🥕"/utf8>>, + % <<"x-rabbit🐇"/utf8>> := <<"carrot🥕"/utf8>>, + <<"key">> := <<"val1">>, + <<"x-key">> := <<"val1">> }, Headers), %% STOMP 1.2 to MQTT 5.0 - ok = stomp_send(StompC2, "SEND", - [{"destination", "/topic/response.topic"}, - {"persistent", "true"}, - {"content-type", "application/json"}, - {"correlation-id", binary_to_list(Correlation)}, - {"x-key", "val4"}], - ["{\"my\" : \"response\"}"]), + ok = stomp_send(StompC2, 'SEND', + [{<<"destination">>, <<"/topic/response.topic">>}, + {<<"persistent">>, <<"true">>}, + {<<"content-type">>, <<"application/json">>}, + {<<"correlation-id">>, Correlation}, + {<<"x-key">>, <<"val4">>}], + [<<"{\"my\" : \"response\"}">>]), ok = stomp_disconnect(StompC2), receive {publish, MqttMsg} -> @@ -713,12 +712,12 @@ stomp_connect(Config) -> Port = rabbit_ct_broker_helpers:get_node_config(Config, 0, tcp_port_stomp), {ok, Sock} = gen_tcp:connect(localhost, Port, [{active, false}, binary]), Client0 = {Sock, []}, - stomp_send(Client0, "CONNECT", [{"accept-version", "1.2"}]), - {#stomp_frame{command = "CONNECTED"}, Client1} = stomp_recv(Client0), + stomp_send(Client0, 'CONNECT', [{<<"accept-version">>, <<"1.2">>}]), + {#stomp_frame{command = 'CONNECTED'}, Client1} = stomp_recv(Client0), {ok, Client1}. stomp_disconnect(Client = {Sock, _}) -> - stomp_send(Client, "DISCONNECT"), + stomp_send(Client, 'DISCONNECT'), gen_tcp:close(Sock). stomp_send(Client, Command) -> @@ -729,9 +728,9 @@ stomp_send(Client, Command, Headers) -> stomp_send({Sock, _}, Command, Headers, Body) -> Frame = rabbit_stomp_frame:serialize( - #stomp_frame{command = list_to_binary(Command), - headers = Headers, - body_iolist = Body}), + #stomp_frame{command = Command, + headers = maps:from_list(Headers), + body_iolist_rev = Body}), gen_tcp:send(Sock, Frame). stomp_recv({_Sock, []} = Client) -> diff --git a/deps/rabbitmq_stomp/Makefile b/deps/rabbitmq_stomp/Makefile index e5bb2d2959ca..31573a15e5fd 100644 --- a/deps/rabbitmq_stomp/Makefile +++ b/deps/rabbitmq_stomp/Makefile @@ -9,7 +9,7 @@ define PROJECT_ENV {passcode, <<"guest">>}]}, {default_vhost, <<"/">>}, {default_topic_exchange, <<"amq.topic">>}, - {default_nack_requeue, true}, + {default_nack_requeue, true}, {ssl_cert_login, false}, {implicit_connect, false}, {tcp_listeners, [61613]}, @@ -30,7 +30,7 @@ define PROJECT_APP_EXTRA_KEYS {broker_version_requirements, []} endef -DEPS = ranch rabbit_common rabbit amqp_client +DEPS = ranch rabbit_common rabbit TEST_DEPS = rabbitmq_ct_helpers rabbitmq_ct_client_helpers rabbitmq_management proper PLT_APPS += rabbitmq_cli elixir ssl @@ -40,3 +40,9 @@ DEP_PLUGINS = rabbit_common/mk/rabbitmq-plugin.mk include ../../rabbitmq-components.mk include ../../erlang.mk + +# Regenerate per-function CT test cases from Python test files. +# The generated .hrl is committed; this rule updates it when Python sources change. +# Runs only when python3 is available. +test/python_SUITE_generated.hrl:: $(wildcard test/python_SUITE_data/src/*.py) test/generate_python_tests.py + $(if $(shell which python3 2>/dev/null),python3 test/generate_python_tests.py test/python_SUITE_data/src $@,@echo "python3 not found, using committed $@") diff --git a/deps/rabbitmq_stomp/include/rabbit_stomp.hrl b/deps/rabbitmq_stomp/include/rabbit_stomp.hrl index 3dd542dd8a3c..fe6c796c79ae 100644 --- a/deps/rabbitmq_stomp/include/rabbit_stomp.hrl +++ b/deps/rabbitmq_stomp/include/rabbit_stomp.hrl @@ -9,16 +9,27 @@ default_passcode, force_default_creds = false, implicit_connect, - ssl_cert_login}). + ssl_cert_login, + max_header_length, + max_headers, + max_body_length}). + -define(SUPPORTED_VERSIONS, ["1.0", "1.1", "1.2"]). +-define(STOMP_PROTO_V1_0, 'STOMP 1.0'). +-define(STOMP_PROTO_V1_1, 'STOMP 1.1'). +-define(STOMP_PROTO_V1_2, 'STOMP 1.2'). + + + -define(INFO_ITEMS, [conn_name, + name, + user, connection, connection_state, session_id, - channel, version, implicit_connect, auth_login, @@ -29,6 +40,7 @@ peer_host, peer_port, protocol, + connected_at, channels, channel_max, frame_max, @@ -43,3 +55,18 @@ -define(DEFAULT_MAX_FRAME_SIZE, 4 * 1024 * 1024). -define(DEFAULT_MAX_FRAME_SIZE_UNAUTHENTICATED, 65536). + +-define(SIMPLE_METRICS, + [pid, + recv_oct, + send_oct, + reductions]). +-define(OTHER_METRICS, + [recv_cnt, + send_cnt, + send_pend, + garbage_collection, + state, + timeout]). + +-type send_fun() :: fun ((iodata()) -> ok). diff --git a/deps/rabbitmq_stomp/include/rabbit_stomp_frame.hrl b/deps/rabbitmq_stomp/include/rabbit_stomp_frame.hrl index 925cc79d2281..33246f823f33 100644 --- a/deps/rabbitmq_stomp/include/rabbit_stomp_frame.hrl +++ b/deps/rabbitmq_stomp/include/rabbit_stomp_frame.hrl @@ -5,4 +5,9 @@ %% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term “Broadcom” refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. %% --record(stomp_frame, {command, headers, body_iolist}). +-record(stomp_frame, {command, headers, body_iolist_rev}). + +-record(stomp_parser_config, {max_header_length = 1024*100, + max_headers = 100, + max_body_length = 1024*1024*100}). +-define(DEFAULT_STOMP_PARSER_CONFIG, #stomp_parser_config{}). diff --git a/deps/rabbitmq_stomp/include/rabbit_stomp_headers.hrl b/deps/rabbitmq_stomp/include/rabbit_stomp_headers.hrl index d9a72c0e49f9..82dca4d519e5 100644 --- a/deps/rabbitmq_stomp/include/rabbit_stomp_headers.hrl +++ b/deps/rabbitmq_stomp/include/rabbit_stomp_headers.hrl @@ -5,57 +5,59 @@ %% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term “Broadcom” refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. %% --define(HEADER_ACCEPT_VERSION, "accept-version"). --define(HEADER_ACK, "ack"). --define(HEADER_AMQP_MESSAGE_ID, "amqp-message-id"). --define(HEADER_APP_ID, "app-id"). --define(HEADER_AUTO_DELETE, "auto-delete"). --define(HEADER_CONTENT_ENCODING, "content-encoding"). --define(HEADER_CONTENT_LENGTH, "content-length"). --define(HEADER_CONTENT_TYPE, "content-type"). --define(HEADER_CORRELATION_ID, "correlation-id"). --define(HEADER_DESTINATION, "destination"). --define(HEADER_DURABLE, "durable"). --define(HEADER_EXPIRATION, "expiration"). --define(HEADER_EXCLUSIVE, "exclusive"). --define(HEADER_HEART_BEAT, "heart-beat"). --define(HEADER_HOST, "host"). --define(HEADER_ID, "id"). --define(HEADER_LOGIN, "login"). --define(HEADER_MESSAGE_ID, "message-id"). --define(HEADER_PASSCODE, "passcode"). --define(HEADER_PERSISTENT, "persistent"). --define(HEADER_PREFETCH_COUNT, "prefetch-count"). --define(HEADER_X_STREAM_OFFSET, "x-stream-offset"). --define(HEADER_X_STREAM_FILTER, "x-stream-filter"). --define(HEADER_X_STREAM_MATCH_UNFILTERED, "x-stream-match-unfiltered"). --define(HEADER_PRIORITY, "priority"). --define(HEADER_X_PRIORITY, "x-priority"). --define(HEADER_RECEIPT, "receipt"). --define(HEADER_REDELIVERED, "redelivered"). --define(HEADER_REPLY_TO, "reply-to"). --define(HEADER_SERVER, "server"). --define(HEADER_SESSION, "session"). --define(HEADER_SUBSCRIPTION, "subscription"). --define(HEADER_TIMESTAMP, "timestamp"). --define(HEADER_TRANSACTION, "transaction"). --define(HEADER_TYPE, "type"). --define(HEADER_USER_ID, "user-id"). --define(HEADER_VERSION, "version"). --define(HEADER_X_DEAD_LETTER_EXCHANGE, "x-dead-letter-exchange"). --define(HEADER_X_DEAD_LETTER_ROUTING_KEY, "x-dead-letter-routing-key"). --define(HEADER_X_EXPIRES, "x-expires"). --define(HEADER_X_MAX_LENGTH, "x-max-length"). --define(HEADER_X_MAX_AGE, "x-max-age"). --define(HEADER_X_MAX_LENGTH_BYTES, "x-max-length-bytes"). --define(HEADER_X_STREAM_MAX_SEGMENT_SIZE_BYTES, "x-stream-max-segment-size-bytes"). --define(HEADER_X_MAX_PRIORITY, "x-max-priority"). --define(HEADER_X_MESSAGE_TTL, "x-message-ttl"). --define(HEADER_X_QUEUE_NAME, "x-queue-name"). --define(HEADER_X_QUEUE_TYPE, "x-queue-type"). --define(HEADER_X_STREAM_FILTER_SIZE_BYTES, "x-stream-filter-size-bytes"). +-include("rabbit_stomp_routing_prefixes.hrl"). --define(MESSAGE_ID_SEPARATOR, "@@"). +-define(HEADER_ACCEPT_VERSION, <<"accept-version">>). +-define(HEADER_ACK, <<"ack">>). +-define(HEADER_AMQP_MESSAGE_ID, <<"amqp-message-id">>). +-define(HEADER_APP_ID, <<"app-id">>). +-define(HEADER_AUTO_DELETE, <<"auto-delete">>). +-define(HEADER_CONTENT_ENCODING, <<"content-encoding">>). +-define(HEADER_CONTENT_LENGTH, <<"content-length">>). +-define(HEADER_CONTENT_TYPE, <<"content-type">>). +-define(HEADER_CORRELATION_ID, <<"correlation-id">>). +-define(HEADER_DESTINATION, <<"destination">>). +-define(HEADER_DURABLE, <<"durable">>). +-define(HEADER_EXPIRATION, <<"expiration">>). +-define(HEADER_EXCLUSIVE, <<"exclusive">>). +-define(HEADER_HEART_BEAT, <<"heart-beat">>). +-define(HEADER_HOST, <<"host">>). +-define(HEADER_ID, <<"id">>). +-define(HEADER_LOGIN, <<"login">>). +-define(HEADER_MESSAGE_ID, <<"message-id">>). +-define(HEADER_PASSCODE, <<"passcode">>). +-define(HEADER_PERSISTENT, <<"persistent">>). +-define(HEADER_PREFETCH_COUNT, <<"prefetch-count">>). +-define(HEADER_X_STREAM_OFFSET, <<"x-stream-offset">>). +-define(HEADER_X_STREAM_FILTER, <<"x-stream-filter">>). +-define(HEADER_X_STREAM_MATCH_UNFILTERED, <<"x-stream-match-unfiltered">>). +-define(HEADER_PRIORITY, <<"priority">>). +-define(HEADER_X_PRIORITY, <<"x-priority">>). +-define(HEADER_RECEIPT, <<"receipt">>). +-define(HEADER_REDELIVERED, <<"redelivered">>). +-define(HEADER_REPLY_TO, <<"reply-to">>). +-define(HEADER_SERVER, <<"server">>). +-define(HEADER_SESSION, <<"session">>). +-define(HEADER_SUBSCRIPTION, <<"subscription">>). +-define(HEADER_TIMESTAMP, <<"timestamp">>). +-define(HEADER_TRANSACTION, <<"transaction">>). +-define(HEADER_TYPE, <<"type">>). +-define(HEADER_USER_ID, <<"user-id">>). +-define(HEADER_VERSION, <<"version">>). +-define(HEADER_X_DEAD_LETTER_EXCHANGE, <<"x-dead-letter-exchange">>). +-define(HEADER_X_DEAD_LETTER_ROUTING_KEY, <<"x-dead-letter-routing-key">>). +-define(HEADER_X_EXPIRES, <<"x-expires">>). +-define(HEADER_X_MAX_LENGTH, <<"x-max-length">>). +-define(HEADER_X_MAX_AGE, <<"x-max-age">>). +-define(HEADER_X_MAX_LENGTH_BYTES, <<"x-max-length-bytes">>). +-define(HEADER_X_STREAM_MAX_SEGMENT_SIZE_BYTES, <<"x-stream-max-segment-size-bytes">>). +-define(HEADER_X_MAX_PRIORITY, <<"x-max-priority">>). +-define(HEADER_X_MESSAGE_TTL, <<"x-message-ttl">>). +-define(HEADER_X_QUEUE_NAME, <<"x-queue-name">>). +-define(HEADER_X_QUEUE_TYPE, <<"x-queue-type">>). +-define(HEADER_X_STREAM_FILTER_SIZE_BYTES, <<"x-stream-filter-size-bytes">>). + +-define(MESSAGE_ID_SEPARATOR, <<"@@">>). -define(HEADERS_NOT_ON_SEND, [?HEADER_MESSAGE_ID]). @@ -81,3 +83,11 @@ ?HEADER_EXCLUSIVE, ?HEADER_PERSISTENT ]). + + +%%------------------------------------------------- + +-define(DEST_PREFIXES, [?EXCHANGE_PREFIX, ?TOPIC_PREFIX, ?QUEUE_PREFIX, + ?AMQQUEUE_PREFIX, ?REPLY_QUEUE_PREFIX]). + +-define(ALL_DEST_PREFIXES, [?TEMP_QUEUE_PREFIX | ?DEST_PREFIXES]). diff --git a/deps/rabbitmq_stomp/include/rabbit_stomp_routing_prefixes.hrl b/deps/rabbitmq_stomp/include/rabbit_stomp_routing_prefixes.hrl index d0a34e467266..0aa1f8979dd7 100644 --- a/deps/rabbitmq_stomp/include/rabbit_stomp_routing_prefixes.hrl +++ b/deps/rabbitmq_stomp/include/rabbit_stomp_routing_prefixes.hrl @@ -5,11 +5,11 @@ %% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. %% --define(QUEUE_PREFIX, "/queue"). --define(TOPIC_PREFIX, "/topic"). --define(EXCHANGE_PREFIX, "/exchange"). --define(AMQQUEUE_PREFIX, "/amq/queue"). --define(TEMP_QUEUE_PREFIX, "/temp-queue"). -%% reply queues names can have slashes in the content so no further +-define(QUEUE_PREFIX, <<"/queue">>). +-define(TOPIC_PREFIX, <<"/topic">>). +-define(EXCHANGE_PREFIX, <<"/exchange">>). +-define(AMQQUEUE_PREFIX, <<"/amq/queue">>). +-define(TEMP_QUEUE_PREFIX, <<"/temp-queue">>). +%% Reply queue names can have slashes in the content so no further %% parsing happens. --define(REPLY_QUEUE_PREFIX, "/reply-queue/"). +-define(REPLY_QUEUE_PREFIX, <<"/reply-queue/">>). diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp.erl b/deps/rabbitmq_stomp/src/rabbit_stomp.erl index 6ce436a14d17..e7c2517c9107 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp.erl @@ -21,14 +21,18 @@ -define(DEFAULT_CONFIGURATION, #stomp_configuration{ - default_login = undefined, - default_passcode = undefined, - implicit_connect = false, - ssl_cert_login = false}). + default_login = undefined, + default_passcode = undefined, + implicit_connect = false, + ssl_cert_login = false, + max_header_length = 1024*100, + max_headers = 100, + max_body_length = 1024*1024*100}). start(normal, []) -> Config = parse_configuration(), Listeners = parse_listener_configuration(), + init_global_counters(), Result = rabbit_stomp_sup:start_link(Listeners, Config), EMPid = case rabbit_event:start_link() of {ok, Pid} -> Pid; @@ -75,7 +79,11 @@ parse_configuration() -> {ok, SSLLogin} = application:get_env(ssl_cert_login), {ok, ImplicitConnect} = application:get_env(implicit_connect), Conf = Conf0#stomp_configuration{ssl_cert_login = SSLLogin, - implicit_connect = ImplicitConnect}, + implicit_connect = ImplicitConnect, + max_headers = application:get_env(rabbitmq_stomp, max_headers, ?DEFAULT_CONFIGURATION#stomp_configuration.max_headers), + max_header_length = application:get_env(rabbitmq_stomp, max_header_length, ?DEFAULT_CONFIGURATION#stomp_configuration.max_header_length), + max_body_length = application:get_env(rabbitmq_stomp, max_body_length, ?DEFAULT_CONFIGURATION#stomp_configuration.max_body_length)}, + report_configuration(Conf), Conf. @@ -114,6 +122,20 @@ report_configuration(#stomp_configuration{ ok. +init_global_counters() -> + lists:foreach(fun init_global_counters/1, + [?STOMP_PROTO_V1_0, + ?STOMP_PROTO_V1_1, + ?STOMP_PROTO_V1_2]). + +init_global_counters(ProtoVer) -> + Proto = #{protocol => ProtoVer}, + rabbit_global_counters:init(Proto), + rabbit_global_counters:init(Proto#{queue_type => rabbit_classic_queue}), + rabbit_global_counters:init(Proto#{queue_type => rabbit_quorum_queue}), + rabbit_global_counters:init(Proto#{queue_type => rabbit_stream_queue}), + rabbit_msg_size_metrics:init(ProtoVer). + list() -> [Client || {_, ListSup, _, _} <- supervisor:which_children(rabbit_stomp_sup), diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_frame.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_frame.erl index 806a1d5d357f..007197c32b07 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_frame.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp_frame.erl @@ -2,7 +2,7 @@ %% License, v. 2.0. If a copy of the MPL was not distributed with this %% file, You can obtain one at https://mozilla.org/MPL/2.0/. %% -%% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term “Broadcom” refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. +%% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. %% -module(rabbit_stomp_frame). @@ -10,7 +10,7 @@ -include("rabbit_stomp_frame.hrl"). -include("rabbit_stomp_headers.hrl"). --export([parse/2, initial_state/0]). +-export([parse/2, initial_state/0, initial_state/1]). -export([header/2, header/3, boolean_header/2, boolean_header/3, integer_header/2, integer_header/3, @@ -18,16 +18,19 @@ -export([stream_offset_header/1, stream_filter_header/1]). -export([serialize/1, serialize/2]). -initial_state() -> none. +%% Only used by tests. Production code uses `initial_state/1` with an explicitly provided config. +initial_state() -> {none, ?DEFAULT_STOMP_PARSER_CONFIG}. +initial_state(Config) -> {none, Config}. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -%% STOMP 1.1 frames basic syntax +%% STOMP 1.0/1.1/1.2 frame syntax +%% %% Rabbit modifications: -%% o CR LF is equivalent to LF in all element terminators (eol). -%% o Escape codes for header names and values include \r for CR -%% and CR is not allowed. -%% o Header names and values are not limited to UTF-8 strings. -%% o Header values may contain unescaped colons +%% - CR LF is equivalent to LF in all element terminators (eol) +%% - Escape codes for header names and values include \r for CR +%% and CR is not allowed +%% - Header names and values are not limited to UTF-8 strings +%% - Header values may contain unescaped colons %% %% frame_seq ::= *(noise frame) %% noise ::= *(NUL | eol) @@ -41,197 +44,435 @@ initial_state() -> none. %% hdrvalue ::= *esc_char %% esc_char ::= HDROCT | BACKSLASH ESCCODE %% -%% Terms in CAPS all represent sets (alternatives) of single octets. -%% They are defined here using a small extension of BNF, minus (-): -%% -%% term1 - term2 denotes any of the possibilities in term1 -%% excluding those in term2. -%% In this grammar minus is only used for sets of single octets. -%% -%% OCTET ::= '00'x..'FF'x % any octet -%% NUL ::= '00'x % the zero octet -%% LF ::= '\n' % '0a'x newline or linefeed -%% CR ::= '\r' % '0d'x carriage return -%% NOTEOL ::= OCTET - (CR | LF) % any octet except CR or LF -%% BACKSLASH ::= '\\' % '5c'x +%% OCTET ::= '00'x..'FF'x +%% NUL ::= '00'x +%% LF ::= '\n' +%% CR ::= '\r' +%% NOTEOL ::= OCTET - (CR | LF) +%% BACKSLASH ::= '\\' %% ESCCODE ::= 'c' | 'n' | 'r' | BACKSLASH %% COLON ::= ':' %% HDROCT ::= NOTEOL - (COLON | BACKSLASH) -%% % octets allowed in a header %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -%% explicit frame characters +%% Frame characters -define(NUL, 0). -define(CR, $\r). -define(LF, $\n). -define(BSL, $\\). -define(COLON, $:). -%% header escape codes +%% Header escape codes -define(LF_ESC, $n). -define(BSL_ESC, $\\). -define(COLON_ESC, $c). -define(CR_ESC, $r). --define(MAX_HEADERS, 100). +%% Command lookup: binary -> atom for known STOMP commands. +%% Unknown commands pass through as binaries. +-define(KNOWN_COMMANDS, + #{<<"SEND">> => 'SEND', + <<"SUBSCRIBE">> => 'SUBSCRIBE', + <<"UNSUBSCRIBE">> => 'UNSUBSCRIBE', + <<"STOMP">> => 'STOMP', + <<"CONNECT">> => 'CONNECT', + <<"CONNECTED">> => 'CONNECTED', + <<"DISCONNECT">> => 'DISCONNECT', + <<"BEGIN">> => 'BEGIN', + <<"COMMIT">> => 'COMMIT', + <<"ABORT">> => 'ABORT', + <<"ACK">> => 'ACK', + <<"NACK">> => 'NACK', + <<"MESSAGE">> => 'MESSAGE', + <<"RECEIPT">> => 'RECEIPT', + <<"ERROR">> => 'ERROR'}). + +%% The longest known STOMP command is UNSUBSCRIBE (11 bytes). +%% Allow some headroom for unknown commands but bound memory usage. +-define(MAX_COMMAND_LENGTH, 32). + +%% Parser state. +%% acc is only used for header values with escape sequences. +%% Commands and headers without escapes use sub-binary extraction. +-record(ps, {acc = [] :: [byte()], + acc_len = 0 :: non_neg_integer(), + cmd :: atom() | binary() | undefined, + hdrs = #{} :: #{binary() => binary()}, + hdrname :: binary() | undefined, + config :: #stomp_parser_config{}}). -%% parser state --record(state, {acc, cmd, hdrs, seen = #{}, hdrname}). +%% +%% Public API +%% parse(Content, {resume, Continuation}) -> Continuation(Content); -parse(Content, none ) -> parser(Content, noframe, #state{}). +parse(Content, {none, Config}) -> parse_noise(Content, #ps{config = Config}). + +%% +%% Phase: noise — skip NULs and LFs between frames +%% + +parse_noise(<<>>, S) -> + more(fun(Rest) -> parse_noise(Rest, S) end); +parse_noise(<>, S) -> parse_noise(Rest, S); +parse_noise(<>, S) -> parse_noise(Rest, S); +parse_noise(<>, S) -> parse_noise(Rest, S); +parse_noise(<>, S) -> more(fun(Rest) -> parse_noise(<>, S) end); +parse_noise(<>, _S) -> {error, {unexpected_chars_between_frames, [?CR, Ch]}}; +parse_noise(Bin, S) -> parse_command(Bin, S). + +%% +%% Phase: command — scan for LF, extract as sub-binary +%% + +parse_command(Bin, S) -> + case scan_until_lf(Bin) of + {ok, CmdBin, Rest} -> + case byte_size(CmdBin) > ?MAX_COMMAND_LENGTH of + true -> {error, {command_too_long, ?MAX_COMMAND_LENGTH}}; + false -> + Cmd = maps:get(CmdBin, ?KNOWN_COMMANDS, CmdBin), + parse_headers(Rest, S#ps{cmd = Cmd, hdrs = #{}}) + end; + {more, Len} -> + case Len > ?MAX_COMMAND_LENGTH of + true -> {error, {command_too_long, ?MAX_COMMAND_LENGTH}}; + false -> more(fun(Rest) -> parse_command(<>, S) end) + end; + {error, _} = Err -> + Err + end. + +%% +%% Phase: headers — dispatch to hdrname or body +%% + +parse_headers(<>, S) -> + parse_body(Rest, S); +parse_headers(<>, S) -> + parse_body(Rest, S); +parse_headers(<<>>, S) -> + more(fun(Rest) -> parse_headers(Rest, S) end); +parse_headers(<>, S) -> + more(fun(Rest) -> parse_headers(<>, S) end); +parse_headers(Bin, S) -> + parse_hdr(Bin, S). + +%% +%% Phase: header line — scan for COLON and LF in bulk. +%% Fast path: no escapes or CR in the header line. +%% Slow path: escape sequences present, fall back to byte-by-byte. +%% + +parse_hdr(Bin, S = #ps{config = #stomp_parser_config{max_header_length = MaxHL}}) -> + case scan_header_line(Bin) of + {ok, Name, Value, Rest} -> + case byte_size(Name) > MaxHL orelse byte_size(Value) > MaxHL of + true -> {error, {max_header_length, MaxHL}}; + false -> + case insert_header(Name, Value, S) of + {ok, S1} -> parse_headers(Rest, S1); + {error, _} = E -> E + end + end; + has_escapes -> + parse_hdrname_esc(Bin, S#ps{acc = [], acc_len = 0}); + {no_value, Name} -> + {error, {header_no_value, Name}}; + {more, Len} -> + case Len > MaxHL of + true -> {error, {max_header_length, MaxHL}}; + false -> more(fun(Rest) -> parse_hdr(<>, S) end) + end; + {error, _} = Err -> + Err + end. + +%% Slow path for header names with escapes or CR +parse_hdrname_esc(<<>>, S) -> + more(fun(Rest) -> parse_hdrname_esc(Rest, S) end); +parse_hdrname_esc(<>, S) -> + more(fun(Rest) -> parse_hdrname_esc(<>, S) end); +parse_hdrname_esc(<>, #ps{acc = Acc}) -> + {error, {header_no_value, list_to_binary(lists:reverse(Acc))}}; +parse_hdrname_esc(<>, _) -> + {error, {unexpected_chars_in_header, [?CR, Ch]}}; +parse_hdrname_esc(<>, #ps{acc = Acc}) -> + {error, {header_no_value, list_to_binary(lists:reverse(Acc))}}; +parse_hdrname_esc(<>, S = #ps{acc = Acc}) -> + parse_hdrvalue_esc(Rest, S#ps{acc = [], acc_len = 0, + hdrname = list_to_binary(lists:reverse(Acc))}); +parse_hdrname_esc(<>, S) -> + more(fun(Rest) -> parse_hdrname_esc(<>, S) end); +parse_hdrname_esc(<>, S) -> + unescape(Ch, fun(Ech) -> parse_hdrname_esc(Rest, accum(Ech, S)) end); +parse_hdrname_esc(<>, S = #ps{acc_len = Len, + config = #stomp_parser_config{ + max_header_length = Max}}) -> + case Len >= Max of + true -> {error, {max_header_length, Max}}; + false -> parse_hdrname_esc(Rest, accum(Ch, S)) + end. + +%% Slow path for header values with escapes +parse_hdrvalue_esc(<<>>, S) -> + more(fun(Rest) -> parse_hdrvalue_esc(Rest, S) end); +parse_hdrvalue_esc(<>, S) -> + more(fun(Rest) -> parse_hdrvalue_esc(<>, S) end); +parse_hdrvalue_esc(<>, S) -> + finish_hdr_esc(Rest, S); +parse_hdrvalue_esc(<>, _) -> + {error, {unexpected_chars_in_header, [?CR, Ch]}}; +parse_hdrvalue_esc(<>, S) -> + finish_hdr_esc(Rest, S); +parse_hdrvalue_esc(<>, S) -> + more(fun(Rest) -> parse_hdrvalue_esc(<>, S) end); +parse_hdrvalue_esc(<>, S) -> + unescape(Ch, fun(Ech) -> parse_hdrvalue_esc(Rest, accum(Ech, S)) end); +parse_hdrvalue_esc(<>, S = #ps{acc_len = Len, + config = #stomp_parser_config{ + max_header_length = Max}}) -> + case Len >= Max of + true -> {error, {max_header_length, Max}}; + false -> parse_hdrvalue_esc(Rest, accum(Ch, S)) + end. + +finish_hdr_esc(Rest, #ps{acc = Acc, hdrname = HdrName} = S) -> + case insert_header(HdrName, list_to_binary(lists:reverse(Acc)), S) of + {ok, S1} -> parse_headers(Rest, S1); + {error, _} = E -> E + end. + +%% +%% Binary scanning helpers — bulk operations, no per-byte allocation +%% more(Continuation) -> {more, {resume, Continuation}}. -%% Single-function parser: Term :: noframe | command | headers | hdrname | hdrvalue -%% general more and line-end detection -parser(<<>>, Term , State) -> more(fun(Rest) -> parser(Rest, Term, State) end); -parser(<>, Term , State) -> more(fun(Rest) -> parser(<>, Term, State) end); -parser(<>, Term , State) -> parser(<>, Term, State); -parser(<>, Term , _State) -> {error, {unexpected_chars(Term), [?CR, Ch]}}; -%% escape processing (only in hdrname and hdrvalue terms) -parser(<>, Term , State) -> more(fun(Rest) -> parser(<>, Term, State) end); -parser(<>, Term , State) - when Term == hdrname; - Term == hdrvalue -> unescape(Ch, fun(Ech) -> parser(Rest, Term, accum(Ech, State)) end); -%% inter-frame noise -parser(<>, noframe , State) -> parser(Rest, noframe, State); -parser(<>, noframe , State) -> parser(Rest, noframe, State); -%% detect transitions -parser( Rest, noframe , State) -> goto(noframe, command, Rest, State); -parser(<>, command , State) -> goto(command, headers, Rest, State); -parser(<>, headers , State) -> goto(headers, body, Rest, State); -parser( Rest, headers , State) -> goto(headers, hdrname, Rest, State); -parser(<>, hdrname , State) -> goto(hdrname, hdrvalue, Rest, State); -parser(<>, hdrname , State) -> goto(hdrname, headers, Rest, State); -parser(<>, hdrvalue, State) -> goto(hdrvalue, headers, Rest, State); -%% accumulate -parser(<>, Term , State) -> parser(Rest, Term, accum(Ch, State)). - -%% state transitions -goto(noframe, command, Rest, State ) -> parser(Rest, command, State#state{acc = []}); -goto(command, headers, Rest, State = #state{acc = Acc} ) -> parser(Rest, headers, State#state{cmd = lists:reverse(Acc), hdrs = [], seen = #{}}); -goto(headers, body, Rest, #state{cmd = Cmd, hdrs = Hdrs}) -> parse_body(Rest, #stomp_frame{command = Cmd, headers = Hdrs}); -goto(headers, hdrname, Rest, State ) -> parser(Rest, hdrname, State#state{acc = []}); -goto(hdrname, hdrvalue, Rest, State = #state{acc = Acc} ) -> parser(Rest, hdrvalue, State#state{acc = [], hdrname = lists:reverse(Acc)}); -goto(hdrname, headers, _Rest, #state{acc = Acc} ) -> {error, {header_no_value, lists:reverse(Acc)}}; % badly formed header -- fatal error -goto(hdrvalue, headers, Rest, State = #state{acc = Acc, hdrs = Headers, seen = Seen, hdrname = HdrName}) -> - case Seen of - #{HdrName := _} -> - parser(Rest, headers, State#state{acc = []}); - _ when map_size(Seen) >= ?MAX_HEADERS -> - {error, too_many_headers}; - _ -> - Value = lists:reverse(Acc), - parser(Rest, headers, State#state{ - acc = [], - hdrs = [{HdrName, Value} | Headers], - seen = Seen#{HdrName => true} - }) +%% Scan for LF in a binary. Handles CR LF normalization. +%% Returns {ok, BeforeLF, AfterLF} | {more, CurrentLen} | {error, _} +scan_until_lf(Bin) -> + scan_until_lf(Bin, 0). + +scan_until_lf(Bin, Pos) -> + case Bin of + <<_:Pos/binary, ?LF, _/binary>> -> + <> = Bin, + {ok, Before, Rest}; + <<_:Pos/binary, ?CR, ?LF, _/binary>> -> + <> = Bin, + {ok, Before, Rest}; + <<_:Pos/binary, ?CR>> -> + {more, Pos}; + <<_:Pos/binary, ?CR, Ch:8, _/binary>> -> + {error, {unexpected_chars_in_command, [?CR, Ch]}}; + <<_:Pos/binary, _:8, _/binary>> -> + scan_until_lf(Bin, Pos + 1); + <<_:Pos/binary>> -> + {more, Pos} + end. + +%% Scan a complete header line: Name:Value\n +%% Fast path: no backslash or CR in the line. +%% Returns: +%% {ok, NameBin, ValueBin, Rest} — fast path, no escapes +%% has_escapes — contains \ or CR, use slow path +%% {no_value, NameBin} — LF before COLON +%% {more, Len} — need more data +%% {error, _} — CR not followed by LF +scan_header_line(Bin) -> + scan_hdr_name(Bin, 0). + +scan_hdr_name(Bin, Pos) -> + case Bin of + <<_:Pos/binary, ?COLON, _/binary>> -> + <> = Bin, + scan_hdr_value(Rest, Name, 0); + <<_:Pos/binary, ?LF, _/binary>> -> + <> = Bin, + {no_value, Name}; + <<_:Pos/binary, ?CR, ?LF, _/binary>> -> + <> = Bin, + {no_value, Name}; + <<_:Pos/binary, ?BSL, _/binary>> -> + has_escapes; + <<_:Pos/binary, ?CR>> -> + {more, Pos}; + <<_:Pos/binary, ?CR, Ch:8, _/binary>> -> + {error, {unexpected_chars_in_header, [?CR, Ch]}}; + <<_:Pos/binary, _:8, _/binary>> -> + scan_hdr_name(Bin, Pos + 1); + <<_:Pos/binary>> -> + {more, Pos} end. -%% error atom -unexpected_chars(noframe) -> unexpected_chars_between_frames; -unexpected_chars(command) -> unexpected_chars_in_command; -unexpected_chars(hdrname) -> unexpected_chars_in_header; -unexpected_chars(hdrvalue) -> unexpected_chars_in_header; -unexpected_chars(_Term) -> unexpected_chars. +scan_hdr_value(Bin, Name, Pos) -> + case Bin of + <<_:Pos/binary, ?LF, _/binary>> -> + <> = Bin, + {ok, Name, Value, Rest}; + <<_:Pos/binary, ?CR, ?LF, _/binary>> -> + <> = Bin, + {ok, Name, Value, Rest}; + <<_:Pos/binary, ?BSL, _/binary>> -> + has_escapes; + <<_:Pos/binary, ?CR>> -> + {more, Pos}; + <<_:Pos/binary, ?CR, Ch:8, _/binary>> -> + {error, {unexpected_chars_in_header, [?CR, Ch]}}; + <<_:Pos/binary, _:8, _/binary>> -> + scan_hdr_value(Bin, Name, Pos + 1); + <<_:Pos/binary>> -> + {more, Pos} + end. -%% general accumulation -accum(Ch, State = #state{acc = Acc}) -> State#state{acc = [Ch | Acc]}. +%% +%% Helpers +%% + +accum(Ch, S = #ps{acc = Acc, acc_len = Len}) -> + S#ps{acc = [Ch | Acc], acc_len = Len + 1}. -%% resolve escapes (with error processing) unescape(?LF_ESC, Fun) -> Fun(?LF); unescape(?BSL_ESC, Fun) -> Fun(?BSL); unescape(?COLON_ESC, Fun) -> Fun(?COLON); unescape(?CR_ESC, Fun) -> Fun(?CR); unescape(Ch, _Fun) -> {error, {bad_escape, [?BSL, Ch]}}. -parse_body(Content, Frame = #stomp_frame{command = Command}) -> - case Command of - "SEND" -> parse_body(Content, Frame, [], integer_header(Frame, ?HEADER_CONTENT_LENGTH, unknown)); - _ -> parse_body(Content, Frame, [], unknown) +%% First occurrence of a header name wins. +%% Duplicates are discarded without allocation. +%% The limit is checked only when a genuinely new header would be added. +insert_header(Name, Value, S = #ps{hdrs = Hdrs, + config = #stomp_parser_config{ + max_headers = MaxHeaders}}) -> + case Hdrs of + #{Name := _} -> + {ok, S}; + _ when map_size(Hdrs) >= MaxHeaders -> + {error, {max_headers, MaxHeaders}}; + _ -> + {ok, S#ps{hdrs = Hdrs#{Name => Value}}} + end. + +%% +%% Body parsing +%% + +parse_body(Content, #ps{cmd = Cmd, hdrs = Hdrs, + config = #stomp_parser_config{ + max_body_length = MaxBodyLength}}) -> + Frame = #stomp_frame{command = Cmd, headers = Hdrs}, + case Cmd of + 'SEND' -> + case integer_header(Frame, ?HEADER_CONTENT_LENGTH, unknown) of + ContentLength when is_integer(ContentLength), + ContentLength < 0 -> + {error, {invalid_content_length, ContentLength}}; + ContentLength when is_integer(ContentLength), + ContentLength > MaxBodyLength -> + {error, {max_body_length, ContentLength}}; + ContentLength when is_integer(ContentLength) -> + parse_known_body(Content, Frame, [], ContentLength); + _ -> + parse_unknown_body(Content, Frame, [], MaxBodyLength) + end; + _ -> + parse_unknown_body(Content, Frame, [], MaxBodyLength) + end. + +-define(MORE_BODY(Content, Frame, Chunks, Remaining), + Chunks1 = finalize_chunk(Content, Chunks), + more(fun(Rest) -> ?FUNCTION_NAME(Rest, Frame, Chunks1, Remaining) end)). + +parse_unknown_body(Content, Frame, Chunks, Remaining) -> + case firstnull(Content) of + -1 -> + ChunkSize = byte_size(Content), + case ChunkSize > Remaining of + true -> {error, {max_body_length, unknown}}; + false -> ?MORE_BODY(Content, Frame, Chunks, Remaining - ChunkSize) + end; + Pos -> + case Pos > Remaining of + true -> {error, {max_body_length, unknown}}; + false -> finish_body(Content, Frame, Chunks, Pos) + end end. -parse_body(Content, Frame, Chunks, unknown) -> - parse_body2(Content, Frame, Chunks, case firstnull(Content) of - -1 -> {more, unknown}; - Pos -> {done, Pos} - end); -parse_body(Content, Frame, Chunks, Remaining) -> +parse_known_body(Content, Frame, Chunks, Remaining) -> Size = byte_size(Content), - parse_body2(Content, Frame, Chunks, case Remaining >= Size of - true -> {more, Remaining - Size}; - false -> {done, Remaining} - end). - -parse_body2(Content, Frame, Chunks, {more, Left}) -> - Chunks1 = finalize_chunk(Content, Chunks), - more(fun(Rest) -> parse_body(Rest, Frame, Chunks1, Left) end); -parse_body2(Content, Frame, Chunks, {done, Pos}) -> - <> = Content, - Body = lists:reverse(finalize_chunk(Chunk, Chunks)), - {ok, Frame#stomp_frame{body_iolist = Body}, Rest}. + case Remaining >= Size of + true -> ?MORE_BODY(Content, Frame, Chunks, Remaining - Size); + false -> finish_body(Content, Frame, Chunks, Remaining) + end. + +finish_body(Content, Frame, Chunks, Pos) -> + case Content of + <> -> + Body = finalize_chunk(Chunk, Chunks), + {ok, Frame#stomp_frame{body_iolist_rev = Body}, Rest}; + _ -> + {error, missing_body_terminator} + end. finalize_chunk(<<>>, Chunks) -> Chunks; finalize_chunk(Chunk, Chunks) -> [Chunk | Chunks]. +firstnull(Content) -> firstnull(Content, 0). + +firstnull(<<>>, _N) -> -1; +firstnull(<<0, _Rest/binary>>, N) -> N; +firstnull(<<_Ch, Rest/binary>>, N) -> firstnull(Rest, N + 1). + +%% +%% Header accessors +%% + default_value({ok, Value}, _DefaultValue) -> Value; default_value(not_found, DefaultValue) -> DefaultValue. header(#stomp_frame{headers = Headers}, Key) -> - case lists:keysearch(Key, 1, Headers) of - {value, {_, Str}} -> {ok, Str}; - _ -> not_found + case maps:find(Key, Headers) of + {ok, _} = Ok -> Ok; + error -> not_found end. header(F, K, D) -> default_value(header(F, K), D). -boolean_header(#stomp_frame{headers = Headers}, Key) -> - case lists:keysearch(Key, 1, Headers) of - {value, {_, "true"}} -> {ok, true}; - {value, {_, "false"}} -> {ok, false}; - %% some Python clients serialize True/False as "True"/"False" - {value, {_, "True"}} -> {ok, true}; - {value, {_, "False"}} -> {ok, false}; - _ -> not_found +boolean_header(F, Key) -> + case header(F, Key) of + {ok, <<"true">>} -> {ok, true}; + {ok, <<"false">>} -> {ok, false}; + {ok, <<"True">>} -> {ok, true}; + {ok, <<"False">>} -> {ok, false}; + _ -> not_found end. boolean_header(F, K, D) -> default_value(boolean_header(F, K), D). -internal_integer_header(Headers, Key) -> - case lists:keysearch(Key, 1, Headers) of - {value, {_, Str}} -> {ok, list_to_integer(string:strip(Str))}; - _ -> not_found +integer_header(F, Key) -> + case header(F, Key) of + {ok, Str} -> + try {ok, binary_to_integer(string:trim(Str))} + catch _:_ -> not_found + end; + not_found -> not_found end. -integer_header(#stomp_frame{headers = Headers}, Key) -> - internal_integer_header(Headers, Key). - integer_header(F, K, D) -> default_value(integer_header(F, K), D). -binary_header(F, K) -> - case header(F, K) of - {ok, Str} -> {ok, list_to_binary(Str)}; - not_found -> not_found - end. +binary_header(F, K) -> header(F, K). binary_header(F, K, D) -> default_value(binary_header(F, K), D). stream_offset_header(F) -> case binary_header(F, ?HEADER_X_STREAM_OFFSET) of - {ok, <<"first">>} -> - {longstr, <<"first">>}; - {ok, <<"last">>} -> - {longstr, <<"last">>}; - {ok, <<"next">>} -> - {longstr, <<"next">>}; - {ok, <<"offset=", OffsetValue/binary>>} -> - {long, binary_to_integer(OffsetValue)}; - {ok, <<"timestamp=", TimestampValue/binary>>} -> - {timestamp, binary_to_integer(TimestampValue)}; - _ -> - not_found + {ok, <<"first">>} -> {longstr, <<"first">>}; + {ok, <<"last">>} -> {longstr, <<"last">>}; + {ok, <<"next">>} -> {longstr, <<"next">>}; + {ok, <<"offset=", V/binary>>} -> {long, binary_to_integer(V)}; + {ok, <<"timestamp=", V/binary>>} -> {timestamp, binary_to_integer(V)}; + _ -> not_found end. stream_filter_header(F) -> @@ -239,7 +480,7 @@ stream_filter_header(F) -> {ok, Str} -> {array, lists:reverse( lists:foldl(fun(V, Acc) -> - [{longstr, V}] ++ Acc + [{longstr, V} | Acc] end, [], binary:split(Str, <<",">>, [global])))}; @@ -247,45 +488,53 @@ stream_filter_header(F) -> not_found end. +%% +%% Serialization +%% + serialize(Frame) -> serialize(Frame, true). -%% second argument controls whether a trailing linefeed -%% character should be added, see rabbitmq/rabbitmq-stomp#39. serialize(Frame, true) -> serialize(Frame, false) ++ [?LF]; serialize(#stomp_frame{command = Command, headers = Headers, - body_iolist = BodyFragments}, false) -> + body_iolist_rev = BodyFragments}, false) -> Len = iolist_size(BodyFragments), - [Command, ?LF, - lists:map(fun serialize_header/1, - lists:keydelete(?HEADER_CONTENT_LENGTH, 1, Headers)), + [serialize_command(Command), ?LF, + serialize_headers(Headers), if - Len > 0 -> [?HEADER_CONTENT_LENGTH ++ ":", integer_to_list(Len), ?LF]; + Len > 0 -> [?HEADER_CONTENT_LENGTH, ?COLON, integer_to_list(Len), ?LF]; true -> [] end, - ?LF, BodyFragments, 0]. + ?LF, case BodyFragments of + _ when is_binary(BodyFragments) -> BodyFragments; + _ -> lists:reverse(BodyFragments) + end, 0]. + +serialize_headers(Headers) -> + maps:fold(fun(K, _V, Acc) when K =:= ?HEADER_CONTENT_LENGTH -> Acc; + (K, V, Acc) -> [serialize_header(K, V) | Acc] + end, [], Headers). + +serialize_command(Command) when is_atom(Command) -> + atom_to_binary(Command, utf8); +serialize_command(Command) -> Command. -serialize_header({K, V}) when is_integer(V) -> hdr(escape(K), integer_to_list(V)); -serialize_header({K, V}) when is_boolean(V) -> hdr(escape(K), boolean_to_list(V)); -serialize_header({K, V}) when is_list(V) -> hdr(escape(K), escape(V)). +serialize_header(K, V) when is_integer(V) -> hdr(escape(K), integer_to_list(V)); +serialize_header(K, V) when is_boolean(V) -> hdr(escape(K), boolean_to_list(V)); +serialize_header(K, V) when is_binary(V) -> hdr(escape(K), escape(V)). boolean_to_list(true) -> "true"; boolean_to_list(_) -> "false". hdr(K, V) -> [K, ?COLON, V, ?LF]. -escape(Str) -> [escape1(Ch) || Ch <- Str]. +escape(Bin) -> escape(Bin, []). -escape1(?COLON) -> [?BSL, ?COLON_ESC]; -escape1(?BSL) -> [?BSL, ?BSL_ESC]; -escape1(?LF) -> [?BSL, ?LF_ESC]; -escape1(?CR) -> [?BSL, ?CR_ESC]; -escape1(Ch) -> Ch. - -firstnull(Content) -> firstnull(Content, 0). - -firstnull(<<>>, _N) -> -1; -firstnull(<<0, _Rest/binary>>, N) -> N; -firstnull(<<_Ch, Rest/binary>>, N) -> firstnull(Rest, N+1). +escape(<<>>, Acc) -> lists:reverse(Acc); +escape(<>, Acc) -> escape(Rest, [?COLON_ESC, ?BSL | Acc]); +escape(<>, Acc) -> escape(Rest, [?BSL_ESC, ?BSL | Acc]); +escape(<>, Acc) -> escape(Rest, [?LF_ESC, ?BSL | Acc]); +escape(<>, Acc) -> escape(Rest, [?CR_ESC, ?BSL | Acc]); +escape(<>, Acc) -> escape(Rest, [Ch | Acc]). diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl index 54c6999d794a..7f5ee32311e4 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp_processor.erl @@ -7,223 +7,281 @@ -module(rabbit_stomp_processor). +-feature(maybe_expr, enable). + -compile({no_auto_import, [error/3]}). --export([initial_state/2, process_frame/2, flush_and_die/1]). +-export([initial_state/2, + process_frame/2, + flush_and_die/1, + info/2]). + -export([flush_pending_receipts/3, - handle_exit/3, cancel_consumer/2, - send_delivery/5]). - --export([adapter_name/1]). --export([info/2]). + handle_down/2, + handle_queue_event/2]). --include_lib("amqp_client/include/amqp_client.hrl"). +-include_lib("kernel/include/logger.hrl"). -include("rabbit_stomp_frame.hrl"). -include("rabbit_stomp.hrl"). -include("rabbit_stomp_headers.hrl"). --include_lib("kernel/include/logger.hrl"). - --record(proc_state, {session_id, channel, connection, subscriptions, - version, start_heartbeat_fun, pending_receipts, - config, route_state, reply_queues, frame_transformer, - adapter_info, send_fun, ssl_login_name, peer_addr, - %% see rabbitmq/rabbitmq-stomp#39 - trailing_lf, auth_mechanism, auth_login, - default_topic_exchange, default_nack_requeue, - virtual_host}). - --record(subscription, {dest_hdr, ack_mode, multi_ack, description}). - --define(FLUSH_TIMEOUT, 60000). - -adapter_name(State) -> - #proc_state{adapter_info = #amqp_adapter_info{name = Name}} = State, - Name. - -%%---------------------------------------------------------------------------- - --spec initial_state( - #stomp_configuration{}, - {SendFun, AdapterInfo, SSLLoginName, PeerAddr}) - -> #proc_state{} - when SendFun :: fun((binary()) -> term()), - AdapterInfo :: #amqp_adapter_info{}, - SSLLoginName :: atom() | binary(), - PeerAddr :: inet:ip_address(). +-include_lib("rabbit/include/amqqueue.hrl"). +-include_lib("rabbit_common/include/rabbit.hrl"). +-include_lib("rabbit_common/include/rabbit_framing.hrl"). + +-import(rabbit_misc, [maps_put_truthy/3]). + +-define(QUEUE, lqueue). +-define(MAX_PERMISSION_CACHE_SIZE, 12). +-define(MAX_TRANSACTIONS, 16). +-define(MAX_TRANSACTION_ACTIONS, 1024). + +-record(conn_info, { + conn_name :: binary(), + host :: inet:ip_address(), + port :: inet:port_number(), + peer_host :: inet:ip_address(), + peer_port :: inet:port_number(), + connected_at :: integer() + }). + +-record(subscription, {dest_hdr, ack_mode, multi_ack, description, queue_name}). +-type session_id() :: string(). +-type subscriptions() :: #{rabbit_types:ctag() => #subscription{}}. + +-type frame_transformer() :: fun ((#stomp_frame{}) -> #stomp_frame{}). + +-record(pending_ack, { + %% delivery identifier used by clients + %% to acknowledge and reject deliveries + delivery_tag :: non_neg_integer(), + %% consumer tag + tag :: rabbit_types:ctag(), + delivered_at :: integer(), + %% queue name + queue :: rabbit_amqqueue:name(), + %% message ID used by queue and message store implementations + msg_id :: rabbit_amqqueue:msg_id() + }). + +-record(cfg, + { + session_id :: none | session_id(), + version :: none | string(), + proto_ver :: undefined | atom(), + default_login :: undefined | binary(), + default_passcode :: undefined | binary(), + ssl_login_name :: none | binary(), + force_default_creds :: boolean(), + implicit_connect :: boolean(), + frame_transformer :: undefined | frame_transformer(), + send_fun :: send_fun(), + conn_info :: #conn_info{}, + trailing_lf :: boolean(), + auth_mechanism :: undefined | config | ssl | stomp_headers, + auth_login :: undefined | binary(), + vhost :: undefined | binary(), + default_topic_exchange :: binary(), + default_nack_requeue = true :: boolean(), + delivery_flow :: flow | noflow, + trace_state :: undefined | rabbit_trace:state(), + msg_interceptor_ctx :: undefined | map() + }). + +-record(state, + { + cfg :: #cfg{}, + user :: undefined | #user{}, + authz_ctx :: undefined | map(), + subscriptions :: subscriptions(), + pending_receipts :: gb_trees:tree(integer(), binary()), + route_state :: sets:set(), + reply_queues :: #{binary() => binary()}, + confirmed :: [rabbit_confirms:mx()], + rejected :: [rabbit_confirms:mx()], + unconfirmed :: rabbit_confirms:state(), + %% a map of queue names to consumer tag lists + queue_consumers :: #{rabbit_amqqueue:name() => rabbit_types:ctag()}, + unacked_message_q :: ?QUEUE:?QUEUE(#pending_ack{}), + queue_states :: rabbit_queue_type:state(), + delivery_tag = 0 :: non_neg_integer(), + msg_seq_no = 1 :: pos_integer(), + publisher = false :: boolean() + }). -type process_frame_result() :: - {ok, #proc_state{}, pid() | undefined} | - {stop, term(), #proc_state{}}. - --spec process_frame(#stomp_frame{}, #proc_state{}) -> - process_frame_result(). - --spec flush_and_die(#proc_state{}) -> #proc_state{}. - --spec command({Command, Frame}, State) -> process_frame_result() - when Command :: string(), - Frame :: #stomp_frame{}, - State :: #proc_state{}. - --type process_fun() :: fun((#proc_state{}) -> - {ok, #stomp_frame{}, #proc_state{}} | - {error, string(), string(), #proc_state{}} | - {stop, term(), #proc_state{}}). --spec process_request(process_fun(), fun((#proc_state{}) -> #proc_state{}), #proc_state{}) -> - process_frame_result(). + {ok, #state{}} | + {stop, term(), #state{}}. --spec flush_pending_receipts(DeliveryTag, IsMulti, State) -> State - when State :: #proc_state{}, - DeliveryTag :: term(), - IsMulti :: boolean(). - --spec handle_exit(From, Reason, State) -> unknown_exit | {stop, Reason, State} - when State :: #proc_state{}, - From :: pid(), - Reason :: term(). - --spec cancel_consumer(binary(), #proc_state{}) -> process_frame_result(). - --spec send_delivery(#'basic.deliver'{}, term(), term(), term(), - #proc_state{}) -> #proc_state{}. - -%%---------------------------------------------------------------------------- +-export_type ([process_frame_result/0]). +-export([adapter_name/1]). %%---------------------------------------------------------------------------- %% Public API %%---------------------------------------------------------------------------- +adapter_name(#state{cfg = #cfg{conn_info = #conn_info{conn_name = Name}}}) -> + Name. +-spec initial_state( + #stomp_configuration{}, + {SendFun, SSLLoginName, ConnName, Host, Port, PeerHost, PeerPort}) + -> #state{} + when SendFun :: send_fun(), + SSLLoginName :: none | binary(), + ConnName :: binary(), + Host :: inet:ip_address(), + Port :: inet:port_number(), + PeerHost :: inet:ip_address(), + PeerPort :: inet:port_number(). +initial_state(Configuration, + {SendFun, SSLLoginName, ConnName, + Host, Port, PeerHost, PeerPort}) -> + Flow = case rabbit_misc:get_env(rabbit, mirroring_flow_control, true) of + true -> flow; + false -> noflow + end, + ConnInfo = #conn_info{conn_name = ConnName, + host = Host, + port = Port, + peer_host = PeerHost, + peer_port = PeerPort, + connected_at = os:system_time(millisecond)}, + #state { + cfg = #cfg{ + send_fun = SendFun, + ssl_login_name = SSLLoginName, + conn_info = ConnInfo, + session_id = none, + frame_transformer = undefined, + version = none, + trailing_lf = application:get_env(rabbitmq_stomp, trailing_lf, true), + default_topic_exchange = application:get_env(rabbitmq_stomp, default_topic_exchange, <<"amq.topic">>), + default_nack_requeue = application:get_env(rabbitmq_stomp, default_nack_requeue, true), + implicit_connect = Configuration#stomp_configuration.implicit_connect, + default_login = Configuration#stomp_configuration.default_login, + default_passcode = Configuration#stomp_configuration.default_passcode, + force_default_creds = Configuration#stomp_configuration.force_default_creds, + delivery_flow = Flow + }, + subscriptions = #{}, + queue_consumers = #{}, + route_state = routing_init_state(), + reply_queues = #{}, + msg_seq_no = 1, + unconfirmed = rabbit_confirms:init(), + confirmed = [], + unacked_message_q = ?QUEUE:new(), + rejected = [], + queue_states = rabbit_queue_type:init(), + pending_receipts = gb_trees:empty() + }. + +-spec process_frame(#stomp_frame{}, #state{}) -> + process_frame_result(). process_frame(Frame = #stomp_frame{command = Command}, State) -> command({Command, Frame}, State). +-spec flush_and_die(#state{}) -> #state{}. flush_and_die(State) -> close_connection(State). -info(session_id, #proc_state{session_id = Val}) -> +-spec info(Key, State) -> Result + when + Key :: atom(), + State :: #state{}, + Result :: term(). %% TODO: somewhere these values are used to render things + %% to CLI and Management UI, what types do they support? +info(session_id, #state{cfg=#cfg{session_id = Val}}) -> Val; -info(channel, #proc_state{channel = Val}) -> Val; -info(version, #proc_state{version = Val}) -> Val; -info(implicit_connect, #proc_state{config = #stomp_configuration{implicit_connect = Val}}) -> Val; -info(auth_login, #proc_state{auth_login = Val}) -> Val; -info(auth_mechanism, #proc_state{auth_mechanism = Val}) -> Val; -info(peer_addr, #proc_state{peer_addr = Val}) -> Val; -info(host, #proc_state{adapter_info = #amqp_adapter_info{host = Val}}) -> Val; -info(port, #proc_state{adapter_info = #amqp_adapter_info{port = Val}}) -> Val; -info(peer_host, #proc_state{adapter_info = #amqp_adapter_info{peer_host = Val}}) -> Val; -info(peer_port, #proc_state{adapter_info = #amqp_adapter_info{peer_port = Val}}) -> Val; -info(protocol, #proc_state{adapter_info = #amqp_adapter_info{protocol = Val}}) -> - case Val of - {Proto, Version} -> {Proto, rabbit_data_coercion:to_binary(Version)}; - Other -> Other - end; -info(channels, PState) -> additional_info(channels, PState); -info(channel_max, PState) -> additional_info(channel_max, PState); -info(frame_max, PState) -> additional_info(frame_max, PState); -info(client_properties, PState) -> additional_info(client_properties, PState); -info(ssl, PState) -> additional_info(ssl, PState); -info(ssl_protocol, PState) -> additional_info(ssl_protocol, PState); -info(ssl_key_exchange, PState) -> additional_info(ssl_key_exchange, PState); -info(ssl_cipher, PState) -> additional_info(ssl_cipher, PState); -info(ssl_hash, PState) -> additional_info(ssl_hash, PState). +info(version, #state{cfg = #cfg{version = Val}}) -> Val; +info(implicit_connect, #state{cfg = #cfg{implicit_connect = Val}}) -> Val; +info(auth_login, #state{cfg = #cfg{auth_login = Val}}) -> Val; +info(auth_mechanism, #state{cfg = #cfg{auth_mechanism = Val}}) -> Val; +info(peer_addr, #state{cfg = #cfg{conn_info = #conn_info{peer_host = Val}}}) -> Val; +info(host, #state{cfg = #cfg{conn_info = #conn_info{host = Val}}}) -> Val; +info(port, #state{cfg = #cfg{conn_info = #conn_info{port = Val}}}) -> Val; +info(peer_host, #state{cfg = #cfg{conn_info = #conn_info{peer_host = Val}}}) -> Val; +info(peer_port, #state{cfg = #cfg{conn_info = #conn_info{peer_port = Val}}}) -> Val; +info(connected_at, #state{cfg = #cfg{conn_info = #conn_info{connected_at = Val}}}) -> Val; +info(protocol, #state{cfg = #cfg{version = Version}}) -> + VersionTuple = case Version of + "1.0" -> {1, 0}; + "1.1" -> {1, 1}; + "1.2" -> {1, 2}; + _ -> none + end, + {'STOMP', VersionTuple}; +info(vhost, #state{cfg = #cfg{vhost = Val}}) -> Val; +info(user, #state{user = undefined}) -> undefined; +info(user, #state{user = #user{username = Username}}) -> Username; +info(channels, _) -> 1; +info(channel_max, _) -> 1; +info(frame_max, _) -> 0; +info(client_properties, _) -> + [{<<"product">>, longstr, <<"STOMP client">>}]; +info(user_who_performed_action, S) -> info(user, S); +info(Other, _) -> throw({bad_argument, Other}). -initial_state(Configuration, - {SendFun, AdapterInfo0 = #amqp_adapter_info{additional_info = Extra}, - SSLLoginName, PeerAddr}) -> - %% STOMP connections use exactly one channel. The frame max is not - %% applicable and there is no way to know what client is used. - AdapterInfo = AdapterInfo0#amqp_adapter_info{additional_info=[ - {channels, 1}, - {channel_max, 1}, - {frame_max, 0}, - %% TODO: can we use a header to make it possible for clients - %% to override this value? - {client_properties, [{<<"product">>, longstr, <<"STOMP client">>}]} - |Extra]}, - #proc_state { - send_fun = SendFun, - adapter_info = AdapterInfo, - ssl_login_name = SSLLoginName, - peer_addr = PeerAddr, - session_id = none, - channel = none, - connection = none, - subscriptions = #{}, - version = none, - pending_receipts = undefined, - config = Configuration, - route_state = rabbit_stomp_routing_util:init_state(), - reply_queues = #{}, - frame_transformer = undefined, - trailing_lf = application:get_env(rabbitmq_stomp, trailing_lf, true), - default_topic_exchange = application:get_env(rabbitmq_stomp, default_topic_exchange, <<"amq.topic">>), - default_nack_requeue = application:get_env(rabbitmq_stomp, default_nack_requeue, true)}. +%%---------------------------------------------------------------------------- +%% Private Parts (Including callbacks) +%%---------------------------------------------------------------------------- -command({"STOMP", Frame}, State) -> +command({'STOMP', Frame}, State) -> process_connect(no_implicit, Frame, State); -command({"CONNECT", Frame}, State) -> +command({'CONNECT', Frame}, State) -> process_connect(no_implicit, Frame, State); -command(Request, State = #proc_state{channel = none, - config = #stomp_configuration{ - implicit_connect = true}}) -> - {ok, State1 = #proc_state{channel = Ch}, _} = - process_connect(implicit, #stomp_frame{headers = []}, State), - case Ch of - none -> {stop, normal, State1}; - _ -> command(Request, State1) +command(Request, State = #state{user = undefined, + cfg = #cfg{ + implicit_connect = true}}) -> + + case process_connect(implicit, #stomp_frame{headers = #{}}, State) of + {ok, State1 = #state{user = undefined}} -> + {stop, normal, State1}; + {ok, State1 = #state{user = _User}} -> + command(Request, State1); + Res -> Res end; -command(_Request, State = #proc_state{channel = none, - config = #stomp_configuration{ - implicit_connect = false}}) -> +command(_Request, State = #state{user = undefined, + cfg = #cfg{ + implicit_connect = false}}) -> {ok, send_error("Illegal command", "You must log in using CONNECT first", - State), none}; + State)}; -command({Command, Frame}, State = #proc_state{frame_transformer = FT}) -> +command({Command, Frame}, State = #state{cfg = #cfg{frame_transformer = FT}}) + when is_function(FT) -> Frame1 = FT(Frame), process_request( fun(StateN) -> - case validate_frame(Command, Frame1, StateN) of - R = {error, _, _, _} -> R; - _ -> handle_frame(Command, Frame1, StateN) - end + case validate_frame(Command, Frame1, StateN) of + R = {error, _, _, _} -> R; + _ -> handle_frame(Command, Frame1, StateN) + end end, fun(StateM) -> ensure_receipt(Frame1, StateM) end, State). -cancel_consumer(Ctag, State) -> - process_request( - fun(StateN) -> server_cancel_consumer(Ctag, StateN) end, - State). - -handle_exit(Conn, {shutdown, {server_initiated_close, Code, Explanation}}, - State = #proc_state{connection = Conn}) -> - amqp_death(Code, Explanation, State); -handle_exit(Conn, {shutdown, {connection_closing, - {server_initiated_close, Code, Explanation}}}, - State = #proc_state{connection = Conn}) -> - amqp_death(Code, Explanation, State); -handle_exit(Conn, Reason, State = #proc_state{connection = Conn}) -> - _ = send_error("AMQP connection died", "Reason: ~tp", [Reason], State), - {stop, {conn_died, Reason}, State}; - -handle_exit(Ch, {shutdown, {server_initiated_close, Code, Explanation}}, - State = #proc_state{channel = Ch}) -> - amqp_death(Code, Explanation, State); - -handle_exit(Ch, Reason, State = #proc_state{channel = Ch}) -> - _ = send_error("AMQP channel died", "Reason: ~tp", [Reason], State), - {stop, {channel_died, Reason}, State}; -handle_exit(Ch, {shutdown, {server_initiated_close, Code, Explanation}}, - State = #proc_state{channel = Ch}) -> - amqp_death(Code, Explanation, State); -handle_exit(_, _, _) -> unknown_exit. - +handle_consuming_queue_down_or_eol(QName, + State = #state{queue_consumers = QCons}) -> + ConsumerTags = case maps:find(QName, QCons) of + error -> gb_sets:new(); + {ok, CTags} -> CTags + end, + gb_sets:fold( + fun (CTag, StateN) -> + {ok, S} = cancel_consumer(CTag, StateN), + S + end, State#state{queue_consumers = maps:remove(QName, QCons)}, ConsumerTags). + +cancel_consumer(CTag, State) -> + process_request( + fun(StateN) -> server_cancel_consumer(CTag, StateN) end, + State). process_request(ProcessFun, State) -> process_request(ProcessFun, fun (StateM) -> StateM end, State). @@ -235,23 +293,25 @@ process_request(ProcessFun, SuccessFun, State) -> {{shutdown, {server_initiated_close, ReplyCode, Explanation}}, _}} -> amqp_death(ReplyCode, Explanation, State); - {'EXIT', {amqp_error, access_refused, Msg, _}} -> - amqp_death(access_refused, Msg, State); + {'EXIT', {amqp_error, Name, Msg, _}} -> + amqp_death(Name, Msg, State); {'EXIT', Reason} -> priv_error("Processing error", "Processing error", - Reason, State); + Reason, State); Result -> Result end, case Res of - {ok, Frame, NewState = #proc_state{connection = Conn}} -> + {ok, Frame, NewState} -> _ = case Frame of none -> ok; _ -> send_frame(Frame, NewState) end, - {ok, SuccessFun(NewState), Conn}; - {error, Message, Detail, NewState = #proc_state{connection = Conn}} -> - {ok, send_error(Message, Detail, NewState), Conn}; + {ok, SuccessFun(NewState)}; + {ok, NewState} -> + {ok, SuccessFun(NewState)}; + {error, Message, Detail, NewState} -> + {ok, send_error(Message, Detail, NewState)}; {stop, normal, NewState} -> {stop, normal, SuccessFun(NewState)}; {stop, R, NewState} -> @@ -259,49 +319,112 @@ process_request(ProcessFun, SuccessFun, State) -> end. process_connect(Implicit, Frame, - State = #proc_state{channel = none, - config = Config, - ssl_login_name = SSLLoginName, - adapter_info = AdapterInfo}) -> + State = #state{user = undefined, + cfg = Config = #cfg{ + conn_info = ConnInfo, + ssl_login_name = SSLLoginName}}) -> + PeerIp = ConnInfo#conn_info.peer_host, process_request( fun(StateN) -> - case negotiate_version(Frame) of - {ok, Version} -> - FT = frame_transformer(Version), - Frame1 = FT(Frame), - {Auth, {Username, Passwd}} = creds(Frame1, SSLLoginName, Config), - {ok, DefaultVHost} = application:get_env( - rabbitmq_stomp, default_vhost), - {ProtoName, _} = AdapterInfo#amqp_adapter_info.protocol, - Res = do_login( - Username, Passwd, - login_header(Frame1, ?HEADER_HOST, DefaultVHost), - login_header(Frame1, ?HEADER_HEART_BEAT, "0,0"), - AdapterInfo#amqp_adapter_info{ - protocol = {ProtoName, Version}}, Version, - StateN#proc_state{frame_transformer = FT, - auth_mechanism = Auth, - auth_login = Username}), - case {Res, Implicit} of - {{ok, _, StateN1}, implicit} -> ok(StateN1); - _ -> Res - end; + Res1 = maybe + {ok, Version} ?= negotiate_version(Frame), + ProtoVer = stomp_proto_ver(Version), + FT = frame_transformer(Version), + Frame1 = FT(Frame), + {Auth, {Username, _}} = Creds = creds(Frame1, SSLLoginName, Config), + {ok, DefaultVHost} = application:get_env(rabbitmq_stomp, default_vhost), + VHost = login_header(Frame1, ?HEADER_HOST, DefaultVHost), + Heartbeat = login_header(Frame1, ?HEADER_HEART_BEAT, <<"0,0">>), + StateN1 = StateN#state{cfg = Config#cfg{vhost = VHost, + proto_ver = ProtoVer, + frame_transformer = FT, + auth_mechanism = Auth, + auth_login = Username}}, + {Username, AuthProps} = auth_props_for_creds(Creds, StateN1), + {ok, User} ?= rabbit_access_control:check_user_login(Username, AuthProps), + ok ?= check_vhost_exists(VHost, Username, PeerIp), + {ok, AuthzCtx} ?= check_vhost_access(VHost, User, PeerIp), + ok ?= check_vhost_connection_limit(VHost), + ok ?= check_user_loopback(Username, PeerIp), + rabbit_core_metrics:auth_attempt_succeeded(PeerIp, Username, stomp), + TraceState = rabbit_trace:init(VHost), + MsgIcptCtx = #{protocol => stomp, + vhost => VHost, + username => Username, + connection_name => ConnInfo#conn_info.conn_name}, + SessionId = rabbit_guid:string(rabbit_guid:gen_secure(), "session"), + {SendTimeout, ReceiveTimeout} = ensure_heartbeats(Heartbeat), + + Headers0 = #{?HEADER_SESSION => list_to_binary(SessionId), + ?HEADER_HEART_BEAT => + <<(integer_to_binary(SendTimeout))/binary, $,, + (integer_to_binary(ReceiveTimeout))/binary>>, + ?HEADER_VERSION => list_to_binary(Version)}, + Headers = case application:get_env(rabbitmq_stomp, hide_server_info, false) of + true -> Headers0; + false -> Headers0#{?HEADER_SERVER => iolist_to_binary(server_header())} + end, + + Res = ok("CONNECTED", Headers, + [], + StateN1#state{cfg = StateN1#state.cfg#cfg{ + session_id = SessionId, + version = Version, + trace_state = TraceState, + msg_interceptor_ctx = MsgIcptCtx + }, + user = User, + authz_ctx = AuthzCtx}), + Res + else {error, no_common_version} -> error("Version mismatch", "Supported versions are ~ts~n", [string:join(?SUPPORTED_VERSIONS, ",")], - StateN) + StateN); + {error, not_allowed, EUsername, EVHost} -> + ?LOG_WARNING("STOMP login failed for user '~ts': " + "virtual host '~ts' access not allowed", + [EUsername, EVHost]), + error("Bad CONNECT", + "Virtual host '~ts' access denied", + [EVHost], State); + {refused, Username1, _Msg, _Args} -> + ?LOG_WARNING("STOMP login failed for user '~ts': " + "authentication failed", [Username1]), + error("Bad CONNECT", + "Access refused for user '~ts'", + [Username1], State); + {error, not_loopback, EUsername} -> + ?LOG_WARNING("STOMP login failed for user '~ts': " + "this user's access is restricted to localhost", + [EUsername]), + error("Bad CONNECT", "non-loopback access denied", State); + {error, quota_exceeded} -> + error("Bad CONNECT", + "Connection refused: vhost connection limit reached", + State) + end, + case {Res1, Implicit} of + {{ok, _, StateN2}, implicit} -> + self() ! connection_created, ok(StateN2); + _ -> + self() ! connection_created, Res1 end end, State). -creds(_, _, #stomp_configuration{default_login = DefLogin, - default_passcode = DefPasscode, - force_default_creds = true}) -> +stomp_proto_ver("1.0") -> ?STOMP_PROTO_V1_0; +stomp_proto_ver("1.1") -> ?STOMP_PROTO_V1_1; +stomp_proto_ver("1.2") -> ?STOMP_PROTO_V1_2. + +creds(_, _, #cfg{default_login = DefLogin, + default_passcode = DefPasscode, + force_default_creds = true}) -> {config, {iolist_to_binary(DefLogin), iolist_to_binary(DefPasscode)}}; creds(Frame, SSLLoginName, - #stomp_configuration{default_login = DefLogin, - default_passcode = DefPasscode}) -> + #cfg{default_login = DefLogin, + default_passcode = DefPasscode}) -> PasswordCreds = {login_header(Frame, ?HEADER_LOGIN, DefLogin), login_header(Frame, ?HEADER_PASSCODE, DefPasscode)}, case {rabbit_stomp_frame:header(Frame, ?HEADER_LOGIN), SSLLoginName} of @@ -310,14 +433,19 @@ creds(Frame, SSLLoginName, _ -> {stomp_headers, PasswordCreds} end. -login_header(Frame, Key, Default) when is_binary(Default) -> - login_header(Frame, Key, binary_to_list(Default)); -login_header(Frame, Key, Default) -> - case rabbit_stomp_frame:header(Frame, Key, Default) of - undefined -> undefined; - Hdr -> list_to_binary(Hdr) +auth_props_for_creds(Creds, #state{cfg = #cfg{ + vhost = VHost}}) -> + case Creds of + {ssl, {Username0, none}}-> {Username0, []}; + {_, {Username0, Password}} -> {Username0, [{password, Password}, + {vhost, VHost}]} end. +login_header(Frame, Key, Default) when is_list(Default) -> + login_header(Frame, Key, list_to_binary(Default)); +login_header(Frame, Key, Default) -> + rabbit_stomp_frame:header(Frame, Key, Default). + %%---------------------------------------------------------------------------- %% Frame Transformation %%---------------------------------------------------------------------------- @@ -334,12 +462,12 @@ report_missing_id_header(State) -> "Header 'id' is required for durable subscriptions", State). validate_frame(Command, Frame, State) - when Command =:= "SUBSCRIBE" orelse Command =:= "UNSUBSCRIBE" -> + when Command =:= 'SUBSCRIBE' orelse Command =:= 'UNSUBSCRIBE' -> Hdr = fun(Name) -> rabbit_stomp_frame:header(Frame, Name) end, case {Hdr(?HEADER_DURABLE), Hdr(?HEADER_PERSISTENT), Hdr(?HEADER_ID)} of - {{ok, "true"}, _, not_found} -> + {{ok, <<"true">>}, _, not_found} -> report_missing_id_header(State); - {_, {ok, "true"}, not_found} -> + {_, {ok, <<"true">>}, not_found} -> report_missing_id_header(State); _ -> ok(State) @@ -351,40 +479,54 @@ validate_frame(_Command, _Frame, State) -> %% Frame handlers %%---------------------------------------------------------------------------- -handle_frame("DISCONNECT", _Frame, State) -> +handle_frame('DISCONNECT', _Frame, State) -> {stop, normal, close_connection(State)}; -handle_frame("SUBSCRIBE", Frame, State) -> - with_destination("SUBSCRIBE", Frame, State, fun do_subscribe/4); +handle_frame('SUBSCRIBE', Frame, State) -> + with_destination('SUBSCRIBE', Frame, State, fun do_subscribe/4); -handle_frame("UNSUBSCRIBE", Frame, State) -> +handle_frame('UNSUBSCRIBE', Frame, State) -> ConsumerTag = rabbit_stomp_util:consumer_tag(Frame), cancel_subscription(ConsumerTag, Frame, State); -handle_frame("SEND", Frame, State) -> - without_headers(?HEADERS_NOT_ON_SEND, "SEND", Frame, State, - fun (_Command, Frame1, State1) -> - with_destination("SEND", Frame1, State1, fun do_send/4) - end); - -handle_frame("ACK", Frame, State) -> - ack_action("ACK", Frame, State, fun create_ack_method/3); +handle_frame('SEND', Frame, State) -> + maybe_with_transaction( + Frame, + fun(State0) -> + ensure_no_headers(?HEADERS_NOT_ON_SEND, 'SEND', Frame, State0, + fun (_Command, Frame1, State1) -> + with_destination('SEND', Frame1, State1, fun do_send/4) + end) + end, State); + +handle_frame('ACK', Frame, State) -> + maybe_with_transaction( + Frame, + fun(State0) -> + ack_action('ACK', Frame, State0, fun handle_ack/4) + end, + State); -handle_frame("NACK", Frame, State) -> - ack_action("NACK", Frame, State, fun create_nack_method/3); +handle_frame('NACK', Frame, State) -> + maybe_with_transaction( + Frame, + fun(State0) -> + ack_action('NACK', Frame, State0, fun handle_nack/4) + end, + State); -handle_frame("BEGIN", Frame, State) -> - transactional_action(Frame, "BEGIN", fun begin_transaction/2, State); +handle_frame('BEGIN', Frame, State) -> + transactional_action(Frame, 'BEGIN', fun begin_transaction/2, State); -handle_frame("COMMIT", Frame, State) -> - transactional_action(Frame, "COMMIT", fun commit_transaction/2, State); +handle_frame('COMMIT', Frame, State) -> + transactional_action(Frame, 'COMMIT', fun commit_transaction/2, State); -handle_frame("ABORT", Frame, State) -> - transactional_action(Frame, "ABORT", fun abort_transaction/2, State); +handle_frame('ABORT', Frame, State) -> + transactional_action(Frame, 'ABORT', fun abort_transaction/2, State); handle_frame(Command, _Frame, State) -> error("Bad command", - "Could not interpret command ~tp~n", + "Could not interpret command \"~ts\"~n", [Command], State). @@ -393,10 +535,10 @@ handle_frame(Command, _Frame, State) -> %%---------------------------------------------------------------------------- ack_action(Command, Frame, - State = #proc_state{subscriptions = Subs, - channel = Channel, - version = Version, - default_nack_requeue = DefaultNackRequeue}, MethodFun) -> + State = #state{subscriptions = Subs, + cfg = #cfg{ + version = Version, + default_nack_requeue = DefaultNackRequeue}}, Fun) -> AckHeader = rabbit_stomp_util:ack_header_name(Version), case rabbit_stomp_frame:header(Frame, AckHeader) of {ok, AckValue} -> @@ -404,16 +546,9 @@ ack_action(Command, Frame, {ok, {ConsumerTag, _SessionId, DeliveryTag}} -> case maps:find(ConsumerTag, Subs) of {ok, Sub} -> - Requeue = rabbit_stomp_frame:boolean_header(Frame, "requeue", DefaultNackRequeue), - Method = MethodFun(DeliveryTag, Sub, Requeue), - case transactional(Frame) of - {yes, Transaction} -> - extend_transaction( - Transaction, {Method}, State); - no -> - amqp_channel:call(Channel, Method), - ok(State) - end; + Requeue = rabbit_stomp_frame:boolean_header(Frame, <<"requeue">>, DefaultNackRequeue), + State1 = Fun(DeliveryTag, Sub, Requeue, State), + ok(State1); error -> error("Subscription not found", "Message with id ~tp has no subscription", @@ -421,10 +556,10 @@ ack_action(Command, Frame, State) end; _ -> - error("Invalid header", - "~tp must include a valid ~tp header~n", - [Command, AckHeader], - State) + error("Invalid header", + "~tp must include a valid ~tp header~n", + [Command, AckHeader], + State) end; not_found -> error("Missing header", @@ -437,7 +572,7 @@ ack_action(Command, Frame, %% Internal helpers for processing frames callbacks %%---------------------------------------------------------------------------- -server_cancel_consumer(ConsumerTag, State = #proc_state{subscriptions = Subs}) -> +server_cancel_consumer(ConsumerTag, State = #state{subscriptions = Subs}) -> case maps:find(ConsumerTag, Subs) of error -> error("Server cancelled unknown subscription", @@ -447,7 +582,7 @@ server_cancel_consumer(ConsumerTag, State = #proc_state{subscriptions = Subs}) - {ok, Subscription = #subscription{description = Description}} -> Id = case rabbit_stomp_util:tag_to_id(ConsumerTag) of {ok, {_, Id1}} -> Id1; - {error, {_, Id1}} -> "Unknown[" ++ Id1 ++ "]" + {error, {_, Id1}} -> <<"Unknown[", Id1/binary, "]">> end, _ = send_error_frame("Server cancelled subscription", [{?HEADER_SUBSCRIPTION, Id}], @@ -471,8 +606,9 @@ cancel_subscription({error, _}, _Frame, State) -> State); cancel_subscription({ok, ConsumerTag, Description}, Frame, - State = #proc_state{subscriptions = Subs, - channel = Channel}) -> + State = #state{subscriptions = Subs, + user = #user{username = Username}, + queue_states = QueueStates0}) -> case maps:find(ConsumerTag, Subs) of error -> error("No subscription found", @@ -480,56 +616,93 @@ cancel_subscription({ok, ConsumerTag, Description}, Frame, "Subscription to ~tp not found.~n", [Description], State); - {ok, Subscription = #subscription{description = Descr}} -> - case amqp_channel:call(Channel, - #'basic.cancel'{ - consumer_tag = ConsumerTag}) of - #'basic.cancel_ok'{consumer_tag = ConsumerTag} -> - tidy_canceled_subscription(ConsumerTag, Subscription, - Frame, State); - _ -> - error("Failed to cancel subscription", - "UNSUBSCRIBE to ~tp failed.~n", - [Descr], - State) + {ok, Subscription = #subscription{queue_name = Queue}} -> + + case rabbit_misc:with_exit_handler( + fun () -> {error, not_found} end, + fun () -> + %% default NoWait is false, so was the basic.cancel here + %% however there is no cancel.ok in the STOMP world + %% so OkMsg is undefined + rabbit_amqqueue:with_or_die( + Queue, + fun(Q1) -> + rabbit_queue_type:cancel( + Q1, + #{consumer_tag => ConsumerTag, + ok_msg => undefined, + user => Username}, + QueueStates0) + end) + end) of + {ok, QueueStates} -> + rabbit_global_counters:consumer_deleted( + State#state.cfg#cfg.proto_ver), + {ok, _, NewState} = tidy_canceled_subscription(ConsumerTag, Subscription, + Frame, State#state{queue_states = QueueStates}), + {ok, NewState}; + {error, not_found} -> + rabbit_global_counters:consumer_deleted( + State#state.cfg#cfg.proto_ver), + + {ok, _, NewState} = tidy_canceled_subscription(ConsumerTag, Subscription, + Frame, State), + {ok, NewState} end end. %% Server-initiated cancelations will pass an undefined instead of a %% STOMP frame. In this case we know that the queue was deleted and %% thus we don't have to clean it up. -tidy_canceled_subscription(ConsumerTag, _Subscription, - undefined, State = #proc_state{subscriptions = Subs}) -> - Subs1 = maps:remove(ConsumerTag, Subs), - ok(State#proc_state{subscriptions = Subs1}); +tidy_canceled_subscription(ConsumerTag, Subscription, + undefined, State) -> + tidy_canceled_subscription_state(ConsumerTag, Subscription, State); %% Client-initiated cancelations will pass an actual frame -tidy_canceled_subscription(ConsumerTag, #subscription{dest_hdr = DestHdr}, - Frame, State = #proc_state{subscriptions = Subs}) -> - Subs1 = maps:remove(ConsumerTag, Subs), - {ok, Dest} = rabbit_routing_parser:parse_endpoint(DestHdr), - maybe_delete_durable_sub(Dest, Frame, State#proc_state{subscriptions = Subs1}). -maybe_delete_durable_sub({topic, Name}, Frame, - State = #proc_state{channel = Channel}) -> +tidy_canceled_subscription(ConsumerTag, Subscription = #subscription{dest_hdr = DestHdr}, + Frame, State0) -> + {ok, State1} = tidy_canceled_subscription_state(ConsumerTag, Subscription, State0), + {ok, Dest} = parse_endpoint(DestHdr), + maybe_delete_durable_sub_queue(Dest, Frame, State1). + +tidy_canceled_subscription_state(ConsumerTag, + _Subscription = #subscription{queue_name = QName}, + State = #state{subscriptions = Subs, + queue_consumers = QCons}) -> + Subs1 = maps:remove(ConsumerTag, Subs), + QCons1 = + case maps:find(QName, QCons) of + error -> QCons; + {ok, CTags} -> CTags1 = gb_sets:delete(ConsumerTag, CTags), + case gb_sets:is_empty(CTags1) of + true -> maps:remove(QName, QCons); + false -> maps:put(QName, CTags1, QCons) + end + end, + {ok, State#state{subscriptions = Subs1, + queue_consumers = QCons1}}. + +maybe_delete_durable_sub_queue({topic, Name}, Frame, + State = #state{cfg = #cfg{auth_login = Username, + vhost = VHost}}) -> case rabbit_stomp_util:has_durable_header(Frame) of true -> {ok, Id} = rabbit_stomp_frame:header(Frame, ?HEADER_ID), QName = rabbit_stomp_util:subscription_queue_name(Name, Id, Frame), - amqp_channel:call(Channel, - #'queue.delete'{queue = list_to_binary(QName), - nowait = false}), + QRes = rabbit_misc:r(VHost, queue, QName), + delete_queue(QRes, Username), ok(State); false -> ok(State) end; -maybe_delete_durable_sub(_Destination, _Frame, State) -> +maybe_delete_durable_sub_queue(_Destination, _Frame, State) -> ok(State). with_destination(Command, Frame, State, Fun) -> case rabbit_stomp_frame:header(Frame, ?HEADER_DESTINATION) of {ok, DestHdr} -> - case rabbit_routing_parser:parse_endpoint(DestHdr) of + case parse_endpoint(DestHdr) of {ok, Destination} -> case Fun(Destination, DestHdr, Frame, State) of {error, invalid_endpoint} -> @@ -557,8 +730,7 @@ with_destination(Command, Frame, State, Fun) -> "'~ts' is not a valid destination.~n" "Valid destination types are: ~ts.~n", [Content, - string:join(rabbit_stomp_routing_util:all_dest_prefixes(), - ", ")], State) + lists:join(<<", ">>, ?ALL_DEST_PREFIXES)], State) end; not_found -> error("Missing destination", @@ -567,7 +739,7 @@ with_destination(Command, Frame, State, Fun) -> State) end. -without_headers([Hdr | Hdrs], Command, Frame, State, Fun) -> +ensure_no_headers([Hdr | Hdrs], Command, Frame, State, Fun) -> % case rabbit_stomp_frame:header(Frame, Hdr) of {ok, _} -> error("Invalid header", @@ -575,100 +747,30 @@ without_headers([Hdr | Hdrs], Command, Frame, State, Fun) -> [Hdr, Command], State); not_found -> - without_headers(Hdrs, Command, Frame, State, Fun) + ensure_no_headers(Hdrs, Command, Frame, State, Fun) end; -without_headers([], Command, Frame, State, Fun) -> +ensure_no_headers([], Command, Frame, State, Fun) -> Fun(Command, Frame, State). -do_login(undefined, _, _, _, _, _, State) -> - error("Bad CONNECT", "Missing login or passcode header(s)", State); -do_login(Username, Passwd, VirtualHost, Heartbeat, AdapterInfo, Version, - State = #proc_state{peer_addr = Addr}) -> - case start_connection( - #amqp_params_direct{username = Username, - password = Passwd, - virtual_host = VirtualHost, - adapter_info = AdapterInfo}, Username, Addr) of - {ok, Connection} -> - rabbit_access_control:clear_max_heap_size(), - link(Connection), - {ok, Channel} = amqp_connection:open_channel(Connection), - link(Channel), - amqp_channel:enable_delivery_flow_control(Channel), - SessionId = rabbit_guid:string(rabbit_guid:gen_secure(), "session"), - {SendTimeout, ReceiveTimeout} = ensure_heartbeats(Heartbeat), - - Headers = [{?HEADER_SESSION, SessionId}, - {?HEADER_HEART_BEAT, - io_lib:format("~B,~B", [SendTimeout, ReceiveTimeout])}, - {?HEADER_VERSION, Version}], - ok("CONNECTED", - case application:get_env(rabbitmq_stomp, hide_server_info, false) of - true -> Headers; - false -> [{?HEADER_SERVER, server_header()} | Headers] - end, - "", - State#proc_state{session_id = SessionId, - channel = Channel, - connection = Connection, - version = Version, - virtual_host = VirtualHost}); - {error, {auth_failure, _}} -> - ?LOG_WARNING("STOMP login failed for user '~ts': authentication failed", [Username]), - error("Bad CONNECT", "Access refused for user '" ++ - binary_to_list(Username) ++ "'", [], State); - {error, not_allowed} -> - ?LOG_WARNING("STOMP login failed for user '~ts': " - "virtual host access not allowed", [Username]), - error("Bad CONNECT", "Virtual host '" ++ - binary_to_list(VirtualHost) ++ - "' access denied", State); - {error, access_refused} -> - ?LOG_WARNING("STOMP login failed for user '~ts': " - "virtual host access not allowed", [Username]), - error("Bad CONNECT", "Virtual host '" ++ - binary_to_list(VirtualHost) ++ - "' access denied", State); - {error, not_loopback} -> - ?LOG_WARNING("STOMP login failed for user '~ts': " - "this user's access is restricted to localhost", [Username]), - error("Bad CONNECT", "non-loopback access denied", State) - end. - -start_connection(Params, Username, Addr) -> - case amqp_connection:start(Params) of - {ok, Conn} -> case rabbit_access_control:check_user_loopback( - Username, Addr) of - ok -> {ok, Conn}; - not_allowed -> amqp_connection:close(Conn), - {error, not_loopback} - end; - {error, E} -> {error, E} - end. - server_header() -> {ok, Product} = application:get_key(rabbit, description), {ok, Version} = application:get_key(rabbit, vsn), rabbit_misc:format("~ts/~ts", [Product, Version]). do_subscribe(Destination, DestHdr, Frame, - State = #proc_state{subscriptions = Subs, - channel = Channel, - default_topic_exchange = DfltTopicEx}) -> - check_subscription_access(Destination, State), + State0 = #state{subscriptions = Subs, + cfg = #cfg{default_topic_exchange = DfltTopicEx}, + queue_consumers = QCons}) -> + check_subscription_access(Destination, State0), + + {ok, {_Global, DefaultPrefetch}} = application:get_env(rabbit, default_consumer_prefetch), Prefetch = - rabbit_stomp_frame:integer_header(Frame, ?HEADER_PREFETCH_COUNT, - undefined), + rabbit_stomp_frame:integer_header(Frame, ?HEADER_PREFETCH_COUNT, DefaultPrefetch), + %% io:format("Prefetch: ~p~n", [Prefetch]), {AckMode, IsMulti} = rabbit_stomp_util:ack_mode(Frame), - case ensure_endpoint(source, Destination, Frame, Channel, State) of - {ok, Queue, RouteState1} -> - {ok, ConsumerTag, Description} = - rabbit_stomp_util:consumer_tag(Frame), - case Prefetch of - undefined -> ok; - _ -> amqp_channel:call( - Channel, #'basic.qos'{prefetch_count = Prefetch}) - end, + case ensure_endpoint(source, Destination, Frame, State0) of + {ok, QueueName, State} -> + {ok, ConsumerTag, Description} = rabbit_stomp_util:consumer_tag(Frame), case maps:find(ConsumerTag, Subs) of {ok, _} -> Message = "Duplicated subscription identifier", @@ -680,39 +782,40 @@ do_subscribe(Destination, DestHdr, Frame, ExchangeAndKey = parse_routing(Destination, DfltTopicEx), Arguments = subscribe_arguments(Frame), try - amqp_channel:subscribe(Channel, - #'basic.consume'{ - queue = Queue, - consumer_tag = ConsumerTag, - no_local = false, - no_ack = (AckMode == auto), - exclusive = false, - arguments = Arguments}, - self()), - ok = rabbit_stomp_routing_util:ensure_binding( - Queue, ExchangeAndKey, Channel) + {ok, State1} = consume_queue(QueueName, #{no_ack => (AckMode == auto), + mode => {simple_prefetch, Prefetch}, + consumer_tag => ConsumerTag, + exclusive_consume => false, + args => Arguments}, + State), + ok = ensure_binding(QueueName, ExchangeAndKey, State1), + CTags1 = case maps:find(QueueName, QCons) of + {ok, CTags} -> gb_sets:insert(ConsumerTag, CTags); + error -> gb_sets:singleton(ConsumerTag) + end, + QCons1 = maps:put(QueueName, CTags1, QCons), + ok(State1#state{subscriptions = maps:put( + ConsumerTag, + #subscription{dest_hdr = DestHdr, + ack_mode = AckMode, + multi_ack = IsMulti, + description = Description, + queue_name = QueueName}, + Subs), + queue_consumers = QCons1}) catch exit:Err -> %% it's safe to delete this queue, it %% was server-named and declared by us case Destination of {exchange, _} -> - ok = maybe_clean_up_queue(Queue, State); + ok = maybe_clean_up_queue(QueueName, State); {topic, _} -> - ok = maybe_clean_up_queue(Queue, State); + ok = maybe_clean_up_queue(QueueName, State); _ -> ok end, exit(Err) - end, - ok(State#proc_state{subscriptions = - maps:put( - ConsumerTag, - #subscription{dest_hdr = DestHdr, - ack_mode = AckMode, - multi_ack = IsMulti, - description = Description}, - Subs), - route_state = RouteState1}) + end end; {error, _} = Err -> Err @@ -736,7 +839,7 @@ subscribe_argument(?HEADER_X_STREAM_OFFSET, Frame, Acc) -> not_found -> Acc; {OffsetType, OffsetValue} -> - [{list_to_binary(?HEADER_X_STREAM_OFFSET), OffsetType, OffsetValue}] ++ Acc + [{?HEADER_X_STREAM_OFFSET, OffsetType, OffsetValue}] ++ Acc end; subscribe_argument(?HEADER_X_STREAM_FILTER, Frame, Acc) -> StreamFilter = rabbit_stomp_frame:stream_filter_header(Frame), @@ -744,13 +847,13 @@ subscribe_argument(?HEADER_X_STREAM_FILTER, Frame, Acc) -> not_found -> Acc; {FilterType, FilterValue} -> - [{list_to_binary(?HEADER_X_STREAM_FILTER), FilterType, FilterValue}] ++ Acc + [{?HEADER_X_STREAM_FILTER, FilterType, FilterValue}] ++ Acc end; subscribe_argument(?HEADER_X_STREAM_MATCH_UNFILTERED, Frame, Acc) -> MatchUnfiltered = rabbit_stomp_frame:boolean_header(Frame, ?HEADER_X_STREAM_MATCH_UNFILTERED), case MatchUnfiltered of {ok, MU} -> - [{list_to_binary(?HEADER_X_STREAM_MATCH_UNFILTERED), bool, MU}] ++ Acc; + [{?HEADER_X_STREAM_MATCH_UNFILTERED, bool, MU}] ++ Acc; not_found -> Acc end; @@ -758,151 +861,384 @@ subscribe_argument(?HEADER_X_PRIORITY, Frame, Acc) -> Priority = rabbit_stomp_frame:integer_header(Frame, ?HEADER_X_PRIORITY), case Priority of {ok, P} -> - [{list_to_binary(?HEADER_X_PRIORITY), byte, P}] ++ Acc; + [{?HEADER_X_PRIORITY, byte, P}] ++ Acc; not_found -> Acc end. check_subscription_access(Destination = {topic, _Topic}, - #proc_state{auth_login = _User, - connection = Connection, - default_topic_exchange = DfltTopicEx}) -> - [{amqp_params, AmqpParams}, {internal_user, InternalUser = #user{username = Username}}] = - amqp_connection:info(Connection, [amqp_params, internal_user]), - #amqp_params_direct{virtual_host = VHost} = AmqpParams, + #state{user = #user{username = Username} = User, + cfg = #cfg{ + default_topic_exchange = DfltTopicEx, + vhost = VHost}}) -> {Exchange, RoutingKey} = parse_routing(Destination, DfltTopicEx), Resource = #resource{virtual_host = VHost, - kind = topic, - name = rabbit_data_coercion:to_binary(Exchange)}, + kind = topic, + name = rabbit_data_coercion:to_binary(Exchange)}, Context = #{routing_key => rabbit_data_coercion:to_binary(RoutingKey), variable_map => #{<<"vhost">> => VHost, <<"username">> => Username} - }, - rabbit_access_control:check_topic_access(InternalUser, Resource, read, Context); + }, + rabbit_access_control:check_topic_access(User, Resource, read, Context); check_subscription_access(_, _) -> authorized. -maybe_clean_up_queue(Queue, #proc_state{connection = Connection}) -> - {ok, Channel} = amqp_connection:open_channel(Connection), - catch amqp_channel:call(Channel, #'queue.delete'{queue = Queue}), - catch amqp_channel:close(Channel), +maybe_clean_up_queue(Queue, #state{cfg = #cfg{auth_login = Username}}) -> + catch delete_queue(Queue, Username), ok. do_send(Destination, _DestHdr, - Frame = #stomp_frame{body_iolist = BodyFragments}, - State = #proc_state{channel = Channel, - default_topic_exchange = DfltTopicEx}) -> - case ensure_endpoint(dest, Destination, Frame, Channel, State) of - - {ok, _Q, RouteState1} -> - + Frame = #stomp_frame{body_iolist_rev = BodyFragments}, + State00 = #state{ + user = #user{username = Username} = User, + authz_ctx = AuthzCtx, + publisher = IsPublisher, + cfg = #cfg{ + proto_ver = ProtoVer, + delivery_flow = Flow, + conn_info = #conn_info{conn_name = ConnName}, + trace_state = TraceState, + default_topic_exchange = DfltTopicEx, + vhost = VHost}}) -> + State0 = case IsPublisher of + true -> State00; + false -> rabbit_global_counters:publisher_created(ProtoVer), + State00#state{publisher = true} + end, + case ensure_endpoint(dest, Destination, Frame, State0) of + {ok, _Q, State} -> {Frame1, State1} = - ensure_reply_to(Frame, State#proc_state{route_state = RouteState1}), + ensure_reply_to(Frame, State), Props = rabbit_stomp_util:message_properties(Frame1), - {Exchange, RoutingKey} = parse_routing(Destination, DfltTopicEx), - - Method = #'basic.publish'{ - exchange = list_to_binary(Exchange), - routing_key = list_to_binary(RoutingKey), - mandatory = false, - immediate = false}, - - case transactional(Frame1) of - {yes, Transaction} -> - extend_transaction( - Transaction, - fun(StateN) -> - maybe_record_receipt(Frame1, StateN) - end, - {Method, Props, BodyFragments}, - State1); - no -> - ok(send_method(Method, Props, BodyFragments, - maybe_record_receipt(Frame1, State1))) - end; + {Exchange0, RoutingKey} = parse_routing(Destination, DfltTopicEx), + + rabbit_global_counters:messages_received(ProtoVer, 1), + + ExchangeName = rabbit_misc:r(VHost, exchange, Exchange0), + check_resource_access(User, ExchangeName, write, AuthzCtx), + Exchange = rabbit_exchange:lookup_or_die(ExchangeName), + check_internal_exchange(Exchange), + check_topic_authorisation(Exchange, User, RoutingKey, AuthzCtx, write), + + {DeliveryOptions, _MsgSeqNo, State2} = + case rabbit_stomp_frame:header(Frame, ?HEADER_RECEIPT) of + not_found -> + {maps_put_truthy(flow, Flow, #{}), undefined, State1}; + {ok, Id} -> + rabbit_global_counters:messages_received_confirm(ProtoVer, 1), + SeqNo = State1#state.msg_seq_no, + StateRR = record_receipt(true, SeqNo, Id, State1), + Opts = maps_put_truthy(flow, Flow, #{correlation => SeqNo}), + {Opts, SeqNo, StateRR#state{msg_seq_no = SeqNo + 1}} + end, - {error, _} = Err -> + {ClassId, _MethodId} = rabbit_framing_amqp_0_9_1:method_id('basic.publish'), + + Content0 = #content{ + class_id = ClassId, + properties = Props, + properties_bin = none, + protocol = none, + payload_fragments_rev = BodyFragments + }, + + {ok, Message0} = mc_amqpl:message(ExchangeName, RoutingKey, Content0), + MsgIcptCtx = State2#state.cfg#cfg.msg_interceptor_ctx, + Message = rabbit_msg_interceptor:intercept_incoming(Message0, MsgIcptCtx), + QNames = rabbit_exchange:route(Exchange, Message, #{return_binding_keys => true}), + Queues = rabbit_db_queue:get_targets(QNames), + rabbit_trace:tap_in(Message, QNames, ConnName, Username, TraceState), + + Delivery = {Message, DeliveryOptions, Queues}, + deliver_to_queues(ExchangeName, Delivery, State2); + {error, _} = Err -> Err end. -create_ack_method(DeliveryTag, #subscription{multi_ack = IsMulti}, _) -> - #'basic.ack'{delivery_tag = DeliveryTag, - multiple = IsMulti}. +deliver_to_queues(_XName, + {_Message, Options, _RoutedToQueues = []}, + #state{cfg = #cfg{proto_ver = ProtoVer}} = State) + when not is_map_key(correlation, Options) -> + rabbit_global_counters:messages_unroutable_dropped(ProtoVer, 1), + {ok, State}; + +deliver_to_queues(XName, + {Message, Options, RoutedToQNames}, + State0 = #state{cfg = #cfg{proto_ver = ProtoVer}, + queue_states = QStates0}) -> + Qs = rabbit_amqqueue:prepend_extra_bcc(RoutedToQNames), + MsgSeqNo = maps:get(correlation, Options, undefined), + case rabbit_queue_type:deliver(Qs, Message, Options, QStates0) of + {ok, QStates, Actions} -> + rabbit_global_counters:messages_routed(ProtoVer, length(Qs)), + QueueNames = rabbit_amqqueue:queue_names(Qs), + State1 = process_routing_confirm(MsgSeqNo, QueueNames, XName, State0), + %% Actions must be processed after registering confirms as actions may + %% contain rejections of publishes. + {ok, handle_queue_actions(Actions, State1#state{queue_states = QStates})}; + {error, Reason} -> + log_error("Failed to deliver message with packet_id=~p to queues: ~p", + [MsgSeqNo, Reason], none), + {error, Reason, State0} + end. + + +record_rejects([], State) -> + State; +record_rejects(MXs, State = #state{rejected = R}) -> + State#state{rejected = [MXs | R]}. + +record_confirms([], State) -> + State; +record_confirms(MXs, State = #state{confirmed = C}) -> + State#state{confirmed = [MXs | C]}. + +process_routing_confirm(undefined, _, _, State) -> + State; +process_routing_confirm(MsgSeqNo, QRefs, XName, State) -> + State#state{unconfirmed = + rabbit_confirms:insert(MsgSeqNo, QRefs, XName, State#state.unconfirmed)}. + +confirm(MsgSeqNos, QRef, State = #state{unconfirmed = UC}) -> + %% NOTE: if queue name does not exist here it's likely that the ref also + %% does not exist in unconfirmed messages. + %% Neither does the 'ignore' atom, so it's a reasonable fallback. + {ConfirmMXs, UC1} = rabbit_confirms:confirm(MsgSeqNos, QRef, UC), + %% NB: don't call noreply/1 since we don't want to send confirms. + record_confirms(ConfirmMXs, State#state{unconfirmed = UC1}). + +send_confirms_and_nacks(State = #state{confirmed = [], rejected = []}) -> + State; +send_confirms_and_nacks(State = #state{confirmed = C, rejected = R}) -> + Confirms = lists:append(C), + Rejects = lists:append(R), + ConfirmMsgSeqNos = + lists:foldl( + fun ({MsgSeqNo, _XName}, MSNs) -> + [MsgSeqNo | MSNs] + end, [], Confirms), + RejectMsgSeqNos = [MsgSeqNo || {MsgSeqNo, _} <- Rejects], + State1 = send_confirms(ConfirmMsgSeqNos, + RejectMsgSeqNos, + State#state{confirmed = []}), + State1#state{rejected = []}. + +send_confirms([], _, State) -> + State; +send_confirms([MsgSeqNo], _, State) -> + flush_pending_receipts(MsgSeqNo, false, State); +send_confirms(Cs, Rs, State) -> + coalesce_and_send(Cs, Rs, + fun(MsgSeqNo, Multiple, StateN) -> + flush_pending_receipts(MsgSeqNo, Multiple, StateN) + end, State). + +coalesce_and_send(MsgSeqNos, NegativeMsgSeqNos, MkMsgFun, State = #state{unconfirmed = UC}) -> + SMsgSeqNos = lists:usort(MsgSeqNos), + UnconfirmedCutoff = case rabbit_confirms:is_empty(UC) of + true -> lists:last(SMsgSeqNos) + 1; + false -> rabbit_confirms:smallest(UC) + end, + Cutoff = lists:min([UnconfirmedCutoff | NegativeMsgSeqNos]), + {Ms, Ss} = lists:splitwith(fun(X) -> X < Cutoff end, SMsgSeqNos), + State1 = case Ms of + [] -> State; + _ -> MkMsgFun(lists:last(Ms), true, State) + end, + lists:foldl(fun(SeqNo, StateN) -> + MkMsgFun(SeqNo, false, StateN) + end, State1, Ss). + +handle_ack(DeliveryTag, #subscription{multi_ack = IsMulti}, _, State = #state{unacked_message_q = UAMQ}) -> + {Acked, Remaining} = collect_acks(UAMQ, DeliveryTag, IsMulti), + State1 = State#state{unacked_message_q = Remaining}, + {State2, Actions} = settle_acks(Acked, State1), + handle_queue_actions(Actions, State2). + +handle_nack(DeliveryTag, #subscription{multi_ack = IsMulti}, Requeue, State = #state{unacked_message_q = UAMQ}) -> + {Acked, Remaining} = collect_acks(UAMQ, DeliveryTag, IsMulti), + State1 = State#state{unacked_message_q = Remaining}, + {State2, Actions} = internal_reject(Requeue, Acked, State1), + handle_queue_actions(Actions, State2). + +%% Records a client-sent acknowledgement. Handles both single delivery acks +%% and multi-acks. +%% +%% Returns a tuple of acknowledged pending acks and remaining pending acks. +%% Sorts each group in the youngest-first order (descending by delivery tag). +collect_acks(UAMQ, DeliveryTag, Multiple) -> + collect_acks([], [], UAMQ, DeliveryTag, Multiple). + +collect_acks(AcknowledgedAcc, RemainingAcc, UAMQ, DeliveryTag, Multiple) -> + case ?QUEUE:out(UAMQ) of + {{value, UnackedMsg = #pending_ack{delivery_tag = CurrentDT}}, + UAMQTail} -> + if CurrentDT == DeliveryTag -> + {[UnackedMsg | AcknowledgedAcc], + case RemainingAcc of + [] -> UAMQTail; + _ -> ?QUEUE:join( + ?QUEUE:from_list(lists:reverse(RemainingAcc)), + UAMQTail) + end}; + Multiple -> + collect_acks([UnackedMsg | AcknowledgedAcc], RemainingAcc, + UAMQTail, DeliveryTag, Multiple); + true -> + collect_acks(AcknowledgedAcc, [UnackedMsg | RemainingAcc], + UAMQTail, DeliveryTag, Multiple) + end; + {empty, _} -> + error("Unknown delivery tag", + "unknown delivery tag ~w", [DeliveryTag]) + end. -create_nack_method(DeliveryTag, #subscription{multi_ack = IsMulti}, Requeue) -> - #'basic.nack'{delivery_tag = DeliveryTag, - multiple = IsMulti, - requeue = Requeue}. +foreach_per_queue(F, [#pending_ack{tag = CTag, + queue = QName, + msg_id = MsgId}], Acc) -> + %% quorum queue, needs the consumer tag + F({QName, CTag}, [MsgId], Acc); +foreach_per_queue(F, UAL, Acc) -> + T = lists:foldl(fun (#pending_ack{tag = CTag, + queue = QName, + msg_id = MsgId}, T) -> + rabbit_misc:gb_trees_cons({QName, CTag}, MsgId, T) + end, gb_trees:empty(), UAL), + rabbit_misc:gb_trees_fold(fun (Key, Val, Acc0) -> F(Key, Val, Acc0) end, Acc, T). + +settle_acks(Acks, State = #state{queue_states = QueueStates0}) -> + {QueueStates, Actions} = + foreach_per_queue( + fun ({QRef, CTag}, MsgIds, {Acc0, ActionsAcc0}) -> + case rabbit_queue_type:settle(QRef, complete, CTag, + MsgIds, Acc0) of + {ok, Acc, ActionsAcc} -> + %% incr_queue_stats(QRef, MsgIds, State), + {Acc, ActionsAcc0 ++ ActionsAcc}; + {protocol_error, ErrorType, Reason, ReasonArgs} -> + rabbit_misc:protocol_error(ErrorType, Reason, ReasonArgs) + end + end, Acks, {QueueStates0, []}), + {State#state{queue_states = QueueStates}, Actions}. + +%% NB: Acked is in youngest-first order +internal_reject(Requeue, Acked, + State = #state{queue_states = QueueStates0}) -> + {QueueStates, Actions} = + foreach_per_queue( + fun({QRef, CTag}, MsgIds, {Acc0, Actions0}) -> + Op = case Requeue of + false -> discard; + true -> requeue + end, + case rabbit_queue_type:settle(QRef, Op, CTag, MsgIds, Acc0) of + {ok, Acc, Actions} -> + {Acc, Actions0 ++ Actions}; + {protocol_error, ErrorType, Reason, ReasonArgs} -> + rabbit_misc:protocol_error(ErrorType, Reason, ReasonArgs) + end + end, Acked, {QueueStates0, []}), + {State#state{queue_states = QueueStates}, Actions}. negotiate_version(Frame) -> ClientVers = re:split(rabbit_stomp_frame:header( - Frame, ?HEADER_ACCEPT_VERSION, "1.0"), - ",", [{return, list}]), + Frame, ?HEADER_ACCEPT_VERSION, <<"1.0">>), + <<",">>, [{return, list}]), rabbit_stomp_util:negotiate_version(ClientVers, ?SUPPORTED_VERSIONS). -send_delivery(Delivery = #'basic.deliver'{consumer_tag = ConsumerTag}, +deliver_to_client(ConsumerTag, Ack, Msgs, State) -> + lists:foldl(fun(Msg, S) -> + deliver_one_to_client(ConsumerTag, Ack, Msg, S) + end, State, Msgs). + +deliver_one_to_client(ConsumerTag, _Ack, + {QName, QPid, MsgId, Redelivered, MsgCont0} = Msg, + State = #state{queue_states = QStates, + delivery_tag = DeliveryTag, + cfg = #cfg{trace_state = TraceState, + msg_interceptor_ctx = MsgIcptCtx, + conn_info = #conn_info{conn_name = ConnName}}, + user = #user{username = Username}}) -> + + [RoutingKey | _] = mc:routing_keys(MsgCont0), + ExchangeNameBin = mc:exchange(MsgCont0), + MsgCont1 = rabbit_msg_interceptor:intercept_outgoing(MsgCont0, MsgIcptCtx), + MsgCont = mc:convert(mc_amqpl, MsgCont1), + Content = mc:protocol_state(MsgCont), + {Props, Payload} = rabbit_basic_common:from_content(Content), + + rabbit_trace:tap_out(Msg, ConnName, Username, TraceState), + + DeliveryCtx = case rabbit_queue_type:module(QName, QStates) of + {ok, rabbit_classic_queue} -> + {ok, QPid, ok}; + _ -> undefined + end, + + State1 = send_delivery(QName, MsgId, + ConsumerTag, DeliveryTag, + ExchangeNameBin, RoutingKey, Redelivered, + Props, Payload, DeliveryCtx, State), + + State1#state{delivery_tag = DeliveryTag + 1}. + + +send_delivery(QName, MsgId, + ConsumerTag, DeliveryTag, + ExchangeNameBin, RoutingKey, Redelivered, Properties, Body, DeliveryCtx, - State = #proc_state{ - session_id = SessionId, - subscriptions = Subs, - version = Version}) -> - NewState = case maps:find(ConsumerTag, Subs) of + State = #state{ + cfg = #cfg{ + session_id = SessionId, + version = Version + }, + subscriptions = Subs, + unacked_message_q = UAMQ}) -> + case maps:find(ConsumerTag, Subs) of {ok, #subscription{ack_mode = AckMode}} -> - send_frame( - "MESSAGE", - rabbit_stomp_util:headers(SessionId, Delivery, Properties, - AckMode, Version), - Body, - State); + NewState = send_frame( + 'MESSAGE', + rabbit_stomp_util:headers( + SessionId, ConsumerTag, DeliveryTag, + ExchangeNameBin, RoutingKey, Redelivered, + Properties, AckMode, Version), + Body, + State), + maybe_notify_sent(DeliveryCtx), + case AckMode of + client -> + DeliveredAt = os:system_time(millisecond), + NewState#state{unacked_message_q = + ?QUEUE:in(#pending_ack{delivery_tag = DeliveryTag, + tag = ConsumerTag, + delivered_at = DeliveredAt, + queue = QName, + msg_id = MsgId}, UAMQ)}; + _ -> NewState + end; error -> send_error("Subscription not found", "There is no current subscription with tag '~ts'.", [ConsumerTag], State) - end, - notify_received(DeliveryCtx), - NewState. - -notify_received(undefined) -> - %% no notification for quorum queues and streams - ok; -notify_received(DeliveryCtx) -> - %% notification for flow control - amqp_channel:notify_received(DeliveryCtx). - -send_method(Method, Channel, State) -> - amqp_channel:call(Channel, Method), - State. - -send_method(Method, State = #proc_state{channel = Channel}) -> - send_method(Method, Channel, State). + end. -send_method(Method, Properties, BodyFragments, - State = #proc_state{channel = Channel}) -> - send_method(Method, Channel, Properties, BodyFragments, State). +maybe_notify_sent(undefined) -> + ok; +maybe_notify_sent({_, QPid, _}) -> + ok = rabbit_amqqueue:notify_sent(QPid, self()). -send_method(Method = #'basic.publish'{}, Channel, Properties, BodyFragments, - State) -> - amqp_channel:cast_flow( - Channel, Method, - #amqp_msg{props = Properties, - payload = list_to_binary(BodyFragments)}), +close_connection(State = #state{publisher = IsPublisher, + cfg = #cfg{proto_ver = ProtoVer}}) -> + case IsPublisher andalso ProtoVer =/= undefined of + true -> rabbit_global_counters:publisher_deleted(ProtoVer); + false -> ok + end, State. -close_connection(State = #proc_state{connection = none}) -> - State; -%% Closing the connection will close the channel and subchannels -close_connection(State = #proc_state{connection = Connection}) -> - %% ignore noproc or other exceptions to avoid debris - catch amqp_connection:close(Connection), - State#proc_state{channel = none, connection = none, subscriptions = none}; -close_connection(undefined) -> - ?LOG_DEBUG("~ts:close_connection: undefined state", [?MODULE]), - #proc_state{channel = none, connection = none, subscriptions = none}. - %%---------------------------------------------------------------------------- %% Reply-To %%---------------------------------------------------------------------------- @@ -912,53 +1248,46 @@ ensure_reply_to(Frame = #stomp_frame{headers = Headers}, State) -> not_found -> {Frame, State}; {ok, ReplyTo} -> - {ok, Destination} = rabbit_routing_parser:parse_endpoint(ReplyTo), - case rabbit_stomp_routing_util:dest_temp_queue(Destination) of + {ok, Destination} = parse_endpoint(ReplyTo), + case dest_temp_queue(Destination) of none -> {Frame, State}; TempQueueId -> {ReplyQueue, State1} = ensure_reply_queue(TempQueueId, State), {Frame#stomp_frame{ - headers = lists:keyreplace( - ?HEADER_REPLY_TO, 1, Headers, - {?HEADER_REPLY_TO, ReplyQueue})}, + headers = Headers#{?HEADER_REPLY_TO => ReplyQueue}}, State1} end end. -ensure_reply_queue(TempQueueId, State = #proc_state{channel = Channel, - reply_queues = RQS, - subscriptions = Subs}) -> +ensure_reply_queue(TempQueueId, State = #state{reply_queues = RQS, + subscriptions = Subs}) -> case maps:find(TempQueueId, RQS) of {ok, RQ} -> - {binary_to_list(RQ), State}; + {RQ, State}; error -> - #'queue.declare_ok'{queue = Queue} = - amqp_channel:call(Channel, - #'queue.declare'{auto_delete = true, - exclusive = true}), + {ok, Queue} = create_queue(State), + #resource{name = QNameBin} = QName = amqqueue:get_name(Queue), ConsumerTag = rabbit_stomp_util:consumer_tag_reply_to(TempQueueId), - #'basic.consume_ok'{} = - amqp_channel:subscribe(Channel, - #'basic.consume'{ - queue = Queue, - consumer_tag = ConsumerTag, - no_ack = true, - nowait = false}, - self()), - Destination = binary_to_list(Queue), + {ok, {_Global, DefaultPrefetch}} = application:get_env(rabbit, default_consumer_prefetch), + Spec = #{no_ack => true, + mode => {simple_prefetch, DefaultPrefetch}, + consumer_tag => ConsumerTag, + exclusive_consume => false, + args => []}, + {ok, State1} = consume_queue(QName, Spec, State), %% synthesise a subscription to the reply queue destination Subs1 = maps:put(ConsumerTag, - #subscription{dest_hdr = Destination, + #subscription{dest_hdr = QNameBin, multi_ack = false}, Subs), - {Destination, State#proc_state{ - reply_queues = maps:put(TempQueueId, Queue, RQS), + {QNameBin, State1#state{ + reply_queues = maps:put(TempQueueId, QNameBin, RQS), subscriptions = Subs1}} end. @@ -972,39 +1301,22 @@ ensure_receipt(Frame = #stomp_frame{command = Command}, State) -> not_found -> State end. -do_receipt("SEND", _, State) -> +do_receipt('SEND', _, State) -> %% SEND frame receipts are handled when messages are confirmed State; do_receipt(_Frame, ReceiptId, State) -> - send_frame("RECEIPT", [{"receipt-id", ReceiptId}], "", State). + send_frame('RECEIPT', #{<<"receipt-id">> => ReceiptId}, <<>>, State). -maybe_record_receipt(Frame, State = #proc_state{channel = Channel, - pending_receipts = PR}) -> - case rabbit_stomp_frame:header(Frame, ?HEADER_RECEIPT) of - {ok, Id} -> - PR1 = case PR of - undefined -> - amqp_channel:register_confirm_handler( - Channel, self()), - #'confirm.select_ok'{} = - amqp_channel:call(Channel, #'confirm.select'{}), - gb_trees:empty(); - _ -> - PR - end, - SeqNo = amqp_channel:next_publish_seqno(Channel), - State#proc_state{pending_receipts = gb_trees:insert(SeqNo, Id, PR1)}; - not_found -> - State - end. +record_receipt(_DoConfirm = true, MsgSeqNo, ReceiptId, State = #state{pending_receipts = PR}) -> + State#state{pending_receipts = gb_trees:insert(MsgSeqNo, ReceiptId, PR)}. flush_pending_receipts(DeliveryTag, IsMulti, - State = #proc_state{pending_receipts = PR}) -> + State = #state{pending_receipts = PR}) -> {Receipts, PR1} = accumulate_receipts(DeliveryTag, IsMulti, PR), State1 = lists:foldl(fun(ReceiptId, StateN) -> do_receipt(none, ReceiptId, StateN) end, State, Receipts), - State1#proc_state{pending_receipts = PR1}. + State1#state{pending_receipts = PR1}. accumulate_receipts(DeliveryTag, false, PR) -> case gb_trees:lookup(DeliveryTag, PR) of @@ -1051,6 +1363,18 @@ transactional_action(Frame, Name, Fun, State) -> State) end. +maybe_with_transaction(Frame, Fun, State) -> + case transactional(Frame) of + {yes, Transaction} -> + extend_transaction( + Transaction, + Fun, + Frame, + State); + no -> + Fun(State) + end. + with_transaction(Transaction, State, Fun) -> case get({transaction, Transaction}) of undefined -> @@ -1063,45 +1387,60 @@ with_transaction(Transaction, State, Fun) -> end. begin_transaction(Transaction, State) -> - put({transaction, Transaction}, []), - ok(State). - -extend_transaction(Transaction, Callback, Action, State) -> - extend_transaction(Transaction, {callback, Callback, Action}, State). + case transaction_count() >= ?MAX_TRANSACTIONS of + true -> + error("Transaction limit exceeded", + "Too many concurrent transactions (limit ~B)", + [?MAX_TRANSACTIONS], State); + false -> + put({transaction, Transaction}, []), + ok(State) + end. -extend_transaction(Transaction, Action, State0) -> +extend_transaction(Transaction, Fun, Frame, State0) -> with_transaction( Transaction, State0, - fun (Actions, State) -> - put({transaction, Transaction}, [Action | Actions]), - ok(State) + fun (Funs, State) -> + case length(Funs) >= ?MAX_TRANSACTION_ACTIONS of + true -> + error("Transaction too large", + "Too many actions in transaction (limit ~B)", + [?MAX_TRANSACTION_ACTIONS], State); + false -> + put({transaction, Transaction}, [{Frame, Fun} | Funs]), + ok(State) + end end). +transaction_count() -> + length([1 || {{transaction, _}, _} <- get()]). + commit_transaction(Transaction, State0) -> with_transaction( Transaction, State0, - fun (Actions, State) -> + fun (Funs, State) -> FinalState = lists:foldr(fun perform_transaction_action/2, - State, - Actions), + {ok, State}, + Funs), erase({transaction, Transaction}), - ok(FinalState) + FinalState end). abort_transaction(Transaction, State0) -> with_transaction( Transaction, State0, - fun (_Actions, State) -> + fun (_Frames, State) -> erase({transaction, Transaction}), ok(State) end). -perform_transaction_action({callback, Callback, Action}, State) -> - perform_transaction_action(Action, Callback(State)); -perform_transaction_action({Method}, State) -> - send_method(Method, State); -perform_transaction_action({Method, Props, BodyFragments}, State) -> - send_method(Method, Props, BodyFragments, State). +perform_transaction_action(_, {stop, _, _} = Res) -> + Res; +perform_transaction_action({Frame, Fun}, {ok, State}) -> + process_request( + Fun, + fun(StateM) -> ensure_receipt(Frame, StateM) end, + State). %%-------------------------------------------------------------------- %% Heartbeat Management @@ -1109,8 +1448,8 @@ perform_transaction_action({Method, Props, BodyFragments}, State) -> ensure_heartbeats(Heartbeats) -> - [CX, CY] = [list_to_integer(X) || - X <- re:split(Heartbeats, ",", [{return, list}])], + [CX, CY] = [binary_to_integer(X) || + X <- binary:split(Heartbeats, <<",">>)], {SendTimeout, ReceiveTimeout} = {millis_to_seconds(CY), millis_to_seconds(CX)}, @@ -1126,34 +1465,27 @@ millis_to_seconds(M) -> M div 1000. %% Queue Setup %%---------------------------------------------------------------------------- -ensure_endpoint(_Direction, {queue, []}, _Frame, _Channel, _State) -> +ensure_endpoint(_Direction, {queue, <<>>}, _Frame, _State) -> {error, {invalid_destination, "Destination cannot be blank"}}; -ensure_endpoint(source, EndPoint, {_, _, Headers, _} = Frame, Channel, - #proc_state{virtual_host = VHost, route_state = RouteState}) -> +ensure_endpoint(source, EndPoint, {_, _, Headers, _} = Frame, State) -> Params = [{subscription_queue_name_gen, fun () -> - Id = build_subscription_id(Frame), - % Note: we discard the exchange here so there's no need to use - % the default_topic_exchange configuration key - {_, Name} = rabbit_routing_parser:parse_routing(EndPoint), - list_to_binary(rabbit_stomp_util:subscription_queue_name(Name, Id, Frame)) + Id = build_subscription_id(Frame), + % Note: we discard the exchange here so there's no need to use + % the default_topic_exchange configuration key + {_, Name} = parse_routing(EndPoint), + rabbit_stomp_util:subscription_queue_name(Name, Id, Frame) end - }, - {default_queue_type, rabbit_vhost:default_queue_type(VHost)}] - ++ rabbit_stomp_util:build_params(EndPoint, Headers), + }] ++ rabbit_stomp_util:build_params(EndPoint, Headers), Arguments = rabbit_stomp_util:build_arguments(Headers), - rabbit_stomp_routing_util:ensure_endpoint(source, Channel, EndPoint, - [Arguments | Params], RouteState); + util_ensure_endpoint(source, EndPoint, [Arguments | Params], State); -ensure_endpoint(Direction, EndPoint, {_, _, Headers, _}, Channel, - #proc_state{virtual_host = VHost, route_state = RouteState}) -> - Params = [{default_queue_type, rabbit_vhost:default_queue_type(VHost)} - | rabbit_stomp_util:build_params(EndPoint, Headers)], +ensure_endpoint(Direction, EndPoint, {_, _, Headers, _}, State) -> + Params = rabbit_stomp_util:build_params(EndPoint, Headers), Arguments = rabbit_stomp_util:build_arguments(Headers), - rabbit_stomp_routing_util:ensure_endpoint(Direction, Channel, EndPoint, - [Arguments | Params], RouteState). + util_ensure_endpoint(Direction, EndPoint, [Arguments | Params], State). build_subscription_id(Frame) -> case rabbit_stomp_util:has_durable_header(Frame) of @@ -1174,14 +1506,14 @@ ok(State) -> ok(Command, Headers, BodyFragments, State) -> {ok, #stomp_frame{command = Command, headers = Headers, - body_iolist = BodyFragments}, State}. + body_iolist_rev = lists:reverse(BodyFragments)}, State}. -amqp_death(access_refused = ErrorName, Explanation, State) -> +amqp_death(ErrorName, Explanation, State) when is_atom(ErrorName) -> ErrorDesc = rabbit_misc:format("~ts", [Explanation]), log_error(ErrorName, ErrorDesc, none), {stop, normal, close_connection(send_error(atom_to_list(ErrorName), ErrorDesc, State))}; amqp_death(ReplyCode, Explanation, State) -> - ErrorName = amqp_connection:error_atom(ReplyCode), + ErrorName = rabbit_framing_amqp_0_9_1:amqp_exception(ReplyCode), ErrorDesc = rabbit_misc:format("~ts", [Explanation]), log_error(ErrorName, ErrorDesc, none), {stop, normal, close_connection(send_error(atom_to_list(ErrorName), ErrorDesc, State))}. @@ -1214,11 +1546,11 @@ log_error(Message, Detail, ServerPrivateDetail) -> send_frame(Command, Headers, BodyFragments, State) -> send_frame(#stomp_frame{command = Command, headers = Headers, - body_iolist = BodyFragments}, + body_iolist_rev = BodyFragments}, State). -send_frame(Frame, State = #proc_state{send_fun = SendFun, - trailing_lf = TrailingLF}) -> +send_frame(Frame, State = #state{cfg = #cfg{send_fun = SendFun, + trailing_lf = TrailingLF}}) -> SendFun(rabbit_stomp_frame:serialize(Frame, TrailingLF)), State. @@ -1227,11 +1559,11 @@ send_error_frame(Message, ExtraHeaders, Format, Args, State) -> State). send_error_frame(Message, ExtraHeaders, Detail, State) -> - send_frame("ERROR", [{"message", Message}, - {"content-type", "text/plain"}, - {"version", string:join(?SUPPORTED_VERSIONS, ",")}] ++ - ExtraHeaders, - Detail, State). + BaseHeaders = #{<<"message">> => iolist_to_binary(Message), + <<"content-type">> => <<"text/plain">>, + <<"version">> => iolist_to_binary(string:join(?SUPPORTED_VERSIONS, ","))}, + Headers = maps:merge(BaseHeaders, maps:from_list(ExtraHeaders)), + send_frame('ERROR', Headers, iolist_to_binary(Detail), State). send_error(Message, Detail, State) -> send_error_frame(Message, [], Detail, State). @@ -1239,26 +1571,478 @@ send_error(Message, Detail, State) -> send_error(Message, Format, Args, State) -> send_error(Message, rabbit_misc:format(Format, Args), State). -additional_info(Key, - #proc_state{adapter_info = - #amqp_adapter_info{additional_info = AddInfo}}) -> - proplists:get_value(Key, AddInfo). - parse_routing(Destination, DefaultTopicExchange) -> - {Exchange0, RoutingKey} = rabbit_routing_parser:parse_routing(Destination), + {Exchange0, RoutingKey} = parse_routing(Destination), Exchange1 = maybe_apply_default_topic_exchange(Exchange0, DefaultTopicExchange), {Exchange1, RoutingKey}. -maybe_apply_default_topic_exchange("amq.topic"=Exchange, <<"amq.topic">>=_DefaultTopicExchange) -> +maybe_apply_default_topic_exchange(<<"amq.topic">>=Exchange, <<"amq.topic">>=_DefaultTopicExchange) -> %% This is the case where the destination is the same %% as the default of amq.topic Exchange; -maybe_apply_default_topic_exchange("amq.topic"=_Exchange, DefaultTopicExchange) -> +maybe_apply_default_topic_exchange(<<"amq.topic">>=_Exchange, DefaultTopicExchange) -> %% This is the case where the destination would have been %% amq.topic but we have configured a different default - binary_to_list(DefaultTopicExchange); + DefaultTopicExchange; maybe_apply_default_topic_exchange(Exchange, _DefaultTopicExchange) -> %% This is the case where the destination is different than %% amq.topic, so it must have been specified in the %% message headers Exchange. + +create_queue(_State = #state{authz_ctx = AuthzCtx, + user = #user{username = Username} = User, + cfg = #cfg{vhost = VHost}}) -> + QNameBin = rabbit_guid:binary(rabbit_guid:gen_secure(), "stomp.gen"), + QName = rabbit_misc:r(VHost, queue, QNameBin), + + %% configure access to queue required for queue.declare + ok = check_resource_access(User, QName, configure, AuthzCtx), + case rabbit_vhost_limit:is_over_queue_limit(VHost) of + false -> + rabbit_core_metrics:queue_declared(QName), + + case rabbit_amqqueue:declare(QName, _Durable = false, _AutoDelete = true, + [], self(), Username) of + {new, Q} when ?is_amqqueue(Q) -> + rabbit_core_metrics:queue_created(QName), + {ok, Q}; + Other -> + log_error(rabbit_misc:format("Failed to declare ~s: ~p", [rabbit_misc:rs(QName)]), Other, none), + {error, queue_declare} + end; + {true, Limit} -> + log_error(rabbit_misc:format("cannot declare ~s because ", [rabbit_misc:rs(QName)]), + rabbit_misc:format("queue limit ~p in vhost '~s' is reached", [Limit, VHost]), + none), + {error, queue_limit_exceeded} + end. + +delete_queue(QRes, Username) -> + case rabbit_amqqueue:with( + QRes, + fun (Q) -> + rabbit_queue_type:delete(Q, false, false, Username) + end, + fun (not_found) -> + ok; + ({absent, Q, crashed}) -> + rabbit_classic_queue:delete_crashed(Q, Username); + ({absent, Q, stopped}) -> + rabbit_classic_queue:delete_crashed(Q, Username); + ({absent, _Q, _Reason}) -> + ok + end) of + {ok, _N} -> + ok; + ok -> + ok + end. + +ensure_binding(#resource{name = QueueBin}, {<<>>, QueueBin}, _State) -> + %% i.e., we should only be asked to bind to the default exchange a + %% queue with its own name + ok; +ensure_binding(QName, {Exchange, RoutingKey}, _State = #state{cfg = #cfg{ + auth_login = Username, + vhost = VHost}}) -> + Binding = #binding{source = rabbit_misc:r(VHost, exchange, Exchange), + destination = QName, + key = RoutingKey}, + case rabbit_binding:add(Binding, Username) of + {error, {resources_missing, [{not_found, Name} | _]}} -> + rabbit_amqqueue:not_found(Name); + {error, {resources_missing, [{absent, Q, Reason} | _]}} -> + rabbit_amqqueue:absent(Q, Reason); + {error, {binding_invalid, Fmt, Args}} -> + rabbit_misc:protocol_error(precondition_failed, Fmt, Args); + {error, #amqp_error{} = Error} -> + rabbit_misc:protocol_error(Error); + ok -> + ok + end. + +check_resource_access(User, Resource, Perm, Context) -> + V = {Resource, Context, Perm}, + Cache = case get(permission_cache) of + undefined -> []; + Other -> Other + end, + case lists:member(V, Cache) of + true -> + ok; + false -> + rabbit_access_control:check_resource_access(User, Resource, Perm, Context), + CacheTail = lists:sublist(Cache, ?MAX_PERMISSION_CACHE_SIZE-1), + put(permission_cache, [V | CacheTail]), + ok + end. + +handle_down({{'DOWN', QName}, _MRef, process, QPid, Reason}, + State0 = #state{queue_states = QStates0} = State) -> + case rabbit_queue_type:handle_down(QPid, QName, Reason, QStates0) of + {ok, QStates1, Actions} -> + State1 = State0#state{queue_states = QStates1}, + State2 = handle_queue_actions(Actions, State1), + {ok, State2}; + {eol, QStates1, QRef} -> + State1 = handle_consuming_queue_down_or_eol(QRef, State#state{queue_states = QStates1}), + {ConfirmMXs, UC1} = + rabbit_confirms:remove_queue(QRef, State1#state.unconfirmed), + State2 = record_confirms(ConfirmMXs, + State1#state{unconfirmed = UC1}), + _ = erase_queue_stats(QRef), + {ok, State2#state{queue_states = rabbit_queue_type:remove(QRef, State2#state.queue_states)}} + end. + +handle_queue_event({queue_event, QRef, Evt}, #state{queue_states = QStates0} = State) -> + case rabbit_queue_type:handle_event(QRef, Evt, QStates0) of + {ok, QState1, Actions} -> + State1 = State#state{queue_states = QState1}, + try handle_queue_actions(Actions, State1) of + State2 -> + {ok, State2} + catch throw:Reason when Reason =:= consumer_timeout -> + {error, Reason, State1} + end; + {eol, Actions} -> + State1 = handle_queue_actions(Actions, State), + State2 = handle_consuming_queue_down_or_eol(QRef, State1), + {ConfirmMXs, UC1} = + rabbit_confirms:remove_queue(QRef, State1#state.unconfirmed), + %% Deleted queue is a special case. + %% Do not nack the "rejected" messages. + State3 = record_confirms(ConfirmMXs, + State2#state{unconfirmed = UC1}), + {ok, State3#state{queue_states = rabbit_queue_type:remove(QRef, QStates0)}}; + {protocol_error, Type, Reason, ReasonArgs} = Error -> + log_error(Type, Reason, ReasonArgs), + {error, Error, State} + end. + +handle_queue_actions(Actions, #state{} = State0) -> + lists:foldl( + fun ({deliver, ConsumerTag, Ack, Msgs}, S) -> + deliver_to_client(ConsumerTag, Ack, Msgs, S); + ({settled, QRef, MsgSeqNos}, S0) -> + S = confirm(MsgSeqNos, QRef, S0), + send_confirms_and_nacks(S); + ({rejected, _QRef, MsgSeqNos}, S0) -> + {U, Rej} = + lists:foldr( + fun(SeqNo, {U1, Acc}) -> + case rabbit_confirms:reject(SeqNo, U1) of + {ok, MX, U2} -> + {U2, [MX | Acc]}; + {error, not_found} -> + {U1, Acc} + end + end, {S0#state.unconfirmed, []}, MsgSeqNos), + S = S0#state{unconfirmed = U}, + record_rejects(Rej, S); + ({queue_down, QRef}, S0) -> + handle_consuming_queue_down_or_eol(QRef, S0); + ({released, QName, _CTag, _MsgSeqNos, timeout}, _S) -> + ?LOG_INFO("Terminating STOMP connection because consumer " + "on ~ts timed out", [rabbit_misc:rs(QName)]), + throw(consumer_timeout); + (_, S0) -> + S0 + end, State0, Actions). + + + +parse_endpoint(undefined) -> + parse_endpoint(<<"/queue">>); +parse_endpoint(Destination) when is_binary(Destination) -> + case binary:split(Destination, <<"/">>, [global]) of + [Name] -> + {ok, {queue, unescape(Name)}}; + [<<>>, <<"exchange">> | Rest] -> + parse_endpoint0(exchange, Rest); + [<<>>, <<"queue">> | Rest] -> + parse_endpoint0(queue, Rest); + [<<>>, <<"topic">> | Rest] -> + parse_endpoint0(topic, Rest); + [<<>>, <<"temp-queue">> | Rest] -> + parse_endpoint0(temp_queue, Rest); + [<<>>, <<"amq">>, <<"queue">> | Rest] -> + parse_endpoint0(amqqueue, Rest); + [<<>>, <<"reply-queue">> | [_|_]] -> + %% Reply queue names can have slashes, so take everything + %% after "/reply-queue/" + PrefixLen = byte_size(<<"/reply-queue/">>), + parse_endpoint0(reply_queue, + [binary:part(Destination, PrefixLen, byte_size(Destination) - PrefixLen)]); + _ -> + {error, {unknown_destination, Destination}} + end. + +parse_endpoint0(exchange, [<<>> | _] = Rest) -> + {error, {invalid_destination, exchange, to_url(Rest)}}; +parse_endpoint0(exchange, [Name]) -> + {ok, {exchange, {unescape(Name), undefined}}}; +parse_endpoint0(exchange, [Name, Pattern]) -> + {ok, {exchange, {unescape(Name), unescape(Pattern)}}}; +parse_endpoint0(queue, []) -> + {error, {invalid_destination, queue, []}}; +parse_endpoint0(Type, [Name]) when Name =/= <<>> -> + {ok, {Type, unescape(Name)}}; +parse_endpoint0(Type, Rest) -> + {error, {invalid_destination, Type, to_url(Rest)}}. + +%% -------------------------------------------------------------------------- + +util_ensure_endpoint(source, {exchange, {Name, _}}, Params, State = #state{cfg = #cfg{vhost = VHost}}) -> + ExchangeName = rabbit_misc:r(Name, exchange, VHost), + check_exchange(ExchangeName, proplists:get_value(check_exchange, Params, false)), + Amqqueue = new_amqqueue(undefined, exchange, Params, State), + {ok, Queue} = create_queue(Amqqueue, State), + {ok, amqqueue:get_name(Queue), State}; + +util_ensure_endpoint(source, {topic, _}, Params, State) -> + Amqqueue = new_amqqueue(undefined, topic, Params, State), + {ok, Queue} = create_queue(Amqqueue, State), + {ok, amqqueue:get_name(Queue), State}; + +util_ensure_endpoint(_Dir, {queue, undefined}, _Params, State) -> + {ok, undefined, State}; + +util_ensure_endpoint(_, {queue, Name}, Params, State=#state{route_state = RoutingState, + cfg = #cfg{vhost = VHost}}) -> + Params1 = rabbit_misc:pmerge(durable, true, Params), + QueueNameBin = Name, + RState1 = case sets:is_element(QueueNameBin, RoutingState) of + true -> RoutingState; + _ -> Amqqueue = new_amqqueue(QueueNameBin, queue, Params1, State), + {ok, Queue} = create_queue(Amqqueue, State), + #resource{name = QNameBin} = amqqueue:get_name(Queue), + sets:add_element(QNameBin, RoutingState) + end, + {ok, rabbit_misc:r(VHost, queue, QueueNameBin), State#state{route_state = RState1}}; + +util_ensure_endpoint(dest, {exchange, {Name, _}}, Params, State = #state{cfg = #cfg{vhost = VHost}}) -> + ExchangeName = rabbit_misc:r(Name, exchange, VHost), + check_exchange(ExchangeName, proplists:get_value(check_exchange, Params, false)), + {ok, undefined, State}; + +util_ensure_endpoint(dest, {topic, _}, _Params, State) -> + {ok, undefined, State}; + +util_ensure_endpoint(_, {amqqueue, Name}, _Params, State = #state{cfg = #cfg{vhost = VHost}}) -> + {ok, rabbit_misc:r(VHost, queue, Name), State}; + +util_ensure_endpoint(_, {reply_queue, Name}, _Params, State = #state{cfg = #cfg{vhost = VHost}}) -> + {ok, rabbit_misc:r(VHost, queue, Name), State}; + +util_ensure_endpoint(_Direction, _Endpoint, _Params, _State) -> + {error, invalid_endpoint}. + + +%% -------------------------------------------------------------------------- + +parse_routing({exchange, {Name, undefined}}) -> + {Name, <<>>}; +parse_routing({exchange, {Name, Pattern}}) -> + {Name, Pattern}; +parse_routing({topic, Name}) -> + {<<"amq.topic">>, Name}; +parse_routing({Type, Name}) + when Type =:= queue orelse Type =:= reply_queue orelse Type =:= amqqueue -> + {<<>>, Name}. + +dest_temp_queue({temp_queue, Name}) -> Name; +dest_temp_queue(_) -> none. + +%% -------------------------------------------------------------------------- + +check_exchange(_, false) -> + ok; +check_exchange(ExchangeName, true) -> + _ = rabbit_exchange:lookup_or_die(ExchangeName), + ok. + +new_amqqueue(QNameBin0, Type, Params0, _State = #state{user = #user{username = Username}, + cfg = #cfg{vhost = VHost}}) -> + QNameBin = case {Type, proplists:get_value(subscription_queue_name_gen, Params0)} of + {topic, SQNG} when is_function(SQNG) -> + SQNG(); + {exchange, SQNG} when is_function(SQNG) -> + SQNG(); + _ -> + QNameBin0 + end, + QName = rabbit_misc:r(VHost, queue, QNameBin), + %% defaults + Params = case proplists:get_value(durable, Params0, false) of + false -> [{auto_delete, true}, {exclusive, true} | Params0]; + true -> Params0 + end, + Args = proplists:get_value(arguments, Params, []), + + amqqueue:new(QName, + none, + proplists:get_value(durable, Params, false), + proplists:get_value(auto_delete, Params, false), + case proplists:get_value(exclusive, Params, false) of + false -> none; + true -> self() + end, + Args, + VHost, + #{user => Username}, + rabbit_amqqueue:get_queue_type(Args)). + + +to_url([]) -> <<>>; +to_url(Lol) -> iolist_to_binary([$/ | lists:join($/, Lol)]). + +unescape(Bin) -> unescape(Bin, []). + +unescape(<<>>, Acc) -> list_to_binary(lists:reverse(Acc)); +unescape(<<"%2F", Rest/binary>>, Acc) -> unescape(Rest, [$/ | Acc]); +unescape(<>, Acc) -> unescape(Rest, [C | Acc]). + + +consume_queue(QRes, Spec0, State = #state{user = #user{username = Username} = User, + authz_ctx = AuthzCtx, + queue_states = QStates0}) -> + check_resource_access(User, QRes, read, AuthzCtx), + Spec = Spec0#{channel_pid => self(), + limiter_pid => none, + limiter_active => false, + ok_msg => undefined, + acting_user => Username}, + rabbit_amqqueue:with_or_die( + QRes, + fun(Q1) -> + case rabbit_queue_type:consume(Q1, Spec, QStates0) of + {ok, QStates} -> + rabbit_global_counters:consumer_created( + State#state.cfg#cfg.proto_ver), + State1 = State#state{queue_states = QStates}, + {ok, State1}; + {error, Type, Fmt, FmtArgs} -> + error("Failed to consume", + "~ts from ~ts: " ++ Fmt, + [Type, rabbit_misc:rs(QRes) | FmtArgs], + State) + end + end). + +create_queue(Amqqueue, _State = #state{authz_ctx = AuthzCtx, + user = User, + cfg = #cfg{vhost = VHost}}) -> + QName = amqqueue:get_name(Amqqueue), + + %% configure access to queue required for queue.declare + ok = check_resource_access(User, QName, configure, AuthzCtx), + + case rabbit_vhost_limit:is_over_queue_limit(VHost) of + false -> + rabbit_core_metrics:queue_declared(QName), + + case rabbit_queue_type:declare(Amqqueue, node()) of + {new, Q} when ?is_amqqueue(Q) -> + rabbit_core_metrics:queue_created(QName), + {ok, Q}; + {existing, Q} when ?is_amqqueue(Q) -> + rabbit_core_metrics:queue_created(QName), + {ok, Q}; + Other -> + log_error(rabbit_misc:format("Failed to declare ~s: ~p", [rabbit_misc:rs(QName)]), Other, none), + {error, queue_declare} + end; + {true, Limit} -> + log_error(rabbit_misc:format("cannot declare ~s because ", [rabbit_misc:rs(QName)]), + rabbit_misc:format("queue limit ~p in vhost '~s' is reached", [Limit, VHost]), + none), + {error, queue_limit_exceeded} + end. + +routing_init_state() -> sets:new([{version, 2}]). + +check_internal_exchange(#exchange{name = Name, internal = true}) -> + rabbit_misc:protocol_error(access_refused, + "cannot publish to internal ~ts", + [rabbit_misc:rs(Name)]); +check_internal_exchange(_) -> + ok. + + +check_topic_authorisation(#exchange{name = Name = #resource{virtual_host = VHost}, type = topic}, + User = #user{username = Username}, + RoutingKey, AuthzContext, Permission) -> + Resource = Name#resource{kind = topic}, + VariableMap = build_topic_variable_map(AuthzContext, VHost, Username), + Context = #{routing_key => RoutingKey, + variable_map => VariableMap}, + Cache = case get(topic_permission_cache) of + undefined -> []; + Other -> Other + end, + case lists:member({Resource, Context, Permission}, Cache) of + true -> ok; + false -> ok = rabbit_access_control:check_topic_access( + User, Resource, Permission, Context), + CacheTail = lists:sublist(Cache, ?MAX_PERMISSION_CACHE_SIZE-1), + put(topic_permission_cache, [{Resource, Context, Permission} | CacheTail]) + end; +check_topic_authorisation(_, _, _, _, _) -> + ok. + + +build_topic_variable_map(AuthzContext, VHost, Username) when is_map(AuthzContext) -> + maps:merge(AuthzContext, #{<<"vhost">> => VHost, <<"username">> => Username}). + +check_vhost_exists(VHost, Username, PeerIp) -> + case rabbit_vhost:exists(VHost) of + true -> + ok; + false -> + rabbit_core_metrics:auth_attempt_failed(PeerIp, Username, stomp), + ?LOG_ERROR("STOMP connection failed: virtual host '~ts' does not exist", + [VHost]), + {error, not_allowed, Username, VHost} + end. + +check_vhost_access(VHost, User = #user{username = Username}, PeerIp) -> + AuthzCtx = #{}, + try rabbit_access_control:check_vhost_access( + User, VHost, {ip, PeerIp}, AuthzCtx) of + ok -> + {ok, AuthzCtx} + catch exit:#amqp_error{name = not_allowed} -> + rabbit_core_metrics:auth_attempt_failed(PeerIp, Username, stomp), + ?LOG_ERROR("STOMP connection failed: access refused for user '~ts' to vhost '~ts'", + [Username, VHost]), + {error, not_allowed, Username, VHost} + end. + +check_vhost_connection_limit(VHost) -> + case rabbit_vhost_limit:is_over_connection_limit(VHost) of + false -> + ok; + {true, Limit} -> + ?LOG_ERROR("STOMP connection failed: connection limit ~p is reached for vhost '~s'", + [Limit, VHost]), + {error, quota_exceeded} + end. + +check_user_loopback(Username, PeerIp) -> + case rabbit_access_control:check_user_loopback(Username, PeerIp) of + ok -> + ok; + not_allowed -> + rabbit_core_metrics:auth_attempt_failed(PeerIp, Username, stomp), + {error, not_loopback, Username} + end. + +erase_queue_stats(QName) -> + rabbit_core_metrics:channel_queue_down({self(), QName}), + erase({queue_stats, QName}), + [begin + rabbit_core_metrics:channel_queue_exchange_down({self(), QX}), + erase({queue_exchange_stats, QX}) + end || {{queue_exchange_stats, QX = {QName0, _}}, _} <- get(), + QName0 =:= QName]. diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_reader.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_reader.erl index 6036c919ed77..22a192384d83 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_reader.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp_reader.erl @@ -18,17 +18,13 @@ -include("rabbit_stomp.hrl"). -include("rabbit_stomp_frame.hrl"). --include_lib("amqp_client/include/amqp_client.hrl"). -include_lib("rabbit_common/include/logging.hrl"). -include_lib("kernel/include/logger.hrl"). --define(SIMPLE_METRICS, [pid, recv_oct, send_oct, reductions]). --define(OTHER_METRICS, [recv_cnt, send_cnt, send_pend, garbage_collection, state, - timeout]). - -record(reader_state, { socket, + proxy_socket, conn_name, parse_state, processor_state, @@ -43,7 +39,8 @@ heartbeat_sup, heartbeat, %% heartbeat timeout value used, 0 means %% heartbeats are disabled - timeout_sec + timeout_sec, + parser_config }). %%---------------------------------------------------------------------------- @@ -65,45 +62,65 @@ close_connection(Pid, Reason) -> init([SupHelperPid, Ref, Configuration]) -> - logger:set_process_metadata(#{domain => ?RMQLOG_DOMAIN_CONN}), + logger:set_process_metadata(#{domain => ?RMQLOG_DOMAIN_CONN ++ [stomp]}), process_flag(trap_exit, true), rabbit_access_control:set_max_heap_size_unauthenticated(rabbitmq_stomp), {ok, Sock} = rabbit_networking:handshake(Ref, application:get_env(rabbitmq_stomp, proxy_protocol, false)), RealSocket = rabbit_net:unwrap_socket(Sock), + ProxySocket = rabbit_net:maybe_get_proxy_socket(Sock), case rabbit_net:connection_string(Sock, inbound) of {ok, ConnStr} -> ConnName = rabbit_data_coercion:to_binary(ConnStr), - ProcInitArgs = processor_args(Configuration, Sock), - ProcState = rabbit_stomp_processor:initial_state(Configuration, - ProcInitArgs), - - ?LOG_INFO("accepting STOMP connection ~tp (~ts)", - [self(), ConnName]), - - ParseState = rabbit_stomp_frame:initial_state(), - Alarms = register_resource_alarm(), - - LoginTimeout = application:get_env(rabbitmq_stomp, login_timeout, 10_000), - MaxFrameSize = application:get_env(rabbitmq_stomp, max_frame_size, ?DEFAULT_MAX_FRAME_SIZE), - erlang:send_after(LoginTimeout, self(), login_timeout), - - gen_server2:enter_loop(?MODULE, [], - rabbit_event:init_stats_timer( - run_socket(control_throttle( - #reader_state{socket = RealSocket, - conn_name = ConnName, - parse_state = ParseState, - processor_state = ProcState, - heartbeat_sup = SupHelperPid, - heartbeat = {none, none}, - max_frame_size = MaxFrameSize, - current_frame_size = 0, - state = running, - blocked_by = sets:from_list(Alarms, [{version, 2}]), - recv_outstanding = false})), #reader_state.stats_timer), - {backoff, 1000, 1000, 10000}); + logger:update_process_metadata(#{connection => ConnName}), + case rabbit_net:socket_ends(Sock, inbound) of + {ok, {PeerHost, PeerPort, Host, Port}} -> + SSLLoginName = ssl_login_name(RealSocket, Configuration), + SendFun = mk_send_fun(RealSocket), + ProcInitArgs = {SendFun, SSLLoginName, ConnName, + Host, Port, PeerHost, PeerPort}, + ProcState = rabbit_stomp_processor:initial_state( + Configuration, ProcInitArgs), + + ?LOG_INFO("accepting STOMP connection ~tp (~ts)", + [self(), ConnName]), + + ParserConfig = #stomp_parser_config{ + max_headers = Configuration#stomp_configuration.max_headers, + max_header_length = Configuration#stomp_configuration.max_header_length, + max_body_length = Configuration#stomp_configuration.max_body_length + }, + ParseState = rabbit_stomp_frame:initial_state(ParserConfig), + Alarms = register_resource_alarm(), + + LoginTimeout = application:get_env(rabbitmq_stomp, login_timeout, 10_000), + MaxFrameSize = application:get_env(rabbitmq_stomp, max_frame_size, ?DEFAULT_MAX_FRAME_SIZE), + erlang:send_after(LoginTimeout, self(), login_timeout), + + rabbit_networking:register_non_amqp_connection(self()), + + gen_server2:enter_loop(?MODULE, [], + rabbit_event:init_stats_timer( + run_socket(control_throttle( + #reader_state{socket = RealSocket, + proxy_socket = ProxySocket, + conn_name = ConnName, + parse_state = ParseState, + parser_config = ParserConfig, + processor_state = ProcState, + heartbeat_sup = SupHelperPid, + heartbeat = {none, none}, + max_frame_size = MaxFrameSize, + current_frame_size = 0, + state = running, + blocked_by = sets:from_list(Alarms, [{version, 2}]), + recv_outstanding = false})), #reader_state.stats_timer), + {backoff, 1000, 1000, 10000}); + {error, Reason} -> + rabbit_net:fast_close(RealSocket), + terminate({network_error, {socket_ends, Reason}}, undefined) + end; {error, enotconn} -> rabbit_net:fast_close(RealSocket), terminate(shutdown, undefined); @@ -123,6 +140,19 @@ handle_call({info, InfoItems}, _From, State) -> handle_call(Msg, From, State) -> {stop, {stomp_unexpected_call, Msg, From}, State}. +handle_cast(QueueEvent = {queue_event, _, _}, State) -> + ProcState = processor_state(State), + case rabbit_stomp_processor:handle_queue_event(QueueEvent, ProcState) of + {ok, NewProcState} -> + {noreply, processor_state(NewProcState, State), hibernate}; + {error, Reason, NewProcState} -> + {stop, {shutdown, Reason}, processor_state(NewProcState, State)} + end; +handle_cast({force_event_refresh, Ref}, State) -> + Infos = infos(?INFO_ITEMS ++ ?OTHER_METRICS, State), + rabbit_event:notify(connection_created, Infos, Ref), + {noreply, rabbit_event:init_stats_timer(State, #reader_state.stats_timer), + hibernate}; handle_cast({close_connection, Reason}, State) -> {stop, {shutdown, {server_initiated_close, Reason}}, State}; handle_cast(client_timeout, State) -> @@ -130,6 +160,16 @@ handle_cast(client_timeout, State) -> handle_cast(Msg, State) -> {stop, {stomp_unexpected_cast, Msg}, State}. +handle_info(connection_created, State) -> + Infos = infos(?INFO_ITEMS ++ ?OTHER_METRICS, State), + rabbit_core_metrics:connection_created(self(), Infos), + rabbit_event:notify(connection_created, Infos), + ProcState = processor_state(State), + logger:update_process_metadata( + #{connection => rabbit_stomp_processor:adapter_name(ProcState), + vhost => rabbit_stomp_processor:info(vhost, ProcState), + user => rabbit_stomp_processor:info(user, ProcState)}), + {noreply, State, hibernate}; handle_info({Tag, Sock, Data}, State=#reader_state{socket=Sock}) when Tag =:= tcp; Tag =:= ssl -> @@ -161,6 +201,15 @@ handle_info({bump_credit, Msg}, State) -> credit_flow:handle_bump_msg(Msg), {noreply, run_socket(control_throttle(State)), hibernate}; +handle_info({{'DOWN', _QName}, _MRef, process, _Pid, _Reason} = Evt, State) -> + ProcState = processor_state(State), + {ok, NewProcState} = rabbit_stomp_processor:handle_down(Evt, ProcState), + {noreply, processor_state(NewProcState, State), hibernate}; + +handle_info({'DOWN', _MRef, process, QPid, _Reason}, State) -> + rabbit_amqqueue_common:notify_sent_queue_down(QPid), + {noreply, State, hibernate}; + %%---------------------------------------------------------------------------- handle_info(client_timeout, State) -> @@ -168,8 +217,8 @@ handle_info(client_timeout, State) -> handle_info(login_timeout, State) -> ProcState = processor_state(State), - case rabbit_stomp_processor:info(channel, ProcState) of - none -> + case rabbit_stomp_processor:info(user, ProcState) of + undefined -> {stop, {shutdown, login_timeout}, State}; _ -> {noreply, State, hibernate} @@ -177,42 +226,6 @@ handle_info(login_timeout, State) -> %%---------------------------------------------------------------------------- -handle_info(#'basic.consume_ok'{}, State) -> - {noreply, State, hibernate}; -handle_info(#'basic.cancel_ok'{}, State) -> - {noreply, State, hibernate}; -handle_info(#'basic.ack'{delivery_tag = Tag, multiple = IsMulti}, State) -> - ProcState = processor_state(State), - NewProcState = rabbit_stomp_processor:flush_pending_receipts(Tag, - IsMulti, - ProcState), - {noreply, processor_state(NewProcState, State), hibernate}; -handle_info({Delivery = #'basic.deliver'{}, - Message = #amqp_msg{}}, - State) -> - %% receiving a message from a quorum queue - %% no delivery context - handle_info({Delivery, Message, undefined}, State); -handle_info({Delivery = #'basic.deliver'{}, - #amqp_msg{props = Props, payload = Payload}, - DeliveryCtx}, - State) -> - ProcState = processor_state(State), - NewProcState = rabbit_stomp_processor:send_delivery(Delivery, - Props, - Payload, - DeliveryCtx, - ProcState), - {noreply, processor_state(NewProcState, State), hibernate}; -handle_info(#'basic.cancel'{consumer_tag = Ctag}, State) -> - ProcState = processor_state(State), - case rabbit_stomp_processor:cancel_consumer(Ctag, ProcState) of - {ok, NewProcState, _} -> - {noreply, processor_state(NewProcState, State), hibernate}; - {stop, Reason, NewProcState} -> - {stop, Reason, processor_state(NewProcState, State)} - end; - handle_info({start_heartbeats, {0, 0}}, State) -> {noreply, State#reader_state{timeout_sec = {0, 0}}}; @@ -229,14 +242,8 @@ handle_info({start_heartbeats, {SendTimeout, ReceiveTimeout}}, %%---------------------------------------------------------------------------- -handle_info({'EXIT', From, Reason}, State) -> - ProcState = processor_state(State), - case rabbit_stomp_processor:handle_exit(From, Reason, ProcState) of - {stop, NewReason, NewProcState} -> - {stop, NewReason, processor_state(NewProcState, State)}; - unknown_exit -> - {stop, {connection_died, Reason}, State} - end. +handle_info({'EXIT', _From, Reason}, State) -> + {stop, {connection_died, Reason}, State}. %%---------------------------------------------------------------------------- process_received_bytes([], State) -> @@ -266,14 +273,14 @@ process_received_bytes(Bytes, {stop, normal, State}; false -> try rabbit_stomp_processor:process_frame(Frame, ProcState) of - {ok, NewProcState, Conn} -> - PS = rabbit_stomp_frame:initial_state(), + {ok, NewProcState} -> + PS = rabbit_stomp_frame:initial_state( + State#reader_state.parser_config), NextState = maybe_block(State, Frame), process_received_bytes(Rest, NextState#reader_state{ current_frame_size = 0, processor_state = NewProcState, - parse_state = PS, - connection = Conn}); + parse_state = PS}); {stop, Reason, NewProcState} -> {stop, Reason, processor_state(NewProcState, State)} @@ -318,7 +325,7 @@ control_throttle(State = #reader_state{state = CS, end. maybe_block(State = #reader_state{state = blocking, heartbeat = Heartbeat}, - #stomp_frame{command = "SEND"}) -> + #stomp_frame{command = 'SEND'}) -> rabbit_heartbeat:pause_monitor(Heartbeat), State#reader_state{state = blocked}; maybe_block(State, _) -> @@ -337,9 +344,13 @@ terminate(Reason, undefined) -> log_reason(Reason, undefined), {stop, Reason}; terminate(Reason, State = #reader_state{processor_state = ProcState}) -> - maybe_emit_stats(State), - log_reason(Reason, State), - _ = rabbit_stomp_processor:flush_and_die(ProcState), + maybe_emit_stats(State), + rabbit_core_metrics:connection_closed(self()), + Infos = infos(?OTHER_METRICS, State), + rabbit_event:notify(connection_closed, Infos), + rabbit_networking:unregister_non_amqp_connection(self()), + log_reason(Reason, State), + _ = rabbit_stomp_processor:flush_and_die(ProcState), {stop, Reason}. code_change(_OldVsn, State, _Extra) -> @@ -413,22 +424,15 @@ log_tls_alert(Alert, ConnName) -> %%---------------------------------------------------------------------------- -processor_args(Configuration, Sock) -> - RealSocket = rabbit_net:unwrap_socket(Sock), - SendFun = fun(IoData) -> - case rabbit_net:send(RealSocket, IoData) of - ok -> - ok; - {error, Reason} -> - exit({send_failed, Reason}) - end - end, - {ok, {PeerAddr, _PeerPort}} = rabbit_net:peername(RealSocket), - {SendFun, adapter_info(Sock), - ssl_login_name(RealSocket, Configuration), PeerAddr}. - -adapter_info(Sock) -> - amqp_connection:socket_adapter_info(Sock, {'STOMP', 0}). +mk_send_fun(RealSocket) -> + fun(IoData) -> + case rabbit_net:send(RealSocket, IoData) of + ok -> + ok; + {error, Reason} -> + exit({send_failed, Reason}) + end + end. ssl_login_name(_Sock, #stomp_configuration{ssl_cert_login = false}) -> none; @@ -452,11 +456,6 @@ maybe_emit_stats(State) -> rabbit_event:if_enabled(State, #reader_state.stats_timer, fun() -> emit_stats(State) end). -emit_stats(State=#reader_state{connection = C}) when C == none; C == undefined -> - %% Avoid emitting stats on terminate when the connection has not yet been - %% established, as this causes orphan entries on the stats database - State1 = rabbit_event:reset_stats_timer(State, #reader_state.stats_timer), - ensure_stats_timer(State1); emit_stats(State) -> [{_, Pid}, {_, Recv_oct}, @@ -482,7 +481,7 @@ processor_state(ProcState, #reader_state{} = State) -> infos(Items, State) -> [{Item, info_internal(Item, State)} || Item <- Items]. -info_internal(pid, State) -> info_internal(connection, State); +info_internal(pid, _) -> self(); info_internal(SockStat, #reader_state{socket = Sock}) when SockStat =:= recv_oct; SockStat =:= recv_cnt; SockStat =:= send_oct; @@ -504,9 +503,19 @@ info_internal(timeout, #reader_state{timeout_sec = undefined}) -> 0; info_internal(conn_name, #reader_state{conn_name = Val}) -> rabbit_data_coercion:to_binary(Val); -info_internal(connection, #reader_state{connection = Val}) -> - Val; +info_internal(name, #reader_state{conn_name = Val}) -> + rabbit_data_coercion:to_binary(Val); +info_internal(connection, _) -> + self(); info_internal(connection_state, #reader_state{state = Val}) -> Val; +info_internal(ssl, #reader_state{socket = Sock, proxy_socket = ProxySock}) -> + rabbit_net:proxy_ssl_info(Sock, ProxySock) /= nossl; +info_internal(SSL, #reader_state{socket = Sock, proxy_socket = ProxySock}) + when SSL =:= ssl_protocol; + SSL =:= ssl_key_exchange; + SSL =:= ssl_cipher; + SSL =:= ssl_hash -> + rabbit_ssl:info(SSL, {Sock, ProxySock}); info_internal(Key, #reader_state{processor_state = ProcState}) -> rabbit_stomp_processor:info(Key, ProcState). diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_routing_util.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_routing_util.erl deleted file mode 100644 index 66db2727ebdc..000000000000 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_routing_util.erl +++ /dev/null @@ -1,165 +0,0 @@ -%% This Source Code Form is subject to the terms of the Mozilla Public -%% License, v. 2.0. If a copy of the MPL was not distributed with this -%% file, You can obtain one at https://mozilla.org/MPL/2.0/. -%% -%% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. -%% - --module(rabbit_stomp_routing_util). - --export([init_state/0, dest_prefixes/0, all_dest_prefixes/0]). --export([ensure_endpoint/4, ensure_endpoint/5, ensure_binding/3]). --export([dest_temp_queue/1]). - --include_lib("amqp_client/include/amqp_client.hrl"). --include("rabbit_stomp_routing_prefixes.hrl"). - -%%---------------------------------------------------------------------------- - -init_state() -> sets:new(). - -dest_prefixes() -> [?EXCHANGE_PREFIX, ?TOPIC_PREFIX, ?QUEUE_PREFIX, - ?AMQQUEUE_PREFIX, ?REPLY_QUEUE_PREFIX]. - -all_dest_prefixes() -> [?TEMP_QUEUE_PREFIX | dest_prefixes()]. - -%% -------------------------------------------------------------------------- - -ensure_endpoint(Dir, Channel, Endpoint, State) -> - ensure_endpoint(Dir, Channel, Endpoint, [], State). - -ensure_endpoint(source, Channel, {exchange, {Name, _}}, Params, State) -> - check_exchange(Name, Channel, - proplists:get_value(check_exchange, Params, false)), - Method = queue_declare_method(#'queue.declare'{}, exchange, Params), - #'queue.declare_ok'{queue = Queue} = amqp_channel:call(Channel, Method), - {ok, Queue, State}; - -ensure_endpoint(source, Channel, {topic, _}, Params, State) -> - Method = queue_declare_method(#'queue.declare'{}, topic, Params), - #'queue.declare_ok'{queue = Queue} = amqp_channel:call(Channel, Method), - {ok, Queue, State}; - -ensure_endpoint(_Dir, _Channel, {queue, undefined}, _Params, State) -> - {ok, undefined, State}; - -ensure_endpoint(_, Channel, {queue, Name}, Params, State) -> - Params1 = rabbit_misc:pmerge(durable, true, Params), - Queue = list_to_binary(Name), - State1 = case sets:is_element(Queue, State) of - true -> State; - _ -> Method = queue_declare_method( - #'queue.declare'{queue = Queue, - nowait = true}, - queue, Params1), - case Method#'queue.declare'.nowait of - true -> amqp_channel:cast(Channel, Method); - false -> amqp_channel:call(Channel, Method) - end, - sets:add_element(Queue, State) - end, - {ok, Queue, State1}; - -ensure_endpoint(dest, Channel, {exchange, {Name, _}}, Params, State) -> - check_exchange(Name, Channel, - proplists:get_value(check_exchange, Params, false)), - {ok, undefined, State}; - -ensure_endpoint(dest, _Ch, {topic, _}, _Params, State) -> - {ok, undefined, State}; - -ensure_endpoint(_, _Ch, {amqqueue, Name}, _Params, State) -> - {ok, list_to_binary(Name), State}; - -ensure_endpoint(_, _Ch, {reply_queue, Name}, _Params, State) -> - {ok, list_to_binary(Name), State}; - -ensure_endpoint(_Direction, _Ch, _Endpoint, _Params, _State) -> - {error, invalid_endpoint}. - -%% -------------------------------------------------------------------------- - -ensure_binding(QueueBin, {"", Queue}, _Channel) -> - %% i.e., we should only be asked to bind to the default exchange a - %% queue with its own name - QueueBin = list_to_binary(Queue), - ok; -ensure_binding(Queue, {Exchange, RoutingKey}, Channel) -> - #'queue.bind_ok'{} = - amqp_channel:call(Channel, - #'queue.bind'{ - queue = Queue, - exchange = list_to_binary(Exchange), - routing_key = list_to_binary(RoutingKey)}), - ok. - -%% -------------------------------------------------------------------------- - -dest_temp_queue({temp_queue, Name}) -> Name; -dest_temp_queue(_) -> none. - -%% -------------------------------------------------------------------------- - -check_exchange(_, _, false) -> - ok; -check_exchange(ExchangeName, Channel, true) -> - XDecl = #'exchange.declare'{ exchange = list_to_binary(ExchangeName), - passive = true }, - #'exchange.declare_ok'{} = amqp_channel:call(Channel, XDecl), - ok. - -update_queue_declare_arguments(Method, Params) -> - Method#'queue.declare'{arguments = - proplists:get_value(arguments, Params, [])}. - -update_queue_declare_exclusive(Method, Params) -> - case proplists:get_value(exclusive, Params) of - undefined -> Method; - Val -> Method#'queue.declare'{exclusive = Val} - end. - -update_queue_declare_auto_delete(Method, Params) -> - case proplists:get_value(auto_delete, Params) of - undefined -> Method; - Val -> Method#'queue.declare'{auto_delete = Val} - end. - -update_queue_declare_nowait(Method, Params) -> - case proplists:get_value(nowait, Params) of - undefined -> Method; - Val -> Method#'queue.declare'{nowait = Val} - end. - -queue_declare_method(#'queue.declare'{} = Method, Type, Params) -> - %% defaults - Method1 = case proplists:get_value(durable, Params, false) of - true -> Method#'queue.declare'{durable = true}; - false -> Method#'queue.declare'{auto_delete = true, - exclusive = true} - end, - - %% set the rest of queue.declare fields from Params - Method2 = lists:foldl(fun (F, Acc) -> F(Acc, Params) end, - Method1, [fun update_queue_declare_arguments/2, - fun update_queue_declare_exclusive/2, - fun update_queue_declare_auto_delete/2, - fun update_queue_declare_nowait/2]), - - Arguments = proplists:get_value(arguments, Params, []), - DefaultQueueType = proplists:get_value(default_queue_type, Params, - rabbit_queue_type:default()), - Method3 = case rabbit_amqqueue:get_queue_type(Arguments, DefaultQueueType) of - T when T =:= rabbit_stream_queue; T =:= rabbit_quorum_queue -> - Method2#'queue.declare'{durable = true, - exclusive = false}; - _ -> Method2 - end, - - case {Type, proplists:get_value(subscription_queue_name_gen, Params)} of - {topic, SQNG} when is_function(SQNG) -> - Method3#'queue.declare'{queue = SQNG()}; - {exchange, SQNG} when is_function(SQNG) -> - Method3#'queue.declare'{queue = SQNG()}; - _ -> - Method3 - end. diff --git a/deps/rabbitmq_stomp/src/rabbit_stomp_util.erl b/deps/rabbitmq_stomp/src/rabbit_stomp_util.erl index b7750cd47ca6..56327db2719f 100644 --- a/deps/rabbitmq_stomp/src/rabbit_stomp_util.erl +++ b/deps/rabbitmq_stomp/src/rabbit_stomp_util.erl @@ -10,14 +10,14 @@ -export([parse_message_id/1, subscription_queue_name/3]). -export([longstr_field/2]). -export([ack_mode/1, consumer_tag_reply_to/1, consumer_tag/1, message_headers/1, - headers_post_process/1, headers/5, message_properties/1, tag_to_id/1, + headers_post_process/1, headers/9, message_properties/1, tag_to_id/1, msg_header_name/1, ack_header_name/1, build_arguments/1, build_params/2, has_durable_header/1]). -export([negotiate_version/2]). -export([trim_headers/1]). --include_lib("amqp_client/include/amqp_client.hrl"). --include("rabbit_stomp_routing_prefixes.hrl"). +-include_lib("rabbit_common/include/rabbit.hrl"). +-include_lib("rabbit_common/include/rabbit_framing.hrl"). -include("rabbit_stomp_frame.hrl"). -include("rabbit_stomp_headers.hrl"). @@ -29,30 +29,33 @@ %%-------------------------------------------------------------------- consumer_tag_reply_to(QueueId) -> - internal_tag(?TEMP_QUEUE_ID_PREFIX ++ QueueId). + internal_tag(<>). consumer_tag(Frame) -> case rabbit_stomp_frame:header(Frame, ?HEADER_ID) of {ok, Id} -> - case lists:prefix(?TEMP_QUEUE_ID_PREFIX, Id) of - false -> {ok, internal_tag(Id), "id='" ++ Id ++ "'"}; - true -> {error, invalid_prefix} + case Id of + <<"/temp-queue/", _/binary>> -> + {error, invalid_prefix}; + _ -> + {ok, internal_tag(Id), + <<"id='", Id/binary, "'">>} end; not_found -> case rabbit_stomp_frame:header(Frame, ?HEADER_DESTINATION) of {ok, DestHdr} -> {ok, queue_tag(DestHdr), - "destination='" ++ DestHdr ++ "'"}; + <<"destination='", DestHdr/binary, "'">>}; not_found -> {error, missing_destination_header} end end. ack_mode(Frame) -> - case rabbit_stomp_frame:header(Frame, ?HEADER_ACK, "auto") of - "auto" -> {auto, false}; - "client" -> {client, true}; - "client-individual" -> {client, false} + case rabbit_stomp_frame:header(Frame, ?HEADER_ACK, <<"auto">>) of + <<"auto">> -> {auto, false}; + <<"client">> -> {client, true}; + <<"client-individual">> -> {client, false} end. message_properties(Frame = #stomp_frame{headers = Headers}) -> @@ -67,8 +70,7 @@ message_properties(Frame = #stomp_frame{headers = Headers}) -> #'P_basic'{ content_type = BinH(?HEADER_CONTENT_TYPE), content_encoding = BinH(?HEADER_CONTENT_ENCODING), - headers = [longstr_field(K, V) || - {K, V} <- Headers, user_header(K)], + headers = user_headers(Headers), delivery_mode = DeliveryMode, priority = IntH(?HEADER_PRIORITY), correlation_id = BinH(?HEADER_CORRELATION_ID), @@ -103,28 +105,24 @@ adhoc_convert_headers(undefined, Existing) -> Existing; adhoc_convert_headers(Headers, Existing) -> lists:foldr(fun ({K, longstr, V}, Acc) -> - [{binary_to_list(K), binary_to_list(V)} | Acc]; + [{K, V} | Acc]; ({K, signedint, V}, Acc) -> - [{binary_to_list(K), integer_to_list(V)} | Acc]; + [{K, integer_to_binary(V)} | Acc]; ({K, long, V}, Acc) -> - [{binary_to_list(K), integer_to_list(V)} | Acc]; + [{K, integer_to_binary(V)} | Acc]; (_, Acc) -> Acc end, Existing, Headers). -headers_extra(SessionId, AckMode, Version, - #'basic.deliver'{consumer_tag = ConsumerTag, - delivery_tag = DeliveryTag, - exchange = ExchangeBin, - routing_key = RoutingKeyBin, - redelivered = Redelivered}) -> +headers_extra(SessionId, ConsumerTag, DeliveryTag, + ExchangeBin, RoutingKeyBin, Redelivered, + AckMode, Version) -> case tag_to_id(ConsumerTag) of {ok, {internal, Id}} -> [{?HEADER_SUBSCRIPTION, Id}]; _ -> [] end ++ [{?HEADER_DESTINATION, - format_destination(binary_to_list(ExchangeBin), - binary_to_list(RoutingKeyBin))}, + format_destination(ExchangeBin, RoutingKeyBin)}, {?HEADER_MESSAGE_ID, create_message_id(ConsumerTag, SessionId, DeliveryTag)}, {?HEADER_REDELIVERED, Redelivered}] ++ @@ -135,27 +133,47 @@ headers_extra(SessionId, AckMode, Version, end. headers_post_process(Headers) -> - Prefixes = rabbit_stomp_routing_util:dest_prefixes(), [case Header of {?HEADER_REPLY_TO, V} -> - case lists:any(fun (P) -> lists:prefix(P, V) end, Prefixes) of + case lists:any(fun(P) -> + S = byte_size(P), + case V of + <> -> true; + _ -> false + end + end, ?DEST_PREFIXES) of true -> {?HEADER_REPLY_TO, V}; - false -> {?HEADER_REPLY_TO, ?REPLY_QUEUE_PREFIX ++ V} + false -> {?HEADER_REPLY_TO, <<(?REPLY_QUEUE_PREFIX)/binary, V/binary>>} end; {_, _} -> Header end || Header <- Headers]. -headers(SessionId, Delivery, Properties, AckMode, Version) -> - headers_extra(SessionId, AckMode, Version, Delivery) ++ - headers_post_process(message_headers(Properties)). +headers(SessionId, ConsumerTag, DeliveryTag, + ExchangeBin, RoutingKey, Redelivered, + Properties, AckMode, Version) -> + maps:merge( + maps:from_list( + headers_extra(SessionId, ConsumerTag, DeliveryTag, + ExchangeBin, RoutingKey, Redelivered, + AckMode, Version)), + maps:from_list( + headers_post_process(message_headers(Properties)))). tag_to_id(<>) -> - {ok, {internal, binary_to_list(Id)}}; -tag_to_id(<>) -> - {ok, {queue, binary_to_list(Id)}}; + {ok, {internal, Id}}; +tag_to_id(<>) -> + {ok, {queue, Id}}; tag_to_id(Other) when is_binary(Other) -> - {error, {unknown, binary_to_list(Other)}}. + {error, {unknown, Other}}. + +user_headers(Headers) -> + maps:fold(fun(K, V, Acc) -> + case user_header(K) of + true -> [longstr_field(K, V) | Acc]; + false -> Acc + end + end, [], Headers). user_header(Hdr) when Hdr =:= ?HEADER_CONTENT_TYPE orelse @@ -176,11 +194,11 @@ user_header(_) -> true. parse_message_id(MessageId) -> - case split(MessageId, ?MESSAGE_ID_SEPARATOR) of + case binary:split(MessageId, ?MESSAGE_ID_SEPARATOR, [global]) of [ConsumerTag, SessionId, DeliveryTag] -> - {ok, {list_to_binary(ConsumerTag), - SessionId, - list_to_integer(DeliveryTag)}}; + {ok, {ConsumerTag, + binary_to_list(SessionId), + binary_to_integer(DeliveryTag)}}; _ -> {error, invalid_message_id} end. @@ -219,36 +237,34 @@ find_max_version({V1, X}, {_V2, []}) when length(X) > 0 -> %% ---- Header processing helpers ---- longstr_field(K, V) -> - {list_to_binary(K), longstr, list_to_binary(V)}. + {K, longstr, V}. maybe_header(_Key, undefined, Acc) -> Acc; maybe_header(?HEADER_PERSISTENT, 2, Acc) -> - [{?HEADER_PERSISTENT, "true"} | Acc]; + [{?HEADER_PERSISTENT, <<"true">>} | Acc]; maybe_header(Key, Value, Acc) when is_binary(Value) -> - [{Key, binary_to_list(Value)} | Acc]; + [{Key, Value} | Acc]; maybe_header(Key, Value, Acc) when is_integer(Value) -> - [{Key, integer_to_list(Value)}| Acc]; + [{Key, integer_to_binary(Value)} | Acc]; maybe_header(_Key, _Value, Acc) -> Acc. create_message_id(ConsumerTag, SessionId, DeliveryTag) -> - [ConsumerTag, - ?MESSAGE_ID_SEPARATOR, - SessionId, - ?MESSAGE_ID_SEPARATOR, - integer_to_list(DeliveryTag)]. + iolist_to_binary([ConsumerTag, ?MESSAGE_ID_SEPARATOR, + SessionId, ?MESSAGE_ID_SEPARATOR, + integer_to_binary(DeliveryTag)]). trim_headers(Frame = #stomp_frame{headers = Hdrs}) -> - Frame#stomp_frame{headers = [{K, string:strip(V, left)} || {K, V} <- Hdrs]}. + Frame#stomp_frame{headers = maps:map(fun(_K, V) -> string:trim(V, leading) end, Hdrs)}. internal_tag(Base) -> - list_to_binary(?INTERNAL_TAG_PREFIX ++ Base). + <>. queue_tag(Base) -> - list_to_binary(?QUEUE_TAG_PREFIX ++ Base). + <>. -ack_header_name("1.2") -> ?HEADER_ID; +ack_header_name("1.2") -> ?HEADER_ACK; ack_header_name("1.1") -> ?HEADER_MESSAGE_ID; ack_header_name("1.0") -> ?HEADER_MESSAGE_ID. @@ -258,60 +274,50 @@ msg_header_name("1.0") -> ?HEADER_MESSAGE_ID. build_arguments(Headers) -> Arguments = - lists:foldl(fun({K, V}, Acc) -> - case lists:member(K, ?HEADER_ARGUMENTS) of - true -> [build_argument(K, V) | Acc]; - false -> Acc - end - end, - [], - Headers), + fold_headers(fun(K, V, Acc) -> + case lists:member(K, ?HEADER_ARGUMENTS) of + true -> [build_argument(K, V) | Acc]; + false -> Acc + end + end, [], Headers), {arguments, Arguments}. +fold_headers(Fun, Acc, Headers) -> + maps:fold(Fun, Acc, Headers). + build_argument(?HEADER_X_DEAD_LETTER_EXCHANGE, Val) -> - {list_to_binary(?HEADER_X_DEAD_LETTER_EXCHANGE), longstr, - list_to_binary(string:strip(Val))}; + {?HEADER_X_DEAD_LETTER_EXCHANGE, longstr, string:trim(Val)}; build_argument(?HEADER_X_DEAD_LETTER_ROUTING_KEY, Val) -> - {list_to_binary(?HEADER_X_DEAD_LETTER_ROUTING_KEY), longstr, - list_to_binary(string:strip(Val))}; + {?HEADER_X_DEAD_LETTER_ROUTING_KEY, longstr, string:trim(Val)}; build_argument(?HEADER_X_EXPIRES, Val) -> - {list_to_binary(?HEADER_X_EXPIRES), long, - list_to_integer(string:strip(Val))}; + {?HEADER_X_EXPIRES, long, binary_to_integer(string:trim(Val))}; build_argument(?HEADER_X_MAX_LENGTH, Val) -> - {list_to_binary(?HEADER_X_MAX_LENGTH), long, - list_to_integer(string:strip(Val))}; + {?HEADER_X_MAX_LENGTH, long, binary_to_integer(string:trim(Val))}; build_argument(?HEADER_X_MAX_LENGTH_BYTES, Val) -> - {list_to_binary(?HEADER_X_MAX_LENGTH_BYTES), long, - list_to_integer(string:strip(Val))}; + {?HEADER_X_MAX_LENGTH_BYTES, long, binary_to_integer(string:trim(Val))}; build_argument(?HEADER_X_MAX_PRIORITY, Val) -> - {list_to_binary(?HEADER_X_MAX_PRIORITY), long, - list_to_integer(string:strip(Val))}; + {?HEADER_X_MAX_PRIORITY, long, binary_to_integer(string:trim(Val))}; build_argument(?HEADER_X_MESSAGE_TTL, Val) -> - {list_to_binary(?HEADER_X_MESSAGE_TTL), long, - list_to_integer(string:strip(Val))}; + {?HEADER_X_MESSAGE_TTL, long, binary_to_integer(string:trim(Val))}; build_argument(?HEADER_X_MAX_AGE, Val) -> - {list_to_binary(?HEADER_X_MAX_AGE), longstr, - list_to_binary(string:strip(Val))}; + {?HEADER_X_MAX_AGE, longstr, string:trim(Val)}; build_argument(?HEADER_X_STREAM_MAX_SEGMENT_SIZE_BYTES, Val) -> - {list_to_binary(?HEADER_X_STREAM_MAX_SEGMENT_SIZE_BYTES), long, - list_to_integer(string:strip(Val))}; + {?HEADER_X_STREAM_MAX_SEGMENT_SIZE_BYTES, long, + binary_to_integer(string:trim(Val))}; build_argument(?HEADER_X_QUEUE_TYPE, Val) -> - {list_to_binary(?HEADER_X_QUEUE_TYPE), longstr, - list_to_binary(string:strip(Val))}; + {?HEADER_X_QUEUE_TYPE, longstr, string:trim(Val)}; build_argument(?HEADER_X_STREAM_FILTER_SIZE_BYTES, Val) -> - {list_to_binary(?HEADER_X_STREAM_FILTER_SIZE_BYTES), long, - list_to_integer(string:strip(Val))}. + {?HEADER_X_STREAM_FILTER_SIZE_BYTES, long, + binary_to_integer(string:trim(Val))}. build_params(EndPoint, Headers) -> - Params = lists:foldl(fun({K, V}, Acc) -> - case lists:member(K, ?HEADER_PARAMS) of - true -> [build_param(K, V) | Acc]; - false -> Acc - end - end, - [], - Headers), + Params = fold_headers(fun(K, V, Acc) -> + case lists:member(K, ?HEADER_PARAMS) of + true -> [build_param(K, V) | Acc]; + false -> Acc + end + end, [], Headers), rabbit_misc:plmerge(default_params(EndPoint), Params). build_param(?HEADER_PERSISTENT, Val) -> @@ -333,19 +339,18 @@ default_params({exchange, _}) -> [{exclusive, true}, {auto_delete, true}]; default_params({topic, _}) -> - [{auto_delete, true}]; + [{exclusive, false}, {auto_delete, true}]; default_params(_) -> - [{exclusive, true}, - {durable, false}]. + [{durable, false}]. -string_to_boolean("True") -> +string_to_boolean(<<"True">>) -> true; -string_to_boolean("true") -> +string_to_boolean(<<"true">>) -> true; -string_to_boolean("False") -> +string_to_boolean(<<"False">>) -> false; -string_to_boolean("false") -> +string_to_boolean(<<"false">>) -> false; string_to_boolean(_) -> undefined. @@ -360,14 +365,14 @@ has_durable_header(Frame) -> %% Destination Formatting %%-------------------------------------------------------------------- -format_destination("", RoutingKey) -> - ?QUEUE_PREFIX ++ "/" ++ escape(RoutingKey); -format_destination("amq.topic", RoutingKey) -> - ?TOPIC_PREFIX ++ "/" ++ escape(RoutingKey); -format_destination(Exchange, "") -> - ?EXCHANGE_PREFIX ++ "/" ++ escape(Exchange); +format_destination(<<>>, RoutingKey) -> + iolist_to_binary([?QUEUE_PREFIX, $/, escape_dest(RoutingKey)]); +format_destination(<<"amq.topic">>, RoutingKey) -> + iolist_to_binary([?TOPIC_PREFIX, $/, escape_dest(RoutingKey)]); +format_destination(Exchange, <<>>) -> + iolist_to_binary([?EXCHANGE_PREFIX, $/, escape_dest(Exchange)]); format_destination(Exchange, RoutingKey) -> - ?EXCHANGE_PREFIX ++ "/" ++ escape(Exchange) ++ "/" ++ escape(RoutingKey). + iolist_to_binary([?EXCHANGE_PREFIX, $/, escape_dest(Exchange), $/, escape_dest(RoutingKey)]). %%-------------------------------------------------------------------- %% Destination Parsing @@ -376,14 +381,10 @@ format_destination(Exchange, RoutingKey) -> subscription_queue_name(Destination, SubscriptionId, Frame) -> case rabbit_stomp_frame:header(Frame, ?HEADER_X_QUEUE_NAME, undefined) of undefined -> - %% We need a queue name that a) can be derived from the - %% Destination and SubscriptionId, and b) meets the constraints on - %% AMQP queue names. It doesn't need to be secure; we use md5 here - %% simply as a convenient means to bound the length. - rabbit_guid:string( - erlang:md5( - term_to_binary_compat:term_to_binary_1( - {Destination, SubscriptionId})), + rabbit_guid:binary( + erlang:md5( + term_to_binary_compat:term_to_binary_1( + {Destination, SubscriptionId})), "stomp-subscription"); Name -> Name @@ -391,32 +392,16 @@ subscription_queue_name(Destination, SubscriptionId, Frame) -> %% ---- Helpers ---- -split([], _Splitter) -> []; -split(Content, Splitter) -> split(Content, [], [], Splitter). - -split([], RPart, RParts, _Splitter) -> - lists:reverse([lists:reverse(RPart) | RParts]); -split(Content = [Elem | Rest1], RPart, RParts, Splitter) -> - case take_prefix(Splitter, Content) of - {ok, Rest2} -> - split(Rest2, [], [lists:reverse(RPart) | RParts], Splitter); - not_found -> - split(Rest1, [Elem | RPart], RParts, Splitter) - end. - -take_prefix([Char | Prefix], [Char | List]) -> take_prefix(Prefix, List); -take_prefix([], List) -> {ok, List}; -take_prefix(_Prefix, _List) -> not_found. - -escape(Str) -> escape(Str, []). +escape_dest(Bin) -> escape_dest(Bin, []). -escape([$/ | Str], Acc) -> escape(Str, "F2%" ++ Acc); %% $/ == '2F'x -escape([$% | Str], Acc) -> escape(Str, "52%" ++ Acc); %% $% == '25'x -escape([X | Str], Acc) when X < 32 orelse X > 127 -> - escape(Str, revhex(X) ++ "%" ++ Acc); -escape([C | Str], Acc) -> escape(Str, [C | Acc]); -escape([], Acc) -> lists:reverse(Acc). +escape_dest(<<>>, Acc) -> iolist_to_binary(lists:reverse(Acc)); +escape_dest(<<$/, Rest/binary>>, Acc) -> escape_dest(Rest, [<<"%2F">> | Acc]); +escape_dest(<<$%, Rest/binary>>, Acc) -> escape_dest(Rest, [<<"%25">> | Acc]); +escape_dest(<>, Acc) when X < 32; X > 127 -> + escape_dest(Rest, [revhex_bin(X) | Acc]); +escape_dest(<>, Acc) -> escape_dest(Rest, [C | Acc]). -revhex(I) -> hexdig(I) ++ hexdig(I bsr 4). +revhex_bin(I) -> + iolist_to_binary([$%, hexdig(I bsr 4), hexdig(I)]). hexdig(I) -> erlang:integer_to_list(I band 15, 16). diff --git a/deps/rabbitmq_stomp/test/command_SUITE.erl b/deps/rabbitmq_stomp/test/command_SUITE.erl index 40cf4fd65a12..196fd9549d63 100644 --- a/deps/rabbitmq_stomp/test/command_SUITE.erl +++ b/deps/rabbitmq_stomp/test/command_SUITE.erl @@ -95,7 +95,7 @@ run(Config) -> start_amqp_connection(direct, Node, Port), - %% Still two MQTT connections, one direct AMQP 0-9-1 connection + %% Still two STOMP connections, one direct AMQP 0-9-1 connection [[{session_id, _}], [{session_id, _}]] = 'Elixir.Enum':to_list(?COMMAND:run([<<"session_id">>], Opts)), diff --git a/deps/rabbitmq_stomp/test/connections_SUITE.erl b/deps/rabbitmq_stomp/test/connections_SUITE.erl index 703e06a459f1..2cf5cfb2883d 100644 --- a/deps/rabbitmq_stomp/test/connections_SUITE.erl +++ b/deps/rabbitmq_stomp/test/connections_SUITE.erl @@ -10,21 +10,25 @@ -import(rabbit_misc, [pget/2]). --include_lib("amqp_client/include/amqp_client.hrl"). +-include_lib("rabbit_common/include/rabbit.hrl"). -include_lib("rabbitmq_ct_helpers/include/rabbit_assert.hrl"). -include("rabbit_stomp_frame.hrl"). --define(DESTINATION, "/queue/bulk-test"). +-define(DESTINATION, <<"/queue/bulk-test">>). all() -> [ messages_not_dropped_on_disconnect, - direct_client_connections_are_not_leaked, + connections_not_leaked_on_parse_error, stats_are_not_leaked, stats, heartbeat, login_timeout, frame_size, - frame_size_huge + frame_size_huge, + unauthenticated_send_returns_error, + unauthenticated_subscribe_returns_error, + unauthenticated_disconnect_returns_error, + unauthenticated_error_does_not_crash_reader ]. merge_app_env(Config) -> @@ -80,7 +84,7 @@ rpc_count_connections(Config, ConnSpec) -> rabbit_ct_broker_helpers:rpc(Config, 0, ranch_server, count_connections, [ConnSpec]). -direct_client_connections_are_not_leaked(Config) -> +connections_not_leaked_on_parse_error(Config) -> StompPort = get_stomp_port(Config), N = count_connections(Config), lists:foreach(fun (_) -> @@ -88,23 +92,22 @@ direct_client_connections_are_not_leaked(Config) -> %% send garbage which trips up the parser gen_tcp:send(Socket, ?GARBAGE), rabbit_stomp_client:send( - Client, "LOL", [{"", ""}]) + Client, "LOL", [{<<"">>, <<"">>}]) end, lists:seq(1, 100)), - ?awaitMatch(N, count_connections(Config), 30_000), + timer:sleep(5000), + N = count_connections(Config), ok. messages_not_dropped_on_disconnect(Config) -> StompPort = get_stomp_port(Config), N = count_connections(Config), {ok, Client} = rabbit_stomp_client:connect(StompPort), - %% STOMP connection registration is asynchronous; wait for it to be - %% readable (as in consistency). N1 = N + 1, - ?awaitMatch(N1, count_connections(Config), 30_000), + N1 = count_connections(Config), [rabbit_stomp_client:send( - Client, "SEND", [{"destination", ?DESTINATION}], - [integer_to_list(Count)]) || Count <- lists:seq(1, 1000)], + Client, 'SEND', [{<<"destination">>, ?DESTINATION}], + [integer_to_binary(Count)]) || Count <- lists:seq(1, 1000)], rabbit_stomp_client:disconnect(Client), QName = rabbit_misc:r(<<"/">>, queue, <<"bulk-test">>), ?awaitMatch(N, count_connections(Config), 30_000), @@ -132,21 +135,21 @@ stats_are_not_leaked(Config) -> Bin = <<"GET / HTTP/1.1\r\nHost: www.rabbitmq.com\r\nUser-Agent: curl/7.43.0\r\nAccept: */*\n\n">>, gen_tcp:send(C, Bin), gen_tcp:close(C), - ?awaitMatch(N, rabbit_ct_broker_helpers:rpc(Config, 0, ets, info, [connection_metrics, size]), 5_000), + timer:sleep(1000), %% Wait for stats to be emitted, which it does every 100ms + N = rabbit_ct_broker_helpers:rpc(Config, 0, ets, info, [connection_metrics, size]), ok. stats(Config) -> StompPort = get_stomp_port(Config), {ok, Client} = rabbit_stomp_client:connect(StompPort), + timer:sleep(1000), %% Wait for stats to be emitted, which it does every 100ms %% Retrieve the connection Pid - [Reader] = ?awaitMatch([_], rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_stomp, list, []), 5_000), + [Reader] = rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_stomp, list, []), [{_, Pid}] = rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_stomp_reader, info, [Reader, [connection]]), %% Verify the content of the metrics, garbage_collection must be present - [{Pid, Props}] = ?awaitMatch([{Pid, _}], - rabbit_ct_broker_helpers:rpc(Config, 0, ets, lookup, - [connection_metrics, Pid]), - 5_000), + [{Pid, Props}] = rabbit_ct_broker_helpers:rpc(Config, 0, ets, lookup, + [connection_metrics, Pid]), true = proplists:is_defined(garbage_collection, Props), 0 = proplists:get_value(timeout, Props), %% If the coarse entry is present, stats were successfully emitted @@ -158,16 +161,15 @@ stats(Config) -> heartbeat(Config) -> StompPort = get_stomp_port(Config), {ok, Client} = rabbit_stomp_client:connect("1.2", "guest", "guest", StompPort, - [{"heart-beat", "5000,7000"}]), + [{<<"heart-beat">>, <<"5000,7000">>}]), + timer:sleep(1000), %% Wait for stats to be emitted, which it does every 100ms %% Retrieve the connection Pid - [Reader] = ?awaitMatch([_], rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_stomp, list, []), 5_000), + [Reader] = rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_stomp, list, []), [{_, Pid}] = rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_stomp_reader, info, [Reader, [connection]]), %% Verify the content of the heartbeat timeout - [{Pid, Props}] = ?awaitMatch([{Pid, _}], - rabbit_ct_broker_helpers:rpc(Config, 0, ets, lookup, - [connection_metrics, Pid]), - 5_000), + [{Pid, Props}] = rabbit_ct_broker_helpers:rpc(Config, 0, ets, lookup, + [connection_metrics, Pid]), 5 = proplists:get_value(timeout, Props), rabbit_stomp_client:disconnect(Client), ok. @@ -193,9 +195,9 @@ frame_size(Config) -> application, set_env, [rabbitmq_stomp, max_frame_size, 80]), StompPort = get_stomp_port(Config), {ok, Client} = rabbit_stomp_client:connect("1.2", "guest", "guest", StompPort, - [{"heart-beat", "5000,7000"}]), + [{<<"heart-beat">>, <<"5000,7000">>}]), ok = rabbit_stomp_client:send( - Client, "SEND", [{"destination", "qwe"}], + Client, 'SEND', [{<<"destination">>, <<"qwe">>}], ["Lorem ipsum dolor sit amet viverra fusce. " "Lorem ipsum dolor sit amet viverra fusce. " "Lorem ipsum dolor sit amet viverra fusce." @@ -212,10 +214,53 @@ frame_size_huge(Config) -> application, set_env, [rabbitmq_stomp, max_frame_size, 700]), StompPort = get_stomp_port(Config), {ok, Client} = rabbit_stomp_client:connect("1.2", "guest", "guest", StompPort, - [{"heart-beat", "5000,7000"}]), + [{<<"heart-beat">>, <<"5000,7000">>}]), rabbit_stomp_client:send( - Client, "SEND", [{"destination", "qwe"}], + Client, 'SEND', [{<<"destination">>, <<"qwe">>}], [base64:encode(crypto:strong_rand_bytes(100000000))]), {S, _} = Client, {error, closed} = gen_tcp:recv(S, 0, 500), ok. + +%% Sending a SEND frame before CONNECT must produce an ERROR +%% frame, not crash the reader process. +unauthenticated_send_returns_error(Config) -> + StompPort = get_stomp_port(Config), + {ok, Sock} = gen_tcp:connect(localhost, StompPort, [{active, false}, binary]), + ok = gen_tcp:send(Sock, <<"SEND\ndestination:/queue/foo\n\nhello\0">>), + {ok, Data} = gen_tcp:recv(Sock, 0, 5000), + {ok, Frame, _} = rabbit_stomp_frame:parse(Data, rabbit_stomp_frame:initial_state()), + 'ERROR' = Frame#stomp_frame.command, + gen_tcp:close(Sock). + +unauthenticated_subscribe_returns_error(Config) -> + StompPort = get_stomp_port(Config), + {ok, Sock} = gen_tcp:connect(localhost, StompPort, [{active, false}, binary]), + ok = gen_tcp:send(Sock, <<"SUBSCRIBE\ndestination:/queue/foo\nid:0\n\n\0">>), + {ok, Data} = gen_tcp:recv(Sock, 0, 5000), + {ok, Frame, _} = rabbit_stomp_frame:parse(Data, rabbit_stomp_frame:initial_state()), + 'ERROR' = Frame#stomp_frame.command, + gen_tcp:close(Sock). + +unauthenticated_disconnect_returns_error(Config) -> + StompPort = get_stomp_port(Config), + {ok, Sock} = gen_tcp:connect(localhost, StompPort, [{active, false}, binary]), + ok = gen_tcp:send(Sock, <<"DISCONNECT\n\n\0">>), + {ok, Data} = gen_tcp:recv(Sock, 0, 5000), + {ok, Frame, _} = rabbit_stomp_frame:parse(Data, rabbit_stomp_frame:initial_state()), + 'ERROR' = Frame#stomp_frame.command, + gen_tcp:close(Sock). + +%% After rejecting an unauthenticated frame, the reader must still +%% be alive and able to accept a proper CONNECT. +unauthenticated_error_does_not_crash_reader(Config) -> + StompPort = get_stomp_port(Config), + N = count_connections(Config), + {ok, Sock} = gen_tcp:connect(localhost, StompPort, [{active, false}, binary]), + ok = gen_tcp:send(Sock, <<"SEND\ndestination:/queue/foo\n\nhello\0">>), + {ok, Data} = gen_tcp:recv(Sock, 0, 5000), + {ok, ErrFrame, _} = rabbit_stomp_frame:parse(Data, rabbit_stomp_frame:initial_state()), + 'ERROR' = ErrFrame#stomp_frame.command, + gen_tcp:close(Sock), + ?awaitMatch(N, count_connections(Config), 5_000), + ok. diff --git a/deps/rabbitmq_stomp/test/frame_SUITE.erl b/deps/rabbitmq_stomp/test/frame_SUITE.erl index f3bc3937de1a..6e549c0a5b6d 100644 --- a/deps/rabbitmq_stomp/test/frame_SUITE.erl +++ b/deps/rabbitmq_stomp/test/frame_SUITE.erl @@ -8,7 +8,6 @@ -module(frame_SUITE). -include_lib("eunit/include/eunit.hrl"). --include_lib("amqp_client/include/amqp_client.hrl"). -include("rabbit_stomp_frame.hrl"). -include("rabbit_stomp_headers.hrl"). -compile(export_all). @@ -25,6 +24,8 @@ all() -> parse_carriage_return_not_ignored_interframe, parse_carriage_return_mid_command, parse_carriage_return_end_command, + parse_unknown_command, + parse_unknown_command_short, parse_resume_mid_command, parse_resume_mid_header_key, parse_resume_mid_header_val, @@ -49,168 +50,176 @@ parse_simple_frame_crlf(_) -> parse_simple_frame_gen("\r\n"). parse_simple_frame_gen(Term) -> - Headers = [{"header1", "value1"}, {"header2", "value2"}], - Content = frame_string("COMMAND", + Headers = [{<<"header1">>, <<"value1">>}, {<<"header2">>, <<"value2">>}], + Content = frame_string('CONNECT', Headers, "Body Content", Term), - {"COMMAND", Frame, _State} = parse_complete(Content), + {'CONNECT', Frame, _State} = parse_complete(Content), [?assertEqual({ok, Value}, rabbit_stomp_frame:header(Frame, Key)) || {Key, Value} <- Headers], - #stomp_frame{body_iolist = Body} = Frame, + #stomp_frame{body_iolist_rev = Body} = Frame, ?assertEqual(<<"Body Content">>, iolist_to_binary(Body)). parse_command_only(_) -> - {ok, #stomp_frame{command = "COMMAND"}, _Rest} = parse("COMMAND\n\n\0"). + {ok, #stomp_frame{command = 'CONNECT'}, _Rest} = parse("CONNECT\n\n\0"). parse_command_prefixed_with_newline(_) -> - {ok, #stomp_frame{command = "COMMAND"}, _Rest} = parse("\nCOMMAND\n\n\0"). + {ok, #stomp_frame{command = 'CONNECT'}, _Rest} = parse("\nCONNECT\n\n\0"). parse_ignore_empty_frames(_) -> - {ok, #stomp_frame{command = "COMMAND"}, _Rest} = parse("\0\0COMMAND\n\n\0"). + {ok, #stomp_frame{command = 'CONNECT'}, _Rest} = parse("\0\0CONNECT\n\n\0"). parse_heartbeat_interframe(_) -> - {ok, #stomp_frame{command = "COMMAND"}, _Rest} = parse("\nCOMMAND\n\n\0"). + {ok, #stomp_frame{command = 'CONNECT'}, _Rest} = parse("\nCONNECT\n\n\0"). parse_crlf_interframe(_) -> - {ok, #stomp_frame{command = "COMMAND"}, _Rest} = parse("\r\nCOMMAND\n\n\0"). + {ok, #stomp_frame{command = 'CONNECT'}, _Rest} = parse("\r\nCONNECT\n\n\0"). parse_carriage_return_not_ignored_interframe(_) -> - {error, {unexpected_chars_between_frames, "\rC"}} = parse("\rCOMMAND\n\n\0"). + {error, {unexpected_chars_between_frames, "\rC"}} = parse("\rCONNECT\n\n\0"). parse_carriage_return_mid_command(_) -> {error, {unexpected_chars_in_command, "\rA"}} = parse("COMM\rAND\n\n\0"). parse_carriage_return_end_command(_) -> - {error, {unexpected_chars_in_command, "\r\r"}} = parse("COMMAND\r\r\n\n\0"). + {error, {unexpected_chars_in_command, "\r\r"}} = parse("CONNECT\r\r\n\n\0"). + +parse_unknown_command(_) -> + %% CR CR triggers a parse error before we reach the command-end transition + {error, {unexpected_chars_in_command, "\r\r"}} = parse("CONNECTA\r\r\n\n\0"). + +parse_unknown_command_short(_) -> + %% Unknown commands produce a frame with the command as binary + {ok, #stomp_frame{command = <<"CONNE">>}, _Rest} = parse("CONNE\n\n\0"). parse_resume_mid_command(_) -> - First = "COMM", - Second = "AND\n\n\0", + First = "CONN", + Second = "ECT\n\n\0", {more, Resume} = parse(First), - {ok, #stomp_frame{command = "COMMAND"}, _Rest} = parse(Second, Resume). + {ok, #stomp_frame{command = 'CONNECT'}, _Rest} = parse(Second, Resume). parse_resume_mid_header_key(_) -> - First = "COMMAND\nheadꙕ", + First = "CONNECT\nheadꙕ", Second = "r1:value1\n\n\0", {more, Resume} = parse(First), - {ok, Frame = #stomp_frame{command = "COMMAND"}, _Rest} = + {ok, Frame = #stomp_frame{command = 'CONNECT'}, _Rest} = parse(Second, Resume), - ?assertEqual({ok, "value1"}, - rabbit_stomp_frame:header(Frame, binary_to_list(<<"headꙕr1"/utf8>>))). + ?assertEqual({ok, <<"value1">>}, + rabbit_stomp_frame:header(Frame, <<"headꙕr1"/utf8>>)). parse_resume_mid_header_val(_) -> - First = "COMMAND\nheader1:val", + First = "CONNECT\nheader1:val", Second = "ue1\n\n\0", {more, Resume} = parse(First), - {ok, Frame = #stomp_frame{command = "COMMAND"}, _Rest} = + {ok, Frame = #stomp_frame{command = 'CONNECT'}, _Rest} = parse(Second, Resume), - ?assertEqual({ok, "value1"}, - rabbit_stomp_frame:header(Frame, "header1")). + ?assertEqual({ok, <<"value1">>}, + rabbit_stomp_frame:header(Frame, <<"header1">>)). parse_resume_mid_body(_) -> - First = "COMMAND\n\nABC", + First = "CONNECT\n\nABC", Second = "DEF\0", {more, Resume} = parse(First), - {ok, #stomp_frame{command = "COMMAND", body_iolist = Body}, _Rest} = + {ok, #stomp_frame{command = 'CONNECT', body_iolist_rev = Body}, _Rest} = parse(Second, Resume), - ?assertEqual([<<"ABC">>, <<"DEF">>], Body). + ?assertEqual([<<"DEF">>, <<"ABC">>], Body). parse_no_header_stripping(_) -> - Content = "COMMAND\nheader: foo \n\n\0", + Content = "CONNECT\nheader: foo \n\n\0", {ok, Frame, _} = parse(Content), - {ok, Val} = rabbit_stomp_frame:header(Frame, "header"), - ?assertEqual(" foo ", Val). + {ok, Val} = rabbit_stomp_frame:header(Frame, <<"header">>), + ?assertEqual(<<" foo ">>, Val). parse_multiple_headers(_) -> - Content = "COMMAND\nheader:correct\nheader:incorrect\n\n\0", + Content = "CONNECT\nheader:correct\nheader:incorrect\n\n\0", {ok, Frame, _} = parse(Content), - {ok, Val} = rabbit_stomp_frame:header(Frame, "header"), - ?assertEqual("correct", Val). + {ok, Val} = rabbit_stomp_frame:header(Frame, <<"header">>), + ?assertEqual(<<"correct">>, Val). header_no_colon(_) -> - Content = "COMMAND\n" + Content = "CONNECT\n" "hdr1:val1\n" "hdrerror\n" "hdr2:val2\n" "\n\0", - ?assertEqual(parse(Content), {error, {header_no_value, "hdrerror"}}). + ?assertEqual(parse(Content), {error, {header_no_value, <<"hdrerror">>}}). no_nested_escapes(_) -> - Content = "COM\\\\rAND\n" % no escapes + Content = "CONNECT\n" % no escapes "hdr\\\\rname:" % one escape "hdr\\\\rval\n\n\0", % one escape {ok, Frame, _} = parse(Content), ?assertEqual(Frame, - #stomp_frame{command = "COM\\\\rAND", - headers = [{"hdr\\rname", "hdr\\rval"}], - body_iolist = []}). + #stomp_frame{command = 'CONNECT', + headers = #{<<"hdr\\rname">> => <<"hdr\\rval">>}, + body_iolist_rev = []}). header_name_with_cr(_) -> - Content = "COMMAND\nhead\rer:val\n\n\0", + Content = "CONNECT\nhead\rer:val\n\n\0", {error, {unexpected_chars_in_header, "\re"}} = parse(Content). header_value_with_cr(_) -> - Content = "COMMAND\nheader:val\rue\n\n\0", + Content = "CONNECT\nheader:val\rue\n\n\0", {error, {unexpected_chars_in_header, "\ru"}} = parse(Content). header_value_with_colon(_) -> - Content = "COMMAND\nheader:val:ue\n\n\0", + Content = "CONNECT\nheader:val:ue\n\n\0", {ok, Frame, _} = parse(Content), ?assertEqual(Frame, - #stomp_frame{ command = "COMMAND", - headers = [{"header", "val:ue"}], - body_iolist = []}). + #stomp_frame{ command = 'CONNECT', + headers = #{<<"header">> => <<"val:ue">>}, + body_iolist_rev = []}). stream_offset_header(_) -> TestCases = [ - {{"x-stream-offset", "first"}, {longstr, <<"first">>}}, - {{"x-stream-offset", "last"}, {longstr, <<"last">>}}, - {{"x-stream-offset", "next"}, {longstr, <<"next">>}}, - {{"x-stream-offset", "offset=5000"}, {long, 5000}}, - {{"x-stream-offset", "timestamp=1000"}, {timestamp, 1000}}, - {{"x-stream-offset", "foo"}, not_found}, - {{"some-header", "some value"}, not_found} + {{<<"x-stream-offset">>, <<"first">>}, {longstr, <<"first">>}}, + {{<<"x-stream-offset">>, <<"last">>}, {longstr, <<"last">>}}, + {{<<"x-stream-offset">>, <<"next">>}, {longstr, <<"next">>}}, + {{<<"x-stream-offset">>, <<"offset=5000">>}, {long, 5000}}, + {{<<"x-stream-offset">>, <<"timestamp=1000">>}, {timestamp, 1000}}, + {{<<"x-stream-offset">>, <<"foo">>}, not_found}, + {{<<"some-header">>, <<"some value">>}, not_found} ], lists:foreach(fun({Header, Expected}) -> ?assertEqual( Expected, - rabbit_stomp_frame:stream_offset_header(#stomp_frame{headers = [Header]}) + rabbit_stomp_frame:stream_offset_header(#stomp_frame{headers = maps:from_list([Header])}) ) end, TestCases). stream_filter_header(_) -> TestCases = [ - {{"x-stream-filter", "banana"}, {array, [{longstr, <<"banana">>}]}}, - {{"x-stream-filter", "banana,apple"}, {array, [{longstr, <<"banana">>}, + {{<<"x-stream-filter">>, <<"banana">>}, {array, [{longstr, <<"banana">>}]}}, + {{<<"x-stream-filter">>, <<"banana,apple">>}, {array, [{longstr, <<"banana">>}, {longstr, <<"apple">>}]}}, - {{"x-stream-filter", "banana,apple,orange"}, {array, [{longstr, <<"banana">>}, + {{<<"x-stream-filter">>, <<"banana,apple,orange">>}, {array, [{longstr, <<"banana">>}, {longstr, <<"apple">>}, {longstr, <<"orange">>}]}}, - {{"some-header", "some value"}, not_found} + {{<<"some-header">>, <<"some value">>}, not_found} ], lists:foreach(fun({Header, Expected}) -> ?assertEqual( Expected, - rabbit_stomp_frame:stream_filter_header(#stomp_frame{headers = [Header]}) + rabbit_stomp_frame:stream_filter_header(#stomp_frame{headers = maps:from_list([Header])}) ) end, TestCases). test_frame_serialization(Expected, TrailingLF) -> {ok, Frame, _} = parse(Expected), - {ok, Val} = rabbit_stomp_frame:header(Frame, "head\r:\ner"), - ?assertEqual(":\n\r\\", Val), + {ok, Val} = rabbit_stomp_frame:header(Frame, <<"head\r:\ner">>), + ?assertEqual(<<":\n\r\\">>, Val), Serialized = lists:flatten(rabbit_stomp_frame:serialize(Frame, TrailingLF)), ?assertEqual(Expected, rabbit_misc:format("~ts", [Serialized])). headers_escaping_roundtrip(_) -> - test_frame_serialization("COMMAND\nhead\\r\\c\\ner:\\c\\n\\r\\\\\n\n\0\n", true). + test_frame_serialization("CONNECT\nhead\\r\\c\\ner:\\c\\n\\r\\\\\n\n\0\n", true). headers_escaping_roundtrip_without_trailing_lf(_) -> - test_frame_serialization("COMMAND\nhead\\r\\c\\ner:\\c\\n\\r\\\\\n\n\0", false). + test_frame_serialization("CONNECT\nhead\\r\\c\\ner:\\c\\n\\r\\\\\n\n\0", false). parse(Content) -> parse(Content, rabbit_stomp_frame:initial_state()). @@ -223,6 +232,5 @@ parse_complete(Content) -> frame_string(Command, Headers, BodyContent, Term) -> HeaderString = - lists:flatten([Key ++ ":" ++ Value ++ Term || {Key, Value} <- Headers]), - Command ++ Term ++ HeaderString ++ Term ++ BodyContent ++ "\0" ++ "\n". - + lists:flatten([binary_to_list(Key) ++ ":" ++ binary_to_list(Value) ++ Term || {Key, Value} <- Headers]), + atom_to_list(Command) ++ Term ++ HeaderString ++ Term ++ BodyContent ++ "\0" ++ "\n". diff --git a/deps/rabbitmq_stomp/test/generate_python_tests.py b/deps/rabbitmq_stomp/test/generate_python_tests.py new file mode 100755 index 000000000000..40db0af88bee --- /dev/null +++ b/deps/rabbitmq_stomp/test/generate_python_tests.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 + +"""Scan Python test files and generate an Erlang include file. + +Structure: + - Single-class file -> CT group named after the class, containing test methods + - Multi-class file -> CT group named after the file, containing class subgroups + +Usage: generate_python_tests.py +""" + +import os +import re +import sys +from collections import OrderedDict + +BROKER_GROUPS = OrderedDict([ + ('tls', ['tls_connect_disconnect']), + ('implicit_connect', ['implicit_connect']), + ('main', [ + 'parsing', + 'errors', + 'connect_disconnect', + 'ack', + 'amqp_headers', + 'queue_properties', + 'reliability', + 'transactions', + 'x_queue_name', + 'destinations', + 'redelivered', + 'topic_permissions', + 'unsubscribe', + 'x_queue_type_quorum', + 'x_queue_type_stream', + ]), +]) + +CLASS_RE = re.compile(r'^class\s+(Test\w+)\s*[\(:]') +METHOD_RE = re.compile(r'^\s+def\s+(test_\w+)\s*\(') + + +def discover_file(src_dir, module): + """Return OrderedDict: class_name -> [method_name, ...]""" + classes = OrderedDict() + path = os.path.join(src_dir, module + '.py') + current_class = None + with open(path) as f: + for line in f: + m = CLASS_RE.match(line) + if m: + current_class = m.group(1) + if current_class not in classes: + classes[current_class] = [] + continue + m = METHOD_RE.match(line) + if m and current_class: + classes[current_class].append(m.group(1)) + return classes + + +def short_class(class_name): + return class_name.removeprefix('Test') + + +def erl_atom(s): + return f"'{s}'" + + +def generate_hrl(src_dir, output_path): + # Collect structure: broker_group -> [(group_name, is_multi, {class -> methods})] + structure = OrderedDict() + for bg, modules in BROKER_GROUPS.items(): + file_groups = [] + for mod in modules: + classes = discover_file(src_dir, mod) + multi = len(classes) > 1 + file_groups.append((mod, multi, classes)) + structure[bg] = file_groups + + all_funcs = [] # (erl_func_name, python_test_id) + + with open(output_path, 'w') as out: + out.write("%% Generated by generate_python_tests.py — do not edit\n\n") + + # For each broker group, emit the subgroup list and class group definitions + for bg, file_groups in structure.items(): + # Subgroup list for this broker group + subgroups = [] + for mod, multi, classes in file_groups: + if multi: + subgroups.append(mod) + else: + cls_name = list(classes.keys())[0] + subgroups.append(short_class(cls_name)) + + macro = bg.upper() + '_SUBGROUPS' + out.write(f"-define({macro}, [\n") + for i, sg in enumerate(subgroups): + comma = ',' if i < len(subgroups) - 1 else '' + out.write(f" {{group, {erl_atom(sg)}}}{comma}\n") + out.write("]).\n\n") + + # Group definitions + for mod, multi, classes in file_groups: + if multi: + # File group containing class subgroups + file_macro = mod.upper() + '_SUBGROUPS' + out.write(f"-define({file_macro}, [\n") + cls_list = list(classes.keys()) + for i, cls in enumerate(cls_list): + sc = short_class(cls) + comma = ',' if i < len(cls_list) - 1 else '' + out.write(f" {{group, {erl_atom(sc)}}}{comma}\n") + out.write("]).\n\n") + + # Each class subgroup + for cls, methods in classes.items(): + sc = short_class(cls) + cls_macro = sc.upper() + '_TESTS' + out.write(f"-define({cls_macro}, [\n") + for i, meth in enumerate(methods): + comma = ',' if i < len(methods) - 1 else '' + fn = f"{sc}.{meth}" + out.write(f" {erl_atom(fn)}{comma}\n") + out.write("]).\n\n") + for meth in methods: + fn = f"{sc}.{meth}" + py_id = f"{mod}.{cls}.{meth}" + all_funcs.append((fn, py_id)) + else: + # Single class -> group directly contains tests + cls = list(classes.keys())[0] + sc = short_class(cls) + methods = classes[cls] + cls_macro = sc.upper() + '_TESTS' + out.write(f"-define({cls_macro}, [\n") + for i, meth in enumerate(methods): + comma = ',' if i < len(methods) - 1 else '' + fn = f"{sc}.{meth}" + out.write(f" {erl_atom(fn)}{comma}\n") + out.write("]).\n\n") + for meth in methods: + fn = f"{sc}.{meth}" + py_id = f"{mod}.{cls}.{meth}" + all_funcs.append((fn, py_id)) + + # CLASS_GROUPS macro — all group definitions for groups() + out.write("-define(CLASS_GROUPS, [\n") + all_group_defs = [] + for bg, file_groups in structure.items(): + for mod, multi, classes in file_groups: + if multi: + all_group_defs.append( + f" {{{erl_atom(mod)}, [], ?{mod.upper()}_SUBGROUPS}}") + for cls in classes: + sc = short_class(cls) + all_group_defs.append( + f" {{{erl_atom(sc)}, [], ?{sc.upper()}_TESTS}}") + else: + cls = list(classes.keys())[0] + sc = short_class(cls) + all_group_defs.append( + f" {{{erl_atom(sc)}, [], ?{sc.upper()}_TESTS}}") + out.write(',\n'.join(all_group_defs)) + out.write("\n]).\n\n") + + # Function definitions + out.write("%% Test function definitions\n\n") + seen = set() + for fn, py_id in all_funcs: + if fn not in seen: + out.write(f"{erl_atom(fn)}(Config) -> " + f"run_one_test(Config, \"{py_id}\").\n") + seen.add(fn) + + print(f"Generated {output_path}: {len(seen)} test functions") + + +if __name__ == '__main__': + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + sys.exit(2) + generate_hrl(sys.argv[1], sys.argv[2]) diff --git a/deps/rabbitmq_stomp/test/prop_frame_SUITE.erl b/deps/rabbitmq_stomp/test/prop_frame_SUITE.erl index 77d0536e0a44..d6ed4f695b77 100644 --- a/deps/rabbitmq_stomp/test/prop_frame_SUITE.erl +++ b/deps/rabbitmq_stomp/test/prop_frame_SUITE.erl @@ -10,10 +10,14 @@ -include_lib("proper/include/proper.hrl"). -include_lib("common_test/include/ct.hrl"). -include_lib("eunit/include/eunit.hrl"). +-include("rabbit_stomp_frame.hrl"). -compile([export_all, nowarn_export_all]). -import(rabbit_ct_proper_helpers, [run_proper/3]). +%% Default max_headers is 100 +-define(LIMIT, 100). + all() -> [{group, properties}]. @@ -22,32 +26,39 @@ groups() -> [ prop_within_limit_succeeds, prop_over_limit_fails, - prop_duplicates_do_not_exhaust_limit + prop_duplicates_do_not_exhaust_limit, + prop_negative_content_length_rejected, + prop_valid_content_length_round_trips, + prop_non_numeric_content_length_does_not_crash, + prop_round_trip, + prop_incremental_parse, + prop_max_body_length_enforced ]}]. -%% Any frame with up to 100 unique header names parses successfully. +%% Any frame with up to LIMIT unique header names parses successfully. prop_within_limit_succeeds(_Config) -> run_proper( fun() -> - ?FORALL(N, range(0, 100), + ?FORALL(N, range(0, ?LIMIT), begin {ok, _, _} = parse(make_frame(unique_headers(N))), true end) end, [], 1000). -%% Any frame with more than 100 unique header names is rejected. +%% Any frame with more than LIMIT unique header names is rejected. prop_over_limit_fails(_Config) -> run_proper( fun() -> - ?FORALL(N, range(101, 300), + ?FORALL(N, range(?LIMIT + 1, 300), begin - {error, too_many_headers} = parse(make_frame(unique_headers(N))), + {error, {max_headers, ?LIMIT}} = parse(make_frame(unique_headers(N))), true end) end, [], 1000). %% Duplicate entries for a header name already seen never trigger the limit. +%% Duplicates are discarded at O(1) with zero allocation. prop_duplicates_do_not_exhaust_limit(_Config) -> run_proper( fun() -> @@ -60,6 +71,124 @@ prop_duplicates_do_not_exhaust_limit(_Config) -> end) end, [], 1000). +%% Any negative content-length is rejected. +prop_negative_content_length_rejected(_Config) -> + run_proper( + fun() -> + ?FORALL(N, range(-10000, -1), + begin + {error, {invalid_content_length, N}} = + parse(send_frame("content-length", integer_to_list(N), "x")), + true + end) + end, [], 1000). + +%% A SEND frame with a matching content-length and body always parses. +prop_valid_content_length_round_trips(_Config) -> + run_proper( + fun() -> + ?FORALL(Body, binary(), + begin + Len = integer_to_list(byte_size(Body)), + case parse(send_frame("content-length", Len, Body)) of + {ok, #stomp_frame{command = 'SEND'}, _} -> true; + {more, _} -> true + end + end) + end, [], 1000). + +%% Non-numeric content-length values must never crash the parser. +prop_non_numeric_content_length_does_not_crash(_Config) -> + run_proper( + fun() -> + ?FORALL(Junk, non_numeric_bin(), + begin + case parse(send_frame("content-length", binary_to_list(Junk), "x")) of + {ok, _, _} -> true; + {more, _} -> true; + {error, _} -> true + end + end) + end, [], 1000). + +%% Serialize then parse preserves command, headers, and body. +%% Headers with escape-triggering characters (colon, backslash, LF, CR) +%% exercise both the fast and slow parser paths. +prop_round_trip(_Config) -> + run_proper( + fun() -> + ?FORALL({RawPairs, Body}, + {resize(5, list({stomp_hdr_name(), stomp_hdr_value()})), + resize(200, binary())}, + begin + Hdrs = maps:from_list( + [{<<"destination">>, <<"/queue/t">>} | + [{K, V} || {K, V} <- RawPairs, + K =/= <<"content-length">>, + K =/= <<"destination">>]]), + Frame = #stomp_frame{command = 'SEND', + headers = Hdrs, + body_iolist_rev = Body}, + Bin = iolist_to_binary(rabbit_stomp_frame:serialize(Frame)), + {ok, Parsed, _} = parse(Bin), + Parsed#stomp_frame.command =:= 'SEND' andalso + lists:all( + fun({K, V}) -> + maps:get(K, Parsed#stomp_frame.headers, undefined) =:= V + end, maps:to_list(Hdrs)) andalso + Body =:= body_to_binary(Parsed) + end) + end, [], 500). + +%% Splitting a valid frame at any byte boundary and parsing in two +%% chunks must produce the same frame as parsing in one call. +prop_incremental_parse(_Config) -> + run_proper( + fun() -> + ?FORALL({Body, N}, + {resize(200, binary()), non_neg_integer()}, + begin + Bin = iolist_to_binary( + rabbit_stomp_frame:serialize( + #stomp_frame{command = 'SEND', + headers = #{<<"destination">> => <<"/queue/t">>}, + body_iolist_rev = Body})), + Pos = N rem (byte_size(Bin) + 1), + <> = Bin, + {ok, Full, _} = parse(Bin), + Chunked = case parse(C1) of + {ok, F, _} -> F; + {more, St} -> + {ok, F2, _} = rabbit_stomp_frame:parse(C2, St), + F2 + end, + Full#stomp_frame.command =:= Chunked#stomp_frame.command andalso + Full#stomp_frame.headers =:= Chunked#stomp_frame.headers andalso + body_to_binary(Full) =:= body_to_binary(Chunked) + end) + end, [], 1000). + +%% A SEND frame exceeding max_body_length is always rejected. +prop_max_body_length_enforced(_Config) -> + run_proper( + fun() -> + ?FORALL({MaxLen, BodySize}, + {range(1, 500), range(0, 1000)}, + begin + Body = binary:copy(<<"x">>, BodySize), + Bin = iolist_to_binary( + rabbit_stomp_frame:serialize( + #stomp_frame{command = 'SEND', + headers = #{<<"destination">> => <<"/queue/t">>}, + body_iolist_rev = Body})), + Config = #stomp_parser_config{max_body_length = MaxLen}, + case rabbit_stomp_frame:parse(Bin, rabbit_stomp_frame:initial_state(Config)) of + {ok, _, _} -> BodySize =< MaxLen; + {error, _} -> BodySize > MaxLen + end + end) + end, [], 1000). + %%------------------------------------------------------------------- unique_headers(N) -> @@ -69,5 +198,41 @@ make_frame(Headers) -> HdrStr = lists:flatten([K ++ ":" ++ V ++ "\n" || {K, V} <- Headers]), iolist_to_binary(["CONNECT\n", HdrStr, "\n\0"]). +send_frame(HdrName, HdrValue, Body) -> + iolist_to_binary(["SEND\ndestination:/queue/t\n", + HdrName, ":", HdrValue, "\n\n", + Body, "\0"]). + +%% Produces binaries that are not valid integer strings. +non_numeric_bin() -> + ?SUCHTHAT(Bin, + ?LET(Chars, list(range(0, 255)), + list_to_binary( + [C || C <- Chars, C =/= $\n, C =/= $\r, C =/= $:, C =/= $\\, C =/= 0])), + not is_integer(catch binary_to_integer(string:trim(Bin)))). + +%% Produces non-empty, no NUL header names. Biased towards escaped chars. +stomp_hdr_name() -> + ?SUCHTHAT(Bin, + ?LET(Chars, resize(15, non_empty(list(stomp_char()))), + list_to_binary(Chars)), + Bin =/= <<"content-length">> andalso + Bin =/= <<"destination">>). + +%% Produces no-NUL header values. Biased towards escaped chars. +stomp_hdr_value() -> + ?LET(Chars, resize(20, list(stomp_char())), + list_to_binary(Chars)). + +%% Biased towards characters that trigger the escape slow path. +stomp_char() -> + frequency([{3, $:}, {3, $\\}, {3, $\n}, {3, $\r}, + {88, range(32, 126)}]). + +body_to_binary(#stomp_frame{body_iolist_rev = Rev}) when is_list(Rev) -> + iolist_to_binary(lists:reverse(Rev)); +body_to_binary(#stomp_frame{body_iolist_rev = Bin}) when is_binary(Bin) -> + Bin. + parse(Bin) -> rabbit_stomp_frame:parse(Bin, rabbit_stomp_frame:initial_state()). diff --git a/deps/rabbitmq_stomp/test/proxy_protocol_SUITE.erl b/deps/rabbitmq_stomp/test/proxy_protocol_SUITE.erl index 5f42f3c0c15c..33ff28720f35 100644 --- a/deps/rabbitmq_stomp/test/proxy_protocol_SUITE.erl +++ b/deps/rabbitmq_stomp/test/proxy_protocol_SUITE.erl @@ -10,7 +10,6 @@ -include_lib("common_test/include/ct.hrl"). -include_lib("eunit/include/eunit.hrl"). --include_lib("rabbitmq_ct_helpers/include/rabbit_assert.hrl"). -define(TIMEOUT, 5000). @@ -37,13 +36,10 @@ init_per_suite(Config) -> {rabbitmq_ct_tls_verify, verify_none} ]), StompConfig = stomp_config(), - Config2 = rabbit_ct_helpers:run_setup_steps(Config1, - [ fun(Conf) -> merge_app_env(StompConfig, Conf) end ] ++ - rabbit_ct_broker_helpers:setup_steps() ++ - rabbit_ct_client_helpers:setup_steps()), - rabbit_ct_broker_helpers:add_user(Config2, <<"proxy_test">>, <<"proxy_test">>), - rabbit_ct_broker_helpers:set_full_permissions(Config2, <<"proxy_test">>, <<"/">>), - Config2. + rabbit_ct_helpers:run_setup_steps(Config1, + [ fun(Conf) -> merge_app_env(StompConfig, Conf) end ] ++ + rabbit_ct_broker_helpers:setup_steps() ++ + rabbit_ct_client_helpers:setup_steps()). stomp_config() -> {rabbitmq_stomp, [ @@ -71,8 +67,9 @@ proxy_protocol_v1(Config) -> ok = inet:send(Socket, "PROXY TCP4 192.168.1.1 192.168.1.2 80 81\r\n"), ok = inet:send(Socket, stomp_connect_frame()), {ok, _Packet} = gen_tcp:recv(Socket, 0, ?TIMEOUT), - await_connection_name_match( - Config, <<"^192.168.1.1:80 -> 192.168.1.2:81$">>), + ConnectionName = rabbit_ct_broker_helpers:rpc(Config, 0, + ?MODULE, connection_name, []), + match = re:run(ConnectionName, <<"^192.168.1.1:80 -> 192.168.1.2:81$">>, [{capture, none}]), gen_tcp:close(Socket), ok. @@ -85,8 +82,9 @@ proxy_protocol_v1_tls(Config) -> {ok, SslSocket} = ssl:connect(Socket, [{verify, verify_none}], ?TIMEOUT), ok = ssl:send(SslSocket, stomp_connect_frame()), {ok, _Packet} = ssl:recv(SslSocket, 0, ?TIMEOUT), - await_connection_name_match( - Config, <<"^192.168.1.1:80 -> 192.168.1.2:81$">>), + ConnectionName = rabbit_ct_broker_helpers:rpc(Config, 0, + ?MODULE, connection_name, []), + match = re:run(ConnectionName, <<"^192.168.1.1:80 -> 192.168.1.2:81$">>, [{capture, none}]), gen_tcp:close(Socket), ok. @@ -101,38 +99,33 @@ proxy_protocol_v2_local(Config) -> ok = inet:send(Socket, ranch_proxy_header:header(ProxyInfo)), ok = inet:send(Socket, stomp_connect_frame()), {ok, _Packet} = gen_tcp:recv(Socket, 0, ?TIMEOUT), - await_connection_name_match( - Config, <<"^127.0.0.1:\\d+ -> 127.0.0.1:\\d+$">>), + ConnectionName = rabbit_ct_broker_helpers:rpc(Config, 0, + ?MODULE, connection_name, []), + match = re:run(ConnectionName, <<"^127.0.0.1:\\d+ -> 127.0.0.1:\\d+$">>, [{capture, none}]), gen_tcp:close(Socket), ok. -%% The `connection_created' ETS table is populated asynchronously by -%% the management agent; wait for an entry whose `name' matches the -%% pattern. -await_connection_name_match(Config, Pattern) -> - ?awaitMatch(true, - rabbit_ct_broker_helpers:rpc( - Config, 0, ?MODULE, has_connection_name_matching, [Pattern]), - 30_000). - -has_connection_name_matching(Pattern) -> - Connections = ets:tab2list(connection_created), - lists:any( - fun({_Key, Values}) -> - case lists:keyfind(name, 1, Values) of - {_, Name} -> - re:run(Name, Pattern, [{capture, none}]) =:= match; - false -> - false - end - end, Connections). +connection_name() -> + connection_name(50). + +connection_name(0) -> + error(no_stomp_connection_found); +connection_name(Retries) -> + case ets:tab2list(connection_created) of + [{_Key, Values} | _] -> + {_, Name} = lists:keyfind(conn_name, 1, Values), + Name; + [] -> + timer:sleep(50), + connection_name(Retries - 1) + end. merge_app_env(StompConfig, Config) -> rabbit_ct_helpers:merge_app_env(Config, StompConfig). stomp_connect_frame() -> <<"CONNECT\n", - "login:proxy_test\n", - "passcode:proxy_test\n", + "login:guest\n", + "passcode:guest\n", "\n", 0>>. diff --git a/deps/rabbitmq_stomp/test/python_SUITE.erl b/deps/rabbitmq_stomp/test/python_SUITE.erl index 63051d21fff3..453e07c7de89 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE.erl +++ b/deps/rabbitmq_stomp/test/python_SUITE.erl @@ -2,57 +2,61 @@ %% License, v. 2.0. If a copy of the MPL was not distributed with this %% file, You can obtain one at https://mozilla.org/MPL/2.0/. %% -%% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term “Broadcom” refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. +%% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. %% -module(python_SUITE). -compile(export_all). -include_lib("common_test/include/ct.hrl"). +%% Generated: one CT test per Python test method, grouped by class. +%% Regenerate: python3 test/generate_python_tests.py \ +%% test/python_SUITE_data/src test/python_SUITE_generated.hrl +-include("python_SUITE_generated.hrl"). + all() -> - [ - %% This must use a dedicated node as they mess with plugin configuration in incompatible ways - {group, tls}, - {group, implicit_connect}, - {group, main} - ]. + [{group, tls}, + {group, implicit_connect}, + {group, main}]. groups() -> - [ - {main, [], [ - main - ]}, - {implicit_connect, [], [ - implicit_connect - ]}, - {tls, [], [ - tls_connections - ]} - ]. + [{tls, [], ?TLS_SUBGROUPS}, + {implicit_connect, [], ?IMPLICIT_CONNECT_SUBGROUPS}, + {main, [], ?MAIN_SUBGROUPS} + | ?CLASS_GROUPS]. init_per_suite(Config) -> - {ok, _} = rabbit_ct_helpers:exec(["pip", "install", "-r", requirements_path(Config), - "--target", deps_path(Config)]), + {ok, _} = rabbit_ct_helpers:exec( + ["pip", "install", "-r", requirements_path(Config), + "--target", deps_path(Config)]), Config. end_per_suite(Config) -> ok = file:del_dir_r(deps_path(Config)), Config. +init_per_group(main, Config) -> + Config1 = init_broker(Config), + rabbit_ct_broker_helpers:rpc( + Config1, 0, + application, set_env, [rabbitmq_stomp, max_frame_size, 17 * 1024 * 1024]), + Config1; +init_per_group(tls, Config) -> + Config1 = init_broker(Config), + ensure_ssl_auth_user(Config1), + Config1; +init_per_group(implicit_connect, Config) -> + init_broker(Config); init_per_group(_, Config) -> - Config0 = rabbit_ct_helpers:set_config(Config, - [ - {rmq_nodename_suffix, ?MODULE}, - {rmq_certspwd, "bunnychow"} - ]), - rabbit_ct_helpers:log_environment(), - rabbit_ct_helpers:run_setup_steps( - Config0, - rabbit_ct_broker_helpers:setup_steps()). + Config. -end_per_group(_, Config) -> +end_per_group(Group, Config) when Group =:= main; + Group =:= tls; + Group =:= implicit_connect -> rabbit_ct_helpers:run_teardown_steps(Config, - rabbit_ct_broker_helpers:teardown_steps()). + rabbit_ct_broker_helpers:teardown_steps()); +end_per_group(_, Config) -> + Config. init_per_testcase(Test, Config) -> rabbit_ct_helpers:testcase_started(Config, Test). @@ -60,21 +64,39 @@ init_per_testcase(Test, Config) -> end_per_testcase(Test, Config) -> rabbit_ct_helpers:testcase_finished(Config, Test). +run_one_test(Config, PythonTestId) -> + DataDir = ?config(data_dir, Config), + SrcDir = filename:join(DataDir, "src"), + setup_python_env(Config), + {ok, _} = rabbit_ct_helpers:exec( + ["python3", "-m", "unittest", "-v", PythonTestId], + [{cd, SrcDir}]). -main(Config) -> - rabbit_ct_broker_helpers:rpc( - Config, 0, - application, set_env, [rabbitmq_stomp, max_frame_size, 17 * 1024 * 1024]), - run(Config, filename:join("src", "main_runner.py")). - -implicit_connect(Config) -> - run(Config, filename:join("src", "implicit_connect_runner.py")). +%% +%% Internal +%% -tls_connections(Config) -> - run(Config, filename:join("src", "tls_runner.py")). +init_broker(Config) -> + Config0 = rabbit_ct_helpers:set_config(Config, + [{rmq_nodename_suffix, ?MODULE}, + {rmq_certspwd, "bunnychow"}]), + rabbit_ct_helpers:log_environment(), + Config1 = rabbit_ct_helpers:merge_app_env( + Config0, + {rabbit, + [{permit_deprecated_features, #{transient_nonexcl_queues => true}}]}), + rabbit_ct_helpers:run_setup_steps( + Config1, + rabbit_ct_broker_helpers:setup_steps()). +ensure_ssl_auth_user(Config) -> + Host = net_adm:localhost(), + User = "O=client,CN=" ++ Host, + rabbit_ct_broker_helpers:rabbitmqctl(Config, 0, ["add_user", User, "foo"]), + rabbit_ct_broker_helpers:rabbitmqctl(Config, 0, ["clear_password", User]), + rabbit_ct_broker_helpers:rabbitmqctl(Config, 0, ["set_permissions", User, ".*", ".*", ".*"]). -run(Config, Test) -> +setup_python_env(Config) -> CertsDir = rabbit_ct_helpers:get_config(Config, rmq_certsdir), StompPort = rabbit_ct_broker_helpers:get_node_config(Config, 0, tcp_port_stomp), StompPortTls = rabbit_ct_broker_helpers:get_node_config(Config, 0, tcp_port_stomp_tls), @@ -87,12 +109,7 @@ run(Config, Test) -> os:putenv("STOMP_PORT_TLS", integer_to_list(StompPortTls)), os:putenv("RABBITMQ_NODENAME", atom_to_list(NodeName)), os:putenv("SSL_CERTS_PATH", CertsDir), - run_python(Config, Test). - -run_python(Config, What) -> - DataDir = ?config(data_dir, Config), - os:putenv("PYTHONPATH", python_path(Config)), - {ok, _} = rabbit_ct_helpers:exec([filename:join(DataDir, What)]). + os:putenv("PYTHONPATH", python_path(Config)). deps_path(Config) -> DataDir = ?config(data_dir, Config), @@ -107,7 +124,3 @@ python_path(Config) -> false -> deps_path(Config); P -> deps_path(Config) ++ ":" ++ P end. - -cur_dir() -> - {ok, Src} = filelib:find_source(?MODULE), - filename:dirname(Src). diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/destinations.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/destinations.py index a4a6cea6bb9b..b7fe5c06c5c8 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/destinations.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/destinations.py @@ -37,7 +37,7 @@ def test_invalid_exchange(self): self.assertListener("Expecting an error", numErrs=1) err = self.listener.errors[0] self.assertEqual("not_found", err['headers']['message']) - self.assertRegex(err['message'], r'^NOT_FOUND') + self.assertRegex(err['message'], r'^no exchange') time.sleep(1) self.assertFalse(self.conn.is_connected()) @@ -489,4 +489,4 @@ def test_durable_subscribe_no_id_and_legacy_header(self): modules = [ __name__ ] - test_runner.run_unittests(modules) \ No newline at end of file + test_runner.run_unittests(modules) diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/errors.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/errors.py index a9697dc1a243..b5cccf5f6b86 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/errors.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/errors.py @@ -74,7 +74,7 @@ def test_unknown_destination(self): def test_send_missing_destination(self): self.__test_missing_destination("SEND") - def test_send_missing_destination(self): + def test_subscribe_missing_destination(self): self.__test_missing_destination("SUBSCRIBE") def __test_missing_destination(self, command): diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/parsing.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/parsing.py index da32b6afc76e..c254a85a50d4 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/parsing.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/parsing.py @@ -21,12 +21,6 @@ def connect(cnames): 'passcode:guest\n' '\n' '\n\0') - resp = ('CONNECTED\n' - 'server:RabbitMQ/(.*)\n' - 'session:(.*)\n' - 'heart-beat:0,0\n' - 'version:1.0\n' - '\n\x00') def w(m): @functools.wraps(m) def wrapper(self, *args, **kwargs): @@ -35,7 +29,11 @@ def wrapper(self, *args, **kwargs): sd.settimeout(30000) sd.connect((self.host, self.port)) sd.sendall(cmd.encode('utf-8')) - self.match(resp, sd.recv(4096).decode('utf-8')) + data = sd.recv(4096).decode('utf-8') + self.assert_frame(data, 'CONNECTED', { + 'version': '1.0', + 'heart-beat': '0,0', + }) setattr(self, cname, sd) try: r = m(self, *args, **kwargs) @@ -68,6 +66,42 @@ def match(self, pattern, data): return matched.groups() self.assertTrue(False, 'No match:\n{}\n\n{}'.format(pattern, data)) + def parse_frame(self, data): + ''' Parse a STOMP frame into (command, headers_dict, body) ''' + # Strip trailing LF (server sends trailing_lf=true) + if data.endswith('\n'): + data = data[:-1] + parts = data.split('\n\n', 1) + header_section = parts[0] + body = parts[1] if len(parts) > 1 else '' + # Body ends with NUL + if body.endswith('\x00'): + body = body[:-1] + lines = header_section.split('\n') + command = lines[0] + headers = {} + for line in lines[1:]: + if ':' in line: + k, v = line.split(':', 1) + headers[k] = v + return command, headers, body + + def assert_frame(self, data, expected_command, expected_headers=None, expected_body=None): + ''' Assert a STOMP frame matches expected values (header-order independent) ''' + command, headers, body = self.parse_frame(data) + self.assertEqual(expected_command, command) + if expected_headers: + for k, v in expected_headers.items(): + self.assertIn(k, headers, f'Missing header: {k}') + if v is not None: + if isinstance(v, re.Pattern): + self.assertRegex(headers[k], v) + else: + self.assertEqual(v, headers[k]) + if expected_body is not None: + self.assertEqual(expected_body, body) + return headers + def recv_atleast(self, bufsize): recvhead = [] rl = bufsize @@ -79,6 +113,37 @@ def recv_atleast(self, bufsize): rl -= bl return ''.join(recvhead) + def recv_frame(self): + ''' Receive one complete STOMP frame, honoring content-length for binary bodies. ''' + buf = b'' + while b'\n\n' not in buf: + chunk = self.cd.recv(4096) + if not chunk: + return buf.decode('utf-8', errors='replace') + buf += chunk + hdr_end = buf.index(b'\n\n') + header_text = buf[:hdr_end].decode('utf-8') + content_length = None + for line in header_text.split('\n')[1:]: + if line.startswith('content-length:'): + content_length = int(line.split(':', 1)[1]) + break + body_start = hdr_end + 2 + if content_length is not None: + needed = body_start + content_length + 1 + while len(buf) < needed: + chunk = self.cd.recv(needed - len(buf)) + if not chunk: + break + buf += chunk + else: + while b'\x00' not in buf[body_start:]: + chunk = self.cd.recv(4096) + if not chunk: + break + buf += chunk + return buf.decode('utf-8') + @connect(['cd']) def test_newline_after_nul(self): @@ -91,15 +156,13 @@ def test_newline_after_nul(self): 'destination:/exchange/amq.fanout\n\n' 'hello\n\x00\n') self.cd.sendall(cmd.encode('utf-8')) - resp = ('MESSAGE\n' - 'destination:/exchange/amq.fanout\n' - 'message-id:Q_/exchange/amq.fanout@@session-(.*)\n' - 'redelivered:false\n' - 'content-type:text/plain\n' - 'content-length:6\n' - '\n' - 'hello\n\0') - self.match(resp, self.cd.recv(4096).decode('utf-8')) + data = self.cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'MESSAGE', { + 'destination': '/exchange/amq.fanout', + 'redelivered': 'false', + 'content-type': 'text/plain', + 'content-length': '6', + }, 'hello\n') @connect(['cd']) def test_send_without_content_type(self): @@ -111,14 +174,12 @@ def test_send_without_content_type(self): 'destination:/exchange/amq.fanout\n\n' 'hello\n\x00') self.cd.sendall(cmd.encode('utf-8')) - resp = ('MESSAGE\n' - 'destination:/exchange/amq.fanout\n' - 'message-id:Q_/exchange/amq.fanout@@session-(.*)\n' - 'redelivered:false\n' - 'content-length:6\n' - '\n' - 'hello\n\0') - self.match(resp, self.cd.recv(4096).decode('utf-8')) + data = self.cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'MESSAGE', { + 'destination': '/exchange/amq.fanout', + 'redelivered': 'false', + 'content-length': '6', + }, 'hello\n') @connect(['cd']) def test_unicode(self): @@ -131,15 +192,13 @@ def test_unicode(self): 'headꙕr1:valꙕe1\n\n' 'hello\n\x00') self.cd.sendall(cmd.encode('utf-8')) - resp = ('MESSAGE\n' - 'destination:/exchange/amq.fanout\n' - 'message-id:Q_/exchange/amq.fanout@@session-(.*)\n' - 'redelivered:false\n' - 'headꙕr1:valꙕe1\n' - 'content-length:6\n' - '\n' - 'hello\n\0') - self.match(resp, self.cd.recv(4096).decode('utf-8')) + data = self.cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'MESSAGE', { + 'destination': '/exchange/amq.fanout', + 'redelivered': 'false', + 'headꙕr1': 'valꙕe1', + 'content-length': '6', + }, 'hello\n') @connect(['cd']) def test_send_without_content_type_binary(self): @@ -153,13 +212,12 @@ def test_send_without_content_type_binary(self): 'content-length:{}\n\n'.format(len(msg)) + '{}\x00'.format(msg)) self.cd.sendall(cmd.encode('utf-8')) - resp = ('MESSAGE\n' - 'destination:/exchange/amq.fanout\n' - 'message-id:Q_/exchange/amq.fanout@@session-(.*)\n' - 'redelivered:false\n' + - 'content-length:{}\n'.format(len(msg)) + - '\n{}\0'.format(msg)) - self.match(resp, self.cd.recv(4096).decode('utf-8')) + data = self.cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'MESSAGE', { + 'destination': '/exchange/amq.fanout', + 'redelivered': 'false', + 'content-length': str(len(msg)), + }, msg) @connect(['cd']) def test_newline_after_nul_and_leading_nul(self): @@ -172,15 +230,13 @@ def test_newline_after_nul_and_leading_nul(self): 'content-type:text/plain\n' '\nhello\n\x00\n') self.cd.sendall(cmd.encode('utf-8')) - resp = ('MESSAGE\n' - 'destination:/exchange/amq.fanout\n' - 'message-id:Q_/exchange/amq.fanout@@session-(.*)\n' - 'redelivered:false\n' - 'content-type:text/plain\n' - 'content-length:6\n' - '\n' - 'hello\n\0') - self.match(resp, self.cd.recv(4096).decode('utf-8')) + data = self.cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'MESSAGE', { + 'destination': '/exchange/amq.fanout', + 'redelivered': 'false', + 'content-type': 'text/plain', + 'content-length': '6', + }, 'hello\n') @connect(['cd']) def test_bad_command(self): @@ -190,15 +246,12 @@ def test_bad_command(self): 'exchange:amq.fanout\n' '\n\0') self.cd.sendall(cmd.encode('utf-8')) - resp = ('ERROR\n' - 'message:Bad command\n' - 'content-type:text/plain\n' - 'version:1.0,1.1,1.2\n' - 'content-length:43\n' - '\n' - 'Could not interpret command "WRONGCOMMAND"\n' - '\0') - self.match(resp, self.cd.recv(4096).decode('utf-8')) + data = self.cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'ERROR', { + 'message': 'Bad command', + 'content-type': 'text/plain', + 'version': '1.0,1.1,1.2', + }, 'Could not interpret command "WRONGCOMMAND"\n') @connect(['sd', 'cd1', 'cd2']) def test_broadcast(self): @@ -227,18 +280,14 @@ def test_broadcast(self): '\n\0') self.sd.sendall(cmd.encode('utf-8')) - resp=('MESSAGE\n' - 'subscription:(.*)\n' - 'destination:/topic/da9d4779\n' - 'message-id:(.*)\n' - 'redelivered:false\n' - 'content-type:text/plain\n' - 'content-length:8\n' - '\n' - 'message' - '\n\x00') for cd in [self.cd1, self.cd2]: - self.match(resp, cd.recv(4096).decode('utf-8')) + data = cd.recv(4096).decode('utf-8') + self.assert_frame(data, 'MESSAGE', { + 'destination': '/topic/da9d4779', + 'redelivered': 'false', + 'content-type': 'text/plain', + 'content-length': '8', + }, 'message\n') @connect(['cd']) def test_message_with_embedded_nulls(self): @@ -268,33 +317,17 @@ def test_message_with_embedded_nulls(self): '\0' % (len(message), message)) self.cd.sendall(cmd.encode('utf-8')) - headresp=('MESSAGE\n' # 8 - 'subscription:(.*)\n' # 14 + subscription - +resp_dest+ # 44 - 'message-id:(.*)\n' # 12 + message-id - 'redelivered:false\n' # 18 - 'content-type:text/plain\n' # 24 - 'content-length:%i\n' # 16 + 4==len('1024') - '\n' # 1 - '(.*)$' # prefix of body+null (potentially) - % len(message) ) - headlen = 8 + 24 + 14 + (3) + 44 + 12 + 18 + (48) + 16 + (4) + 1 + (1) - - headbuf = self.recv_atleast(headlen) - self.assertFalse(len(headbuf) == 0) - - (sub, msg_id, bodyprefix) = self.match(headresp, headbuf) - bodyresp=( '%s\0' % message ) - bodylen = len(bodyresp); - - bodybuf = ''.join([bodyprefix, - self.recv_atleast(bodylen - len(bodyprefix))]) - - self.assertEqual(len(bodybuf), msg_len+1, - "body received not the same length as message sent") - self.assertEqual(bodybuf, bodyresp, + fullbuf = self.recv_frame() + self.assertFalse(len(fullbuf) == 0) + + command, headers, body = self.parse_frame(fullbuf) + self.assertEqual('MESSAGE', command) + self.assertEqual('/topic/test_embed_nulls_message', headers.get('destination')) + self.assertEqual('false', headers.get('redelivered')) + self.assertEqual(str(msg_len), headers.get('content-length')) + self.assertEqual(message, body, " body (...'%s')\nincorrectly returned as (...'%s')" - % (bodyresp[-10:], bodybuf[-10:])) + % (message[-10:], body[-10:])) @connect(['cd']) def test_message_in_packets(self): @@ -328,33 +361,17 @@ def test_message_in_packets(self): self.cd.sendall(part.encode('utf-8')) part_index += packet_size - headresp=('MESSAGE\n' # 8 - 'subscription:(.*)\n' # 14 + subscription - +resp_dest+ # 44 - 'message-id:(.*)\n' # 12 + message-id - 'redelivered:false\n' # 18 - 'content-type:text/plain\n' # 24 - 'content-length:%i\n' # 16 + 4==len('1024') - '\n' # 1 - '(.*)$' # prefix of body+null (potentially) - % len(message) ) - headlen = 8 + 24 + 14 + (3) + 44 + 12 + 18 + (48) + 16 + (4) + 1 + (1) - - headbuf = self.recv_atleast(headlen) - self.assertFalse(len(headbuf) == 0) - - (sub, msg_id, bodyprefix) = self.match(headresp, headbuf) - bodyresp=( '%s\0' % message ) - bodylen = len(bodyresp); - - bodybuf = ''.join([bodyprefix, - self.recv_atleast(bodylen - len(bodyprefix))]) - - self.assertEqual(len(bodybuf), msg_len+1, - "body received not the same length as message sent") - self.assertEqual(bodybuf, bodyresp, + fullbuf = self.recv_frame() + self.assertFalse(len(fullbuf) == 0) + + command, headers, body = self.parse_frame(fullbuf) + self.assertEqual('MESSAGE', command) + self.assertEqual('false', headers.get('redelivered')) + self.assertEqual('text/plain', headers.get('content-type')) + self.assertEqual(str(msg_len), headers.get('content-length')) + self.assertEqual(message, body, " body ('%s')\nincorrectly returned as ('%s')" - % (bodyresp, bodybuf)) + % (message, body)) if __name__ == '__main__': diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/redelivered.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/redelivered.py index e0de477223e4..1f152fb05244 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/redelivered.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/redelivered.py @@ -36,6 +36,8 @@ def test_redelivered(self): self.assertTrue(listener2.wait(), "message not received again") self.assertEqual(1, len(listener2.messages)) self.assertEqual('true', listener2.messages[0]['headers']['redelivered']) + mid = listener2.messages[0]['headers'][self.ack_id_source_header] + self.ack_message(conn2, mid, None) finally: conn2.disconnect() @@ -44,4 +46,4 @@ def test_redelivered(self): modules = [ __name__ ] - test_runner.run_unittests(modules) \ No newline at end of file + test_runner.run_unittests(modules) diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/test_runner.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/test_runner.py old mode 100644 new mode 100755 index 6b07e7bd5b6a..630c166ab45f --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/test_runner.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/test_runner.py @@ -4,7 +4,7 @@ ## License, v. 2.0. If a copy of the MPL was not distributed with this ## file, You can obtain one at https://mozilla.org/MPL/2.0/. ## -## Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term “Broadcom” refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. +## Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. ## import unittest @@ -20,7 +20,12 @@ def run_unittests(modules): if name.startswith("Test") and issubclass(obj, unittest.TestCase): suite.addTest(unittest.TestLoader().loadTestsFromTestCase(obj)) - ts = unittest.TextTestRunner(verbosity=10).run(unittest.TestSuite(suite)) + ts = unittest.TextTestRunner(verbosity=2).run(unittest.TestSuite(suite)) if ts.errors or ts.failures: sys.exit(1) +if __name__ == '__main__': + if len(sys.argv) < 2: + print("Usage: test_runner.py [module ...]", file=sys.stderr) + sys.exit(2) + run_unittests(sys.argv[1:]) diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/topic_permissions.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/topic_permissions.py index bd53a650050f..4eaa618dfa50 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/topic_permissions.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/topic_permissions.py @@ -49,7 +49,7 @@ def test_publish_authorisation(self): # assert errors self.assertGreater(len(self.listener.errors), 0) - self.assertIn("ACCESS_REFUSED", self.listener.errors[0]['message']) + self.assertIn("access_refused", self.listener.errors[0]['headers']['message']) if __name__ == '__main__': @@ -57,4 +57,4 @@ def test_publish_authorisation(self): modules = [ __name__ ] - test_runner.run_unittests(modules) \ No newline at end of file + test_runner.run_unittests(modules) diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_name.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_name.py index 4bc7e68f86c0..305a89722d15 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_name.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_name.py @@ -13,7 +13,7 @@ import os import test_util -class TestUserGeneratedQueueName(base.BaseTest): +class TestQueueName(base.BaseTest): def test_exchange_dest(self): queueName='my-user-generated-queue-name-exchange' diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_quorum.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_quorum.py index ffe3dde95726..f9d737b4c1a4 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_quorum.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_quorum.py @@ -12,7 +12,7 @@ import re import rabbitman -class TestUserGeneratedQueueName(base.BaseTest): +class TestQuorumQueue(base.BaseTest): def test_quorum_queue(self): queueName = 'my-quorum-queue' diff --git a/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_stream.py b/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_stream.py index 4216e38d1178..da979cd05b48 100644 --- a/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_stream.py +++ b/deps/rabbitmq_stomp/test/python_SUITE_data/src/x_queue_type_stream.py @@ -12,7 +12,7 @@ import re import rabbitman -class TestUserGeneratedQueueName(base.BaseTest): +class TestStreamQueue(base.BaseTest): def test_stream_queue(self): queueName = 'my-stream-queue' diff --git a/deps/rabbitmq_stomp/test/python_SUITE_generated.hrl b/deps/rabbitmq_stomp/test/python_SUITE_generated.hrl new file mode 100644 index 000000000000..071449865f48 --- /dev/null +++ b/deps/rabbitmq_stomp/test/python_SUITE_generated.hrl @@ -0,0 +1,335 @@ +%% Generated by generate_python_tests.py — do not edit + +-define(TLS_SUBGROUPS, [ + {group, 'TLSConnection'} +]). + +-define(TLSCONNECTION_TESTS, [ + 'TLSConnection.test_ssl_connect', + 'TLSConnection.test_ssl_auth_connect', + 'TLSConnection.test_ssl_send_receive', + 'TLSConnection.test_ssl_auth_send_receive' +]). + +-define(IMPLICIT_CONNECT_SUBGROUPS, [ + {group, 'ImplicitConnect'} +]). + +-define(IMPLICITCONNECT_TESTS, [ + 'ImplicitConnect.test_implicit_connect' +]). + +-define(MAIN_SUBGROUPS, [ + {group, 'Parsing'}, + {group, 'errors'}, + {group, 'ConnectDisconnect'}, + {group, 'ack'}, + {group, 'AmqpHeaders'}, + {group, 'QueueProperties'}, + {group, 'Reliability'}, + {group, 'Transactions'}, + {group, 'QueueName'}, + {group, 'destinations'}, + {group, 'Redelivered'}, + {group, 'TopicPermissions'}, + {group, 'Lifecycle'}, + {group, 'QuorumQueue'}, + {group, 'StreamQueue'} +]). + +-define(PARSING_TESTS, [ + 'Parsing.test_newline_after_nul', + 'Parsing.test_send_without_content_type', + 'Parsing.test_unicode', + 'Parsing.test_send_without_content_type_binary', + 'Parsing.test_newline_after_nul_and_leading_nul', + 'Parsing.test_bad_command', + 'Parsing.test_broadcast', + 'Parsing.test_message_with_embedded_nulls', + 'Parsing.test_message_in_packets' +]). + +-define(ERRORS_SUBGROUPS, [ + {group, 'ErrorsAndCloseConnection'}, + {group, 'Errors'} +]). + +-define(ERRORSANDCLOSECONNECTION_TESTS, [ + 'ErrorsAndCloseConnection.test_duplicate_consumer_tag_with_transient_destination', + 'ErrorsAndCloseConnection.test_duplicate_consumer_tag_with_durable_destination' +]). + +-define(ERRORS_TESTS, [ + 'Errors.test_invalid_queue_destination', + 'Errors.test_invalid_empty_queue_destination', + 'Errors.test_invalid_topic_destination', + 'Errors.test_invalid_empty_topic_destination', + 'Errors.test_invalid_exchange_destination', + 'Errors.test_invalid_empty_exchange_destination', + 'Errors.test_invalid_default_exchange_destination', + 'Errors.test_unknown_destination', + 'Errors.test_send_missing_destination', + 'Errors.test_subscribe_missing_destination' +]). + +-define(CONNECTDISCONNECT_TESTS, [ + 'ConnectDisconnect.test_connect_version_1_0', + 'ConnectDisconnect.test_connect_version_1_1', + 'ConnectDisconnect.test_connect_version_1_2', + 'ConnectDisconnect.test_default_user', + 'ConnectDisconnect.test_unsupported_version', + 'ConnectDisconnect.test_bad_username', + 'ConnectDisconnect.test_bad_password', + 'ConnectDisconnect.test_bad_vhost', + 'ConnectDisconnect.test_bad_header_on_send', + 'ConnectDisconnect.test_send_recv_header', + 'ConnectDisconnect.test_disconnect', + 'ConnectDisconnect.test_disconnect_with_receipt' +]). + +-define(ACK_SUBGROUPS, [ + {group, 'Ack'}, + {group, 'Ack11'}, + {group, 'Ack12'} +]). + +-define(ACK_TESTS, [ + 'Ack.test_ack_client', + 'Ack.test_ack_client_individual', + 'Ack.test_ack_client_tx', + 'Ack.test_topic_prefetch', + 'Ack.test_nack', + 'Ack.test_nack_multi', + 'Ack.test_nack_without_requeueing' +]). + +-define(ACK11_TESTS, [ + 'Ack11.test_version' +]). + +-define(ACK12_TESTS, [ + 'Ack12.test_version' +]). + +-define(AMQPHEADERS_TESTS, [ + 'AmqpHeaders.test_headers_to_stomp' +]). + +-define(QUEUEPROPERTIES_TESTS, [ + 'QueueProperties.test_subscribe', + 'QueueProperties.test_send' +]). + +-define(RELIABILITY_TESTS, [ + 'Reliability.test_send_and_disconnect' +]). + +-define(TRANSACTIONS_TESTS, [ + 'Transactions.test_tx_commit', + 'Transactions.test_tx_abort' +]). + +-define(QUEUENAME_TESTS, [ + 'QueueName.test_exchange_dest', + 'QueueName.test_topic_dest' +]). + +-define(DESTINATIONS_SUBGROUPS, [ + {group, 'Exchange'}, + {group, 'Queue'}, + {group, 'Topic'}, + {group, 'ReplyQueue'}, + {group, 'DurableSubscription'} +]). + +-define(EXCHANGE_TESTS, [ + 'Exchange.test_amq_direct', + 'Exchange.test_amq_topic', + 'Exchange.test_amq_fanout', + 'Exchange.test_amq_fanout_no_route', + 'Exchange.test_invalid_exchange' +]). + +-define(QUEUE_TESTS, [ + 'Queue.test_send_receive', + 'Queue.test_send_recv_header', + 'Queue.test_send_receive_in_other_conn', + 'Queue.test_send_receive_in_other_conn_with_disconnect', + 'Queue.test_multi_subscribers', + 'Queue.test_send_with_receipt', + 'Queue.test_send_with_receipt_tx', + 'Queue.test_interleaved_receipt_no_receipt', + 'Queue.test_interleaved_receipt_no_receipt_tx', + 'Queue.test_interleaved_receipt_no_receipt_inverse' +]). + +-define(TOPIC_TESTS, [ + 'Topic.test_send_receive', + 'Topic.test_send_multiple', + 'Topic.test_send_multiple_with_a_large_message' +]). + +-define(REPLYQUEUE_TESTS, [ + 'ReplyQueue.test_durable_known_reply_queue' +]). + +-define(DURABLESUBSCRIPTION_TESTS, [ + 'DurableSubscription.test_durable_subscription', + 'DurableSubscription.test_durable_subscription_and_legacy_header', + 'DurableSubscription.test_share_subscription', + 'DurableSubscription.test_separate_ids', + 'DurableSubscription.test_durable_subscribe_no_id', + 'DurableSubscription.test_durable_subscribe_no_id_and_legacy_header' +]). + +-define(REDELIVERED_TESTS, [ + 'Redelivered.test_redelivered' +]). + +-define(TOPICPERMISSIONS_TESTS, [ + 'TopicPermissions.test_publish_authorisation' +]). + +-define(LIFECYCLE_TESTS, [ + 'Lifecycle.test_unsubscribe_exchange_destination', + 'Lifecycle.test_unsubscribe_exchange_destination_with_receipt', + 'Lifecycle.test_unsubscribe_queue_destination', + 'Lifecycle.test_unsubscribe_queue_destination_with_receipt', + 'Lifecycle.test_unsubscribe_exchange_id', + 'Lifecycle.test_unsubscribe_exchange_id_with_receipt', + 'Lifecycle.test_unsubscribe_queue_id', + 'Lifecycle.test_unsubscribe_queue_id_with_receipt' +]). + +-define(QUORUMQUEUE_TESTS, [ + 'QuorumQueue.test_quorum_queue' +]). + +-define(STREAMQUEUE_TESTS, [ + 'StreamQueue.test_stream_queue' +]). + +-define(CLASS_GROUPS, [ + {'TLSConnection', [], ?TLSCONNECTION_TESTS}, + {'ImplicitConnect', [], ?IMPLICITCONNECT_TESTS}, + {'Parsing', [], ?PARSING_TESTS}, + {'errors', [], ?ERRORS_SUBGROUPS}, + {'ErrorsAndCloseConnection', [], ?ERRORSANDCLOSECONNECTION_TESTS}, + {'Errors', [], ?ERRORS_TESTS}, + {'ConnectDisconnect', [], ?CONNECTDISCONNECT_TESTS}, + {'ack', [], ?ACK_SUBGROUPS}, + {'Ack', [], ?ACK_TESTS}, + {'Ack11', [], ?ACK11_TESTS}, + {'Ack12', [], ?ACK12_TESTS}, + {'AmqpHeaders', [], ?AMQPHEADERS_TESTS}, + {'QueueProperties', [], ?QUEUEPROPERTIES_TESTS}, + {'Reliability', [], ?RELIABILITY_TESTS}, + {'Transactions', [], ?TRANSACTIONS_TESTS}, + {'QueueName', [], ?QUEUENAME_TESTS}, + {'destinations', [], ?DESTINATIONS_SUBGROUPS}, + {'Exchange', [], ?EXCHANGE_TESTS}, + {'Queue', [], ?QUEUE_TESTS}, + {'Topic', [], ?TOPIC_TESTS}, + {'ReplyQueue', [], ?REPLYQUEUE_TESTS}, + {'DurableSubscription', [], ?DURABLESUBSCRIPTION_TESTS}, + {'Redelivered', [], ?REDELIVERED_TESTS}, + {'TopicPermissions', [], ?TOPICPERMISSIONS_TESTS}, + {'Lifecycle', [], ?LIFECYCLE_TESTS}, + {'QuorumQueue', [], ?QUORUMQUEUE_TESTS}, + {'StreamQueue', [], ?STREAMQUEUE_TESTS} +]). + +%% Test function definitions + +'TLSConnection.test_ssl_connect'(Config) -> run_one_test(Config, "tls_connect_disconnect.TestTLSConnection.test_ssl_connect"). +'TLSConnection.test_ssl_auth_connect'(Config) -> run_one_test(Config, "tls_connect_disconnect.TestTLSConnection.test_ssl_auth_connect"). +'TLSConnection.test_ssl_send_receive'(Config) -> run_one_test(Config, "tls_connect_disconnect.TestTLSConnection.test_ssl_send_receive"). +'TLSConnection.test_ssl_auth_send_receive'(Config) -> run_one_test(Config, "tls_connect_disconnect.TestTLSConnection.test_ssl_auth_send_receive"). +'ImplicitConnect.test_implicit_connect'(Config) -> run_one_test(Config, "implicit_connect.TestImplicitConnect.test_implicit_connect"). +'Parsing.test_newline_after_nul'(Config) -> run_one_test(Config, "parsing.TestParsing.test_newline_after_nul"). +'Parsing.test_send_without_content_type'(Config) -> run_one_test(Config, "parsing.TestParsing.test_send_without_content_type"). +'Parsing.test_unicode'(Config) -> run_one_test(Config, "parsing.TestParsing.test_unicode"). +'Parsing.test_send_without_content_type_binary'(Config) -> run_one_test(Config, "parsing.TestParsing.test_send_without_content_type_binary"). +'Parsing.test_newline_after_nul_and_leading_nul'(Config) -> run_one_test(Config, "parsing.TestParsing.test_newline_after_nul_and_leading_nul"). +'Parsing.test_bad_command'(Config) -> run_one_test(Config, "parsing.TestParsing.test_bad_command"). +'Parsing.test_broadcast'(Config) -> run_one_test(Config, "parsing.TestParsing.test_broadcast"). +'Parsing.test_message_with_embedded_nulls'(Config) -> run_one_test(Config, "parsing.TestParsing.test_message_with_embedded_nulls"). +'Parsing.test_message_in_packets'(Config) -> run_one_test(Config, "parsing.TestParsing.test_message_in_packets"). +'ErrorsAndCloseConnection.test_duplicate_consumer_tag_with_transient_destination'(Config) -> run_one_test(Config, "errors.TestErrorsAndCloseConnection.test_duplicate_consumer_tag_with_transient_destination"). +'ErrorsAndCloseConnection.test_duplicate_consumer_tag_with_durable_destination'(Config) -> run_one_test(Config, "errors.TestErrorsAndCloseConnection.test_duplicate_consumer_tag_with_durable_destination"). +'Errors.test_invalid_queue_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_queue_destination"). +'Errors.test_invalid_empty_queue_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_empty_queue_destination"). +'Errors.test_invalid_topic_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_topic_destination"). +'Errors.test_invalid_empty_topic_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_empty_topic_destination"). +'Errors.test_invalid_exchange_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_exchange_destination"). +'Errors.test_invalid_empty_exchange_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_empty_exchange_destination"). +'Errors.test_invalid_default_exchange_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_invalid_default_exchange_destination"). +'Errors.test_unknown_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_unknown_destination"). +'Errors.test_send_missing_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_send_missing_destination"). +'Errors.test_subscribe_missing_destination'(Config) -> run_one_test(Config, "errors.TestErrors.test_subscribe_missing_destination"). +'ConnectDisconnect.test_connect_version_1_0'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_connect_version_1_0"). +'ConnectDisconnect.test_connect_version_1_1'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_connect_version_1_1"). +'ConnectDisconnect.test_connect_version_1_2'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_connect_version_1_2"). +'ConnectDisconnect.test_default_user'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_default_user"). +'ConnectDisconnect.test_unsupported_version'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_unsupported_version"). +'ConnectDisconnect.test_bad_username'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_bad_username"). +'ConnectDisconnect.test_bad_password'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_bad_password"). +'ConnectDisconnect.test_bad_vhost'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_bad_vhost"). +'ConnectDisconnect.test_bad_header_on_send'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_bad_header_on_send"). +'ConnectDisconnect.test_send_recv_header'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_send_recv_header"). +'ConnectDisconnect.test_disconnect'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_disconnect"). +'ConnectDisconnect.test_disconnect_with_receipt'(Config) -> run_one_test(Config, "connect_disconnect.TestConnectDisconnect.test_disconnect_with_receipt"). +'Ack.test_ack_client'(Config) -> run_one_test(Config, "ack.TestAck.test_ack_client"). +'Ack.test_ack_client_individual'(Config) -> run_one_test(Config, "ack.TestAck.test_ack_client_individual"). +'Ack.test_ack_client_tx'(Config) -> run_one_test(Config, "ack.TestAck.test_ack_client_tx"). +'Ack.test_topic_prefetch'(Config) -> run_one_test(Config, "ack.TestAck.test_topic_prefetch"). +'Ack.test_nack'(Config) -> run_one_test(Config, "ack.TestAck.test_nack"). +'Ack.test_nack_multi'(Config) -> run_one_test(Config, "ack.TestAck.test_nack_multi"). +'Ack.test_nack_without_requeueing'(Config) -> run_one_test(Config, "ack.TestAck.test_nack_without_requeueing"). +'Ack11.test_version'(Config) -> run_one_test(Config, "ack.TestAck11.test_version"). +'Ack12.test_version'(Config) -> run_one_test(Config, "ack.TestAck12.test_version"). +'AmqpHeaders.test_headers_to_stomp'(Config) -> run_one_test(Config, "amqp_headers.TestAmqpHeaders.test_headers_to_stomp"). +'QueueProperties.test_subscribe'(Config) -> run_one_test(Config, "queue_properties.TestQueueProperties.test_subscribe"). +'QueueProperties.test_send'(Config) -> run_one_test(Config, "queue_properties.TestQueueProperties.test_send"). +'Reliability.test_send_and_disconnect'(Config) -> run_one_test(Config, "reliability.TestReliability.test_send_and_disconnect"). +'Transactions.test_tx_commit'(Config) -> run_one_test(Config, "transactions.TestTransactions.test_tx_commit"). +'Transactions.test_tx_abort'(Config) -> run_one_test(Config, "transactions.TestTransactions.test_tx_abort"). +'QueueName.test_exchange_dest'(Config) -> run_one_test(Config, "x_queue_name.TestQueueName.test_exchange_dest"). +'QueueName.test_topic_dest'(Config) -> run_one_test(Config, "x_queue_name.TestQueueName.test_topic_dest"). +'Exchange.test_amq_direct'(Config) -> run_one_test(Config, "destinations.TestExchange.test_amq_direct"). +'Exchange.test_amq_topic'(Config) -> run_one_test(Config, "destinations.TestExchange.test_amq_topic"). +'Exchange.test_amq_fanout'(Config) -> run_one_test(Config, "destinations.TestExchange.test_amq_fanout"). +'Exchange.test_amq_fanout_no_route'(Config) -> run_one_test(Config, "destinations.TestExchange.test_amq_fanout_no_route"). +'Exchange.test_invalid_exchange'(Config) -> run_one_test(Config, "destinations.TestExchange.test_invalid_exchange"). +'Queue.test_send_receive'(Config) -> run_one_test(Config, "destinations.TestQueue.test_send_receive"). +'Queue.test_send_recv_header'(Config) -> run_one_test(Config, "destinations.TestQueue.test_send_recv_header"). +'Queue.test_send_receive_in_other_conn'(Config) -> run_one_test(Config, "destinations.TestQueue.test_send_receive_in_other_conn"). +'Queue.test_send_receive_in_other_conn_with_disconnect'(Config) -> run_one_test(Config, "destinations.TestQueue.test_send_receive_in_other_conn_with_disconnect"). +'Queue.test_multi_subscribers'(Config) -> run_one_test(Config, "destinations.TestQueue.test_multi_subscribers"). +'Queue.test_send_with_receipt'(Config) -> run_one_test(Config, "destinations.TestQueue.test_send_with_receipt"). +'Queue.test_send_with_receipt_tx'(Config) -> run_one_test(Config, "destinations.TestQueue.test_send_with_receipt_tx"). +'Queue.test_interleaved_receipt_no_receipt'(Config) -> run_one_test(Config, "destinations.TestQueue.test_interleaved_receipt_no_receipt"). +'Queue.test_interleaved_receipt_no_receipt_tx'(Config) -> run_one_test(Config, "destinations.TestQueue.test_interleaved_receipt_no_receipt_tx"). +'Queue.test_interleaved_receipt_no_receipt_inverse'(Config) -> run_one_test(Config, "destinations.TestQueue.test_interleaved_receipt_no_receipt_inverse"). +'Topic.test_send_receive'(Config) -> run_one_test(Config, "destinations.TestTopic.test_send_receive"). +'Topic.test_send_multiple'(Config) -> run_one_test(Config, "destinations.TestTopic.test_send_multiple"). +'Topic.test_send_multiple_with_a_large_message'(Config) -> run_one_test(Config, "destinations.TestTopic.test_send_multiple_with_a_large_message"). +'ReplyQueue.test_durable_known_reply_queue'(Config) -> run_one_test(Config, "destinations.TestReplyQueue.test_durable_known_reply_queue"). +'DurableSubscription.test_durable_subscription'(Config) -> run_one_test(Config, "destinations.TestDurableSubscription.test_durable_subscription"). +'DurableSubscription.test_durable_subscription_and_legacy_header'(Config) -> run_one_test(Config, "destinations.TestDurableSubscription.test_durable_subscription_and_legacy_header"). +'DurableSubscription.test_share_subscription'(Config) -> run_one_test(Config, "destinations.TestDurableSubscription.test_share_subscription"). +'DurableSubscription.test_separate_ids'(Config) -> run_one_test(Config, "destinations.TestDurableSubscription.test_separate_ids"). +'DurableSubscription.test_durable_subscribe_no_id'(Config) -> run_one_test(Config, "destinations.TestDurableSubscription.test_durable_subscribe_no_id"). +'DurableSubscription.test_durable_subscribe_no_id_and_legacy_header'(Config) -> run_one_test(Config, "destinations.TestDurableSubscription.test_durable_subscribe_no_id_and_legacy_header"). +'Redelivered.test_redelivered'(Config) -> run_one_test(Config, "redelivered.TestRedelivered.test_redelivered"). +'TopicPermissions.test_publish_authorisation'(Config) -> run_one_test(Config, "topic_permissions.TestTopicPermissions.test_publish_authorisation"). +'Lifecycle.test_unsubscribe_exchange_destination'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_exchange_destination"). +'Lifecycle.test_unsubscribe_exchange_destination_with_receipt'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_exchange_destination_with_receipt"). +'Lifecycle.test_unsubscribe_queue_destination'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_queue_destination"). +'Lifecycle.test_unsubscribe_queue_destination_with_receipt'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_queue_destination_with_receipt"). +'Lifecycle.test_unsubscribe_exchange_id'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_exchange_id"). +'Lifecycle.test_unsubscribe_exchange_id_with_receipt'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_exchange_id_with_receipt"). +'Lifecycle.test_unsubscribe_queue_id'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_queue_id"). +'Lifecycle.test_unsubscribe_queue_id_with_receipt'(Config) -> run_one_test(Config, "unsubscribe.TestLifecycle.test_unsubscribe_queue_id_with_receipt"). +'QuorumQueue.test_quorum_queue'(Config) -> run_one_test(Config, "x_queue_type_quorum.TestQuorumQueue.test_quorum_queue"). +'StreamQueue.test_stream_queue'(Config) -> run_one_test(Config, "x_queue_type_stream.TestStreamQueue.test_stream_queue"). diff --git a/deps/rabbitmq_stomp/test/src/rabbit_stomp_client.erl b/deps/rabbitmq_stomp/test/src/rabbit_stomp_client.erl index 510238bbbb3b..f0d12674d3ce 100644 --- a/deps/rabbitmq_stomp/test/src/rabbit_stomp_client.erl +++ b/deps/rabbitmq_stomp/test/src/rabbit_stomp_client.erl @@ -27,13 +27,15 @@ connect0(Version, Login, Pass, Port, Headers) -> %% AMQP default port. {ok, Sock} = gen_tcp:connect(localhost, Port, [{active, false}, binary]), Client0 = recv_state(Sock), - send(Client0, "CONNECT", [{"login", Login}, - {"passcode", Pass} | Version] ++ Headers), - {#stomp_frame{command = "CONNECTED"}, Client1} = recv(Client0), + send(Client0, 'CONNECT', [{<<"login">>, list_to_binary(Login)}, + {<<"passcode">>, list_to_binary(Pass)} + | [{list_to_binary(K), list_to_binary(V)} || {K, V} <- Version]] + ++ Headers), + {#stomp_frame{command = 'CONNECTED'}, Client1} = recv(Client0), {ok, Client1}. disconnect(Client = {Sock, _}) -> - send(Client, "DISCONNECT"), + send(Client, 'DISCONNECT'), gen_tcp:close(Sock). send(Client, Command) -> @@ -44,9 +46,9 @@ send(Client, Command, Headers) -> send({Sock, _}, Command, Headers, Body) -> Frame = rabbit_stomp_frame:serialize( - #stomp_frame{command = list_to_binary(Command), - headers = Headers, - body_iolist = Body}), + #stomp_frame{command = Command, + headers = maps:from_list(Headers), + body_iolist_rev = Body}), gen_tcp:send(Sock, Frame). recv_state(Sock) -> diff --git a/deps/rabbitmq_stomp/test/src/rabbit_stomp_publish_test.erl b/deps/rabbitmq_stomp/test/src/rabbit_stomp_publish_test.erl index fafaf0cd7d5e..efcb1a9662f1 100644 --- a/deps/rabbitmq_stomp/test/src/rabbit_stomp_publish_test.erl +++ b/deps/rabbitmq_stomp/test/src/rabbit_stomp_publish_test.erl @@ -11,7 +11,7 @@ -include("rabbit_stomp_frame.hrl"). --define(DESTINATION, "/queue/test"). +-define(DESTINATION, <<"/queue/test">>). -define(MICROS_PER_UPDATE, 5000000). -define(MICROS_PER_UPDATE_MSG, 100000). @@ -27,7 +27,7 @@ run() -> Self = self(), spawn(fun() -> publish(Self, Pub, 0, erlang:monotonic_time()) end), rabbit_stomp_client:send( - Recv, "SUBSCRIBE", [{"destination", ?DESTINATION}]), + Recv, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}]), spawn(fun() -> recv(Self, Recv, 0, erlang:monotonic_time()) end), report(). @@ -53,8 +53,8 @@ report() -> publish(Owner, Client, Count, TS) -> rabbit_stomp_client:send( - Client, "SEND", [{"destination", ?DESTINATION}], - [integer_to_list(Count)]), + Client, 'SEND', [{<<"destination">>, ?DESTINATION}], + [integer_to_binary(Count)]), Diff = erlang:convert_time_unit( erlang:monotonic_time() - TS, native, microseconds), case Diff > ?MICROS_PER_UPDATE_MSG of @@ -65,9 +65,9 @@ publish(Owner, Client, Count, TS) -> end. recv(Owner, Client0, Count, TS) -> - {#stomp_frame{body_iolist = Body}, Client1} = + {#stomp_frame{body_iolist_rev = Body}, Client1} = rabbit_stomp_client:recv(Client0), - BodyInt = list_to_integer(binary_to_list(iolist_to_binary(Body))), + BodyInt = binary_to_integer(iolist_to_binary(Body)), Count = BodyInt, Diff = erlang:convert_time_unit( erlang:monotonic_time() - TS, native, microseconds), @@ -77,4 +77,3 @@ recv(Owner, Client0, Count, TS) -> erlang:monotonic_time()); false -> recv(Owner, Client1, Count + 1, TS) end. - diff --git a/deps/rabbitmq_stomp/test/system_SUITE.erl b/deps/rabbitmq_stomp/test/system_SUITE.erl index 45f9042912a0..9f603262d21d 100644 --- a/deps/rabbitmq_stomp/test/system_SUITE.erl +++ b/deps/rabbitmq_stomp/test/system_SUITE.erl @@ -18,8 +18,8 @@ -define(QUEUE, <<"TestQueue">>). -define(QUEUE_QQ, <<"TestQueueQQ">>). --define(DESTINATION, "/amq/queue/TestQueue"). --define(DESTINATION_QQ, "/amq/queue/TestQueueQQ"). +-define(DESTINATION, <<"/amq/queue/TestQueue">>). +-define(DESTINATION_QQ, <<"/amq/queue/TestQueueQQ">>). all() -> [{group, version_to_group_name(V)} || V <- ?SUPPORTED_VERSIONS]. @@ -38,7 +38,9 @@ groups() -> temp_destination_queue, temp_destination_in_send, blank_destination_in_send, - stream_filtering + stream_filtering, + transaction_limit, + global_counters ], [{version_to_group_name(V), [sequence], Tests} @@ -119,29 +121,90 @@ end_per_testcase0(publish_unauthorized_error, Config) -> end_per_testcase0(_, Config) -> Config. +transaction_limit(Config) -> + Client = ?config(stomp_client, Config), + %% Open 16 transactions (the limit) + lists:foreach(fun(I) -> + TxId = integer_to_binary(I), + rabbit_stomp_client:send(Client, 'BEGIN', + [{<<"transaction">>, TxId}]) + end, lists:seq(1, 16)), + + %% The 17th should fail + rabbit_stomp_client:send(Client, 'BEGIN', + [{<<"transaction">>, <<"17">>}]), + {ok, _Client1, Hdrs, _} = stomp_receive(Client, 'ERROR'), + <<"Transaction limit exceeded">> = maps:get(<<"message">>, Hdrs), + ok. + +global_counters(Config) -> + Version = ?config(version, Config), + ProtoVer = stomp_proto_ver(Version), + Dest = iolist_to_binary(["/topic/counter-test-", Version]), + + C0 = get_global_counters(Config, ProtoVer), + Pubs0 = maps:get(publishers, C0, 0), + Cons0 = maps:get(consumers, C0, 0), + Recv0 = maps:get(messages_received_total, C0, 0), + Routed0 = maps:get(messages_routed_total, C0, 0), + + Client = ?config(stomp_client, Config), + rabbit_stomp_client:send( + Client, 'SUBSCRIBE', + [{<<"destination">>, Dest}, {<<"id">>, <<"counter-sub">>}]), + + rabbit_stomp_client:send( + Client, 'SEND', [{<<"destination">>, Dest}], ["hello"]), + + {ok, Client1, _Hdrs, _Body} = stomp_receive(Client, 'MESSAGE'), + + C1 = get_global_counters(Config, ProtoVer), + ?assertEqual(Pubs0 + 1, maps:get(publishers, C1)), + ?assertEqual(Cons0 + 1, maps:get(consumers, C1)), + ?assertEqual(Recv0 + 1, maps:get(messages_received_total, C1)), + ?assertEqual(Routed0 + 1, maps:get(messages_routed_total, C1)), + + rabbit_stomp_client:send( + Client1, 'UNSUBSCRIBE', [{<<"id">>, <<"counter-sub">>}]), + + timer:sleep(100), + C2 = get_global_counters(Config, ProtoVer), + ?assertEqual(Cons0, maps:get(consumers, C2)), + + ok. + +get_global_counters(Config, ProtoVer) -> + maps:get(#{protocol => ProtoVer}, + rabbit_ct_broker_helpers:rpc( + Config, 0, rabbit_global_counters, overview, [])). + +stomp_proto_ver("1.0") -> 'STOMP 1.0'; +stomp_proto_ver("1.1") -> 'STOMP 1.1'; +stomp_proto_ver("1.2") -> 'STOMP 1.2'. + publish_no_dest_error(Config) -> Client = ?config(stomp_client, Config), rabbit_stomp_client:send( - Client, "SEND", [{"destination", "/exchange/non-existent"}], ["hello"]), - {ok, _Client1, Hdrs, _} = stomp_receive(Client, "ERROR"), - "not_found" = proplists:get_value("message", Hdrs), + Client, 'SEND', [{<<"destination">>, <<"/exchange/non-existent">>}], ["hello"]), + {ok, _Client1, Hdrs, _} = stomp_receive(Client, 'ERROR'), + <<"not_found">> = maps:get(<<"message">>, Hdrs), ok. publish_unauthorized_error(Config) -> ClientFoo = ?config(client_foo, Config), rabbit_stomp_client:send( - ClientFoo, "SEND", [{"destination", "/amq/queue/RestrictedQueue"}], ["hello"]), - {ok, _Client1, Hdrs, _} = stomp_receive(ClientFoo, "ERROR"), - "access_refused" = proplists:get_value("message", Hdrs), + ClientFoo, 'SEND', [{<<"destination">>, <<"/amq/queue/RestrictedQueue">>}], ["hello"]), + {ok, _Client1, Hdrs, _} = stomp_receive(ClientFoo, 'ERROR'), + <<"access_refused">> = maps:get(<<"message">>, Hdrs), ok. subscribe_error(Config) -> Client = ?config(stomp_client, Config), %% SUBSCRIBE to missing queue rabbit_stomp_client:send( - Client, "SUBSCRIBE", [{"destination", ?DESTINATION}]), - {ok, _Client1, Hdrs, _} = stomp_receive(Client, "ERROR"), - "not_found" = proplists:get_value("message", Hdrs), + Client, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}]), + {ok, _Client1, Hdrs, _} = stomp_receive(Client, 'ERROR'), + <<"not_found">> = maps:get(<<"message">>, Hdrs), ok. subscribe(Config) -> @@ -154,8 +217,8 @@ subscribe(Config) -> %% subscribe and wait for receipt rabbit_stomp_client:send( - Client, "SUBSCRIBE", [{"destination", ?DESTINATION}, {"receipt", "foo"}]), - {ok, Client1, _, _} = stomp_receive(Client, "RECEIPT"), + Client, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}, {<<"receipt">>, <<"foo">>}]), + {ok, Client1, _, _} = stomp_receive(Client, 'RECEIPT'), %% send from amqp Method = #'basic.publish'{exchange = <<"">>, routing_key = ?QUEUE}, @@ -163,7 +226,7 @@ subscribe(Config) -> amqp_channel:call(Channel, Method, #amqp_msg{props = #'P_basic'{}, payload = <<"hello">>}), - {ok, _Client2, _, [<<"hello">>]} = stomp_receive(Client1, "MESSAGE"), + {ok, _Client2, _, [<<"hello">>]} = stomp_receive(Client1, 'MESSAGE'), ok. subscribe_with_x_priority(Config) -> @@ -180,17 +243,17 @@ subscribe_with_x_priority(Config) -> %% subscribe and wait for receipt rabbit_stomp_client:send( - ClientA, "SUBSCRIBE", [{"destination", ?DESTINATION_QQ}, {"receipt", "foo"}]), - {ok, _ClientA1, _, _} = stomp_receive(ClientA, "RECEIPT"), + ClientA, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION_QQ}, {<<"receipt">>, <<"foo">>}]), + {ok, _ClientA1, _, _} = stomp_receive(ClientA, 'RECEIPT'), %% subscribe with a higher priority and wait for receipt {ok, ClientB} = rabbit_stomp_client:connect(Version, StompPort), rabbit_stomp_client:send( - ClientB, "SUBSCRIBE", [{"destination", ?DESTINATION_QQ}, - {"receipt", "foo"}, - {"x-priority", 10} - ]), - {ok, ClientB1, _, _} = stomp_receive(ClientB, "RECEIPT"), + ClientB, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION_QQ}, + {<<"receipt">>, <<"foo">>}, + {<<"x-priority">>, <<"10">>} + ]), + {ok, ClientB1, _, _} = stomp_receive(ClientB, 'RECEIPT'), %% send from amqp Method = #'basic.publish'{exchange = <<"">>, routing_key = ?QUEUE_QQ}, @@ -199,7 +262,7 @@ subscribe_with_x_priority(Config) -> payload = <<"hello">>}), %% ClientB should receive the message since it has a higher priority - {ok, _ClientB2, _, [<<"hello">>]} = stomp_receive(ClientB1, "MESSAGE"), + {ok, _ClientB2, _, [<<"hello">>]} = stomp_receive(ClientB1, 'MESSAGE'), #'queue.delete_ok'{} = amqp_channel:call(Channel, #'queue.delete'{queue = ?QUEUE_QQ}), ok. @@ -214,11 +277,11 @@ unsubscribe_ack(Config) -> auto_delete = true}), %% subscribe and wait for receipt rabbit_stomp_client:send( - Client, "SUBSCRIBE", [{"destination", ?DESTINATION}, - {"receipt", "rcpt1"}, - {"ack", "client"}, - {"id", "subscription-id"}]), - {ok, Client1, _, _} = stomp_receive(Client, "RECEIPT"), + Client, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}, + {<<"receipt">>, <<"rcpt1">>}, + {<<"ack">>, <<"client">>}, + {<<"id">>, <<"subscription-id">>}]), + {ok, Client1, _, _} = stomp_receive(Client, 'RECEIPT'), %% send from amqp Method = #'basic.publish'{exchange = <<"">>, routing_key = ?QUEUE}, @@ -226,21 +289,21 @@ unsubscribe_ack(Config) -> amqp_channel:call(Channel, Method, #amqp_msg{props = #'P_basic'{}, payload = <<"hello">>}), - {ok, Client2, Hdrs1, [<<"hello">>]} = stomp_receive(Client1, "MESSAGE"), + {ok, Client2, Hdrs1, [<<"hello">>]} = stomp_receive(Client1, 'MESSAGE'), rabbit_stomp_client:send( - Client2, "UNSUBSCRIBE", [{"destination", ?DESTINATION}, - {"id", "subscription-id"}]), + Client2, 'UNSUBSCRIBE', [{<<"destination">>, ?DESTINATION}, + {<<"id">>, <<"subscription-id">>}]), rabbit_stomp_client:send( - Client2, "ACK", [{rabbit_stomp_util:ack_header_name(Version), - proplists:get_value( + Client2, 'ACK', [{rabbit_stomp_util:ack_header_name(Version), + maps:get( rabbit_stomp_util:msg_header_name(Version), Hdrs1)}, - {"receipt", "rcpt2"}]), + {<<"receipt">>, <<"rcpt2">>}]), - {ok, _Client3, Hdrs2, _Body2} = stomp_receive(Client2, "ERROR"), - ?assertEqual("Subscription not found", - proplists:get_value("message", Hdrs2)), + {ok, _Client3, Hdrs2, _Body2} = stomp_receive(Client2, 'ERROR'), + ?assertEqual(<<"Subscription not found">>, + maps:get(<<"message">>, Hdrs2)), ok. subscribe_ack(Config) -> @@ -254,10 +317,10 @@ subscribe_ack(Config) -> %% subscribe and wait for receipt rabbit_stomp_client:send( - Client, "SUBSCRIBE", [{"destination", ?DESTINATION}, - {"receipt", "foo"}, - {"ack", "client"}]), - {ok, Client1, _, _} = stomp_receive(Client, "RECEIPT"), + Client, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}, + {<<"receipt">>, <<"foo">>}, + {<<"ack">>, <<"client">>}]), + {ok, Client1, _, _} = stomp_receive(Client, 'RECEIPT'), %% send from amqp Method = #'basic.publish'{exchange = <<"">>, routing_key = ?QUEUE}, @@ -265,14 +328,14 @@ subscribe_ack(Config) -> amqp_channel:call(Channel, Method, #amqp_msg{props = #'P_basic'{}, payload = <<"hello">>}), - {ok, _Client2, Headers, [<<"hello">>]} = stomp_receive(Client1, "MESSAGE"), - false = (Version == "1.2") xor proplists:is_defined(?HEADER_ACK, Headers), + {ok, _Client2, Headers, [<<"hello">>]} = stomp_receive(Client1, 'MESSAGE'), + false = (Version == "1.2") xor is_map_key(?HEADER_ACK, Headers), MsgHeader = rabbit_stomp_util:msg_header_name(Version), - AckValue = proplists:get_value(MsgHeader, Headers), + AckValue = maps:get(MsgHeader, Headers), AckHeader = rabbit_stomp_util:ack_header_name(Version), - rabbit_stomp_client:send(Client, "ACK", [{AckHeader, AckValue}]), + rabbit_stomp_client:send(Client, 'ACK', [{AckHeader, AckValue}]), #'basic.get_empty'{} = amqp_channel:call(Channel, #'basic.get'{queue = ?QUEUE}), ok. @@ -287,14 +350,14 @@ send(Config) -> %% subscribe and wait for receipt rabbit_stomp_client:send( - Client, "SUBSCRIBE", [{"destination", ?DESTINATION}, {"receipt", "foo"}]), - {ok, Client1, _, _} = stomp_receive(Client, "RECEIPT"), + Client, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}, {<<"receipt">>, <<"foo">>}]), + {ok, Client1, _, _} = stomp_receive(Client, 'RECEIPT'), %% send from stomp rabbit_stomp_client:send( - Client1, "SEND", [{"destination", ?DESTINATION}], ["hello"]), + Client1, 'SEND', [{<<"destination">>, ?DESTINATION}], ["hello"]), - {ok, _Client2, _, [<<"hello">>]} = stomp_receive(Client1, "MESSAGE"), + {ok, _Client2, _, [<<"hello">>]} = stomp_receive(Client1, 'MESSAGE'), ok. delete_queue_subscribe(Config) -> @@ -307,16 +370,16 @@ delete_queue_subscribe(Config) -> %% subscribe and wait for receipt rabbit_stomp_client:send( - Client, "SUBSCRIBE", [{"destination", ?DESTINATION}, {"receipt", "bah"}]), - {ok, Client1, _, _} = stomp_receive(Client, "RECEIPT"), + Client, 'SUBSCRIBE', [{<<"destination">>, ?DESTINATION}, {<<"receipt">>, <<"bah">>}]), + {ok, Client1, _, _} = stomp_receive(Client, 'RECEIPT'), %% delete queue while subscribed #'queue.delete_ok'{} = amqp_channel:call(Channel, #'queue.delete'{queue = ?QUEUE}), - {ok, _Client2, Headers, _} = stomp_receive(Client1, "ERROR"), + {ok, _Client2, Headers, _} = stomp_receive(Client1, 'ERROR'), - ?DESTINATION = proplists:get_value("subscription", Headers), + ?DESTINATION = maps:get(<<"subscription">>, Headers), % server closes connection ok. @@ -328,8 +391,8 @@ temp_destination_queue(Config) -> amqp_channel:call(Channel, #'queue.declare'{queue = ?QUEUE, durable = true, auto_delete = true}), - rabbit_stomp_client:send( Client, "SEND", [{"destination", ?DESTINATION}, - {"reply-to", "/temp-queue/foo"}], + rabbit_stomp_client:send( Client, 'SEND', [{<<"destination">>, ?DESTINATION}, + {<<"reply-to">>, <<"/temp-queue/foo">>}], ["ping"]), amqp_channel:call(Channel,#'basic.consume'{queue = ?QUEUE, no_ack = true}), receive #'basic.consume_ok'{consumer_tag = _Tag} -> ok end, @@ -340,162 +403,163 @@ temp_destination_queue(Config) -> ok = amqp_channel:call(Channel, #'basic.publish'{routing_key = ReplyTo}, #amqp_msg{payload = <<"pong">>}), - {ok, _Client1, _, [<<"pong">>]} = stomp_receive(Client, "MESSAGE"), + {ok, _Client1, _, [<<"pong">>]} = stomp_receive(Client, 'MESSAGE'), ok. temp_destination_in_send(Config) -> Client = ?config(stomp_client, Config), - rabbit_stomp_client:send( Client, "SEND", [{"destination", "/temp-queue/foo"}], + rabbit_stomp_client:send( Client, 'SEND', [{<<"destination">>, <<"/temp-queue/foo">>}], ["poing"]), - {ok, _Client1, Hdrs, _} = stomp_receive(Client, "ERROR"), - "Invalid destination" = proplists:get_value("message", Hdrs), + {ok, _Client1, Hdrs, _} = stomp_receive(Client, 'ERROR'), + <<"Invalid destination">> = maps:get(<<"message">>, Hdrs), ok. blank_destination_in_send(Config) -> Client = ?config(stomp_client, Config), - rabbit_stomp_client:send( Client, "SEND", [{"destination", ""}], + rabbit_stomp_client:send( Client, 'SEND', [{<<"destination">>, <<"">>}], ["poing"]), - {ok, _Client1, Hdrs, _} = stomp_receive(Client, "ERROR"), - "Invalid destination" = proplists:get_value("message", Hdrs), + {ok, _Client1, Hdrs, _} = stomp_receive(Client, 'ERROR'), + <<"Invalid destination">> = maps:get(<<"message">>, Hdrs), ok. stream_filtering(Config) -> Version = ?config(version, Config), Client = ?config(stomp_client, Config), - Stream = atom_to_list(?FUNCTION_NAME) ++ "-" ++ integer_to_list(rand:uniform(10000)), + Stream = <<(atom_to_binary(?FUNCTION_NAME))/binary, $-, + (integer_to_binary(rand:uniform(10000)))/binary>>, %% subscription just to create the stream from STOMP - SubDestination = "/topic/stream-queue-test", + SubDestination = <<"/topic/stream-queue-test">>, rabbit_stomp_client:send( - Client, "SUBSCRIBE", - [{"destination", SubDestination}, - {"receipt", "foo"}, - {"x-queue-name", Stream}, - {"x-queue-type", "stream"}, - {?HEADER_X_STREAM_FILTER_SIZE_BYTES, "32"}, - {"durable", "true"}, - {"auto-delete", "false"}, - {"id", "1234"}, - {"prefetch-count", "1"}, - {"ack", "client"}]), - {ok, Client1, _, _} = stomp_receive(Client, "RECEIPT"), + Client, 'SUBSCRIBE', + [{<<"destination">>, SubDestination}, + {<<"receipt">>, <<"foo">>}, + {<<"x-queue-name">>, Stream}, + {<<"x-queue-type">>, <<"stream">>}, + {?HEADER_X_STREAM_FILTER_SIZE_BYTES, <<"32">>}, + {<<"durable">>, <<"true">>}, + {<<"auto-delete">>, <<"false">>}, + {<<"id">>, <<"1234">>}, + {<<"prefetch-count">>, <<"1">>}, + {<<"ack">>, <<"client">>}]), + {ok, Client1, _, _} = stomp_receive(Client, 'RECEIPT'), rabbit_stomp_client:send( - Client1, "UNSUBSCRIBE", [{"destination", SubDestination}, - {"id", "1234"}, - {"receipt", "bar"}]), - {ok, Client2, _, _} = stomp_receive(Client1, "RECEIPT"), + Client1, 'UNSUBSCRIBE', [{<<"destination">>, SubDestination}, + {<<"id">>, <<"1234">>}, + {<<"receipt">>, <<"bar">>}]), + {ok, Client2, _, _} = stomp_receive(Client1, 'RECEIPT'), %% we are going to publish several waves of messages with and without filter values. %% we will then create subscriptions with various filter options %% and make sure we receive only what we asked for and not all the messages. - StreamDestination = "/amq/queue/" ++ Stream, + StreamDestination = <<"/amq/queue/", Stream/binary>>, %% logic to publish a wave of messages with or without a filter value WaveCount = 1000, Publish = fun(C, FilterValue) -> lists:foldl(fun(Seq, C0) -> - Headers0 = [{"destination", StreamDestination}, - {"receipt", integer_to_list(Seq)}], + Headers0 = [{<<"destination">>, StreamDestination}, + {<<"receipt">>, integer_to_binary(Seq)}], Headers = case FilterValue of undefined -> Headers0; _ -> - [{"x-stream-filter-value", FilterValue}] ++ Headers0 + [{<<"x-stream-filter-value">>, FilterValue}] ++ Headers0 end, rabbit_stomp_client:send( - C0, "SEND", Headers, ["hello"]), - {ok, C1, _, _} = stomp_receive(C0, "RECEIPT"), + C0, 'SEND', Headers, ["hello"]), + {ok, C1, _, _} = stomp_receive(C0, 'RECEIPT'), C1 end, C, lists:seq(1, WaveCount)) end, %% publishing messages with the "apple" filter value - Client3 = Publish(Client2, "apple"), + Client3 = Publish(Client2, <<"apple">>), %% publishing messages with no filter value Client4 = Publish(Client3, undefined), %% publishing messages with the "orange" filter value - Client5 = Publish(Client4, "orange"), + Client5 = Publish(Client4, <<"orange">>), %% filtering on "apple" rabbit_stomp_client:send( - Client5, "SUBSCRIBE", - [{"destination", StreamDestination}, - {"id", "0"}, - {"ack", "client"}, - {"prefetch-count", "1"}, - {"x-stream-filter", "apple"}, - {"x-stream-offset", "first"}]), + Client5, 'SUBSCRIBE', + [{<<"destination">>, StreamDestination}, + {<<"id">>, <<"0">>}, + {<<"ack">>, <<"client">>}, + {<<"prefetch-count">>, <<"1">>}, + {<<"x-stream-filter">>, <<"apple">>}, + {<<"x-stream-offset">>, <<"first">>}]), {Client6, AppleMessages} = stomp_receive_messages(Client5, Version), %% we should get less than all the waves combined ?assert(length(AppleMessages) < WaveCount * 3), %% client-side filtering AppleFilteredMessages = lists:filter(fun(H) -> - proplists:get_value("x-stream-filter-value", H) =:= "apple" + maps:get(<<"x-stream-filter-value">>, H, undefined) =:= <<"apple">> end, AppleMessages), %% we should have only the "apple" messages ?assert(length(AppleFilteredMessages) =:= WaveCount), rabbit_stomp_client:send( - Client6, "UNSUBSCRIBE", [{"destination", StreamDestination}, - {"id", "0"}, - {"receipt", "bar"}]), - {ok, Client7, _, _} = stomp_receive(Client6, "RECEIPT"), + Client6, 'UNSUBSCRIBE', [{<<"destination">>, StreamDestination}, + {<<"id">>, <<"0">>}, + {<<"receipt">>, <<"bar">>}]), + {ok, Client7, _, _} = stomp_receive(Client6, 'RECEIPT'), %% filtering on "apple" and "orange" rabbit_stomp_client:send( - Client7, "SUBSCRIBE", - [{"destination", StreamDestination}, - {"id", "0"}, - {"ack", "client"}, - {"prefetch-count", "1"}, - {"x-stream-filter", "apple,orange"}, - {"x-stream-offset", "first"}]), + Client7, 'SUBSCRIBE', + [{<<"destination">>, StreamDestination}, + {<<"id">>, <<"0">>}, + {<<"ack">>, <<"client">>}, + {<<"prefetch-count">>, <<"1">>}, + {<<"x-stream-filter">>, <<"apple,orange">>}, + {<<"x-stream-offset">>, <<"first">>}]), {Client8, AppleOrangeMessages} = stomp_receive_messages(Client7, Version), %% we should get less than all the waves combined ?assert(length(AppleOrangeMessages) < WaveCount * 3), %% client-side filtering AppleOrangeFilteredMessages = lists:filter(fun(H) -> - proplists:get_value("x-stream-filter-value", H) =:= "apple" orelse - proplists:get_value("x-stream-filter-value", H) =:= "orange" + maps:get(<<"x-stream-filter-value">>, H, undefined) =:= <<"apple">> orelse + maps:get(<<"x-stream-filter-value">>, H, undefined) =:= <<"orange">> end, AppleOrangeMessages), %% we should have only the "apple" and "orange" messages ?assert(length(AppleOrangeFilteredMessages) =:= WaveCount * 2), rabbit_stomp_client:send( - Client8, "UNSUBSCRIBE", [{"destination", StreamDestination}, - {"id", "0"}, - {"receipt", "bar"}]), - {ok, Client9, _, _} = stomp_receive(Client8, "RECEIPT"), + Client8, 'UNSUBSCRIBE', [{<<"destination">>, StreamDestination}, + {<<"id">>, <<"0">>}, + {<<"receipt">>, <<"bar">>}]), + {ok, Client9, _, _} = stomp_receive(Client8, 'RECEIPT'), %% filtering on "apple" and messages without a filter value rabbit_stomp_client:send( - Client9, "SUBSCRIBE", - [{"destination", StreamDestination}, - {"id", "0"}, - {"ack", "client"}, - {"prefetch-count", "1"}, - {"x-stream-filter", "apple"}, - {"x-stream-match-unfiltered", "true"}, - {"x-stream-offset", "first"}]), + Client9, 'SUBSCRIBE', + [{<<"destination">>, StreamDestination}, + {<<"id">>, <<"0">>}, + {<<"ack">>, <<"client">>}, + {<<"prefetch-count">>, <<"1">>}, + {<<"x-stream-filter">>, <<"apple">>}, + {<<"x-stream-match-unfiltered">>, <<"true">>}, + {<<"x-stream-offset">>, <<"first">>}]), {Client10, AppleUnfilteredMessages} = stomp_receive_messages(Client9, Version), %% we should get less than all the waves combined ?assert(length(AppleUnfilteredMessages) < WaveCount * 3), %% client-side filtering AppleUnfilteredFilteredMessages = lists:filter(fun(H) -> - proplists:get_value("x-stream-filter-value", H) =:= "apple" orelse - proplists:get_value("x-stream-filter-value", H) =:= undefined + maps:get(<<"x-stream-filter-value">>, H, undefined) =:= <<"apple">> orelse + maps:get(<<"x-stream-filter-value">>, H, undefined) =:= undefined end, AppleUnfilteredMessages), %% we should have only the "apple" messages and messages without a filter value ?assert(length(AppleUnfilteredFilteredMessages) =:= WaveCount * 2), rabbit_stomp_client:send( - Client10, "UNSUBSCRIBE", [{"destination", StreamDestination}, - {"id", "0"}, - {"receipt", "bar"}]), - {ok, _, _, _} = stomp_receive(Client10, "RECEIPT"), + Client10, 'UNSUBSCRIBE', [{<<"destination">>, StreamDestination}, + {<<"id">>, <<"0">>}, + {<<"receipt">>, <<"bar">>}]), + {ok, _, _, _} = stomp_receive(Client10, 'RECEIPT'), Channel = ?config(amqp_channel, Config), #'queue.delete_ok'{} = amqp_channel:call(Channel, - #'queue.delete'{queue = list_to_binary(Stream)}), + #'queue.delete'{queue = Stream}), ok. stomp_receive_messages(Client, Version) -> @@ -503,12 +567,12 @@ stomp_receive_messages(Client, Version) -> stomp_receive_messages(Client, Acc, Version) -> try rabbit_stomp_client:recv(Client) of - {#stomp_frame{command = "MESSAGE", + {#stomp_frame{command = 'MESSAGE', headers = Headers}, Client1} -> MsgHeader = rabbit_stomp_util:msg_header_name(Version), - AckValue = proplists:get_value(MsgHeader, Headers), + AckValue = maps:get(MsgHeader, Headers), AckHeader = rabbit_stomp_util:ack_header_name(Version), - rabbit_stomp_client:send(Client1, "ACK", [{AckHeader, AckValue}]), + rabbit_stomp_client:send(Client1, 'ACK', [{AckHeader, AckValue}]), stomp_receive_messages(Client1, [Headers] ++ Acc, Version) catch error:{badmatch, {error, timeout}} -> @@ -518,7 +582,6 @@ stomp_receive_messages(Client, Acc, Version) -> stomp_receive(Client, Command) -> {#stomp_frame{command = Command, headers = Hdrs, - body_iolist = Body}, Client1} = + body_iolist_rev = Body}, Client1} = rabbit_stomp_client:recv(Client), {ok, Client1, Hdrs, Body}. - diff --git a/deps/rabbitmq_stomp/test/topic_SUITE.erl b/deps/rabbitmq_stomp/test/topic_SUITE.erl index 162c202d1338..7413fcbcf29a 100644 --- a/deps/rabbitmq_stomp/test/topic_SUITE.erl +++ b/deps/rabbitmq_stomp/test/topic_SUITE.erl @@ -92,53 +92,53 @@ end_per_testcase0(Config) -> publish_topic_authorisation(Config) -> ClientFoo = ?config(client_foo, Config), - AuthorisedTopic = "/topic/user.AuthorisedTopic", - RestrictedTopic = "/topic/user.RestrictedTopic", + AuthorisedTopic = <<"/topic/user.AuthorisedTopic">>, + RestrictedTopic = <<"/topic/user.RestrictedTopic">>, %% send on authorised topic rabbit_stomp_client:send( - ClientFoo, "SUBSCRIBE", [{"destination", AuthorisedTopic}, - {"id", "s0"}, - {"durable", "true"}]), + ClientFoo, 'SUBSCRIBE', [{<<"destination">>, AuthorisedTopic}, + {<<"id">>, <<"s0">>}, + {<<"durable">>, <<"true">>}]), rabbit_stomp_client:send( - ClientFoo, "SEND", [{"destination", AuthorisedTopic}], ["authorised hello"]), + ClientFoo, 'SEND', [{<<"destination">>, AuthorisedTopic}], ["authorised hello"]), - {ok, _Client1, _, Body} = stomp_receive(ClientFoo, "MESSAGE"), + {ok, _Client1, _, Body} = stomp_receive(ClientFoo, 'MESSAGE'), [<<"authorised hello">>] = Body, %% send on restricted topic rabbit_stomp_client:send( - ClientFoo, "SEND", [{"destination", RestrictedTopic}], ["hello"]), - {ok, _Client2, Hdrs2, _} = stomp_receive(ClientFoo, "ERROR"), - "access_refused" = proplists:get_value("message", Hdrs2), + ClientFoo, 'SEND', [{<<"destination">>, RestrictedTopic}], ["hello"]), + {ok, _Client2, Hdrs2, _} = stomp_receive(ClientFoo, 'ERROR'), + <<"access_refused">> = maps:get(<<"message">>, Hdrs2), ok. subscribe_topic_authorisation(Config) -> ClientFoo = ?config(client_foo, Config), - AuthorisedTopic = "/topic/user.AuthorisedTopic", - RestrictedTopic = "/topic/user.RestrictedTopic", + AuthorisedTopic = <<"/topic/user.AuthorisedTopic">>, + RestrictedTopic = <<"/topic/user.RestrictedTopic">>, %% subscribe to authorised topic rabbit_stomp_client:send( - ClientFoo, "SUBSCRIBE", [{"destination", AuthorisedTopic}, - {"id", "s0"}, - {"durable", "true"}]), + ClientFoo, 'SUBSCRIBE', [{<<"destination">>, AuthorisedTopic}, + {<<"id">>, <<"s0">>}, + {<<"durable">>, <<"true">>}]), rabbit_stomp_client:send( - ClientFoo, "SEND", [{"destination", AuthorisedTopic}], ["authorised hello"]), + ClientFoo, 'SEND', [{<<"destination">>, AuthorisedTopic}], ["authorised hello"]), - {ok, _Client1, _, Body} = stomp_receive(ClientFoo, "MESSAGE"), + {ok, _Client1, _, Body} = stomp_receive(ClientFoo, 'MESSAGE'), [<<"authorised hello">>] = Body, %% subscribe to restricted topic rabbit_stomp_client:send( - ClientFoo, "SUBSCRIBE", [{"destination", RestrictedTopic}, - {"id", "s1"}, - {"durable", "true"}]), - {ok, _Client2, Hdrs2, _} = stomp_receive(ClientFoo, "ERROR"), - "access_refused" = proplists:get_value("message", Hdrs2), + ClientFoo, 'SUBSCRIBE', [{<<"destination">>, RestrictedTopic}, + {<<"id">>, <<"s1">>}, + {<<"durable">>, <<"true">>}]), + {ok, _Client2, Hdrs2, _} = stomp_receive(ClientFoo, 'ERROR'), + <<"access_refused">> = maps:get(<<"message">>, Hdrs2), ok. publish_topic_authorisation_regex_not_injected(Config) -> @@ -155,18 +155,18 @@ publish_topic_authorisation_regex_not_injected(Config) -> {ok, ClientRegex} = rabbit_stomp_client:connect(Version, ".*", "pass", StompPort), rabbit_stomp_client:send( - ClientRegex, "SUBSCRIBE", [{"destination", "/topic/.*.Authorised"}, - {"id", "s0"}, - {"durable", "true"}]), + ClientRegex, 'SUBSCRIBE', [{<<"destination">>, <<"/topic/.*.Authorised">>}, + {<<"id">>, <<"s0">>}, + {<<"durable">>, <<"true">>}]), rabbit_stomp_client:send( - ClientRegex, "SEND", [{"destination", "/topic/.*.Authorised"}], ["allowed"]), - {ok, _Client1, _, Body} = stomp_receive(ClientRegex, "MESSAGE"), + ClientRegex, 'SEND', [{<<"destination">>, <<"/topic/.*.Authorised">>}], ["allowed"]), + {ok, _Client1, _, Body} = stomp_receive(ClientRegex, 'MESSAGE'), [<<"allowed">>] = Body, rabbit_stomp_client:send( - ClientRegex, "SEND", [{"destination", "/topic/injected.Authorised"}], ["denied"]), - {ok, _Client2, Hdrs2, _} = stomp_receive(ClientRegex, "ERROR"), - "access_refused" = proplists:get_value("message", Hdrs2), + ClientRegex, 'SEND', [{<<"destination">>, <<"/topic/injected.Authorised">>}], ["denied"]), + {ok, _Client2, Hdrs2, _} = stomp_receive(ClientRegex, 'ERROR'), + <<"access_refused">> = maps:get(<<"message">>, Hdrs2), rabbit_stomp_client:disconnect(ClientRegex), rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_auth_backend_internal, delete_user, @@ -180,7 +180,7 @@ change_default_topic_exchange(Config) -> Ex = <<"my-topic-exchange">>, ok = rabbit_ct_broker_helpers:rpc(Config, 0, application, set_env, [rabbitmq_stomp, default_topic_exchange, Ex]), {ok, ClientFoo} = rabbit_stomp_client:connect(Version, StompPort), - AuthorisedTopic = "/topic/user.AuthorisedTopic", + AuthorisedTopic = <<"/topic/user.AuthorisedTopic">>, Declare = #'exchange.declare'{exchange = Ex, type = <<"topic">>}, #'exchange.declare_ok'{} = amqp_channel:call(Channel, Declare), @@ -188,9 +188,9 @@ change_default_topic_exchange(Config) -> 0 = length(rabbit_ct_broker_helpers:rpc(Config, 0, rabbit_binding, list_for_source, [#resource{virtual_host= <<"/">>, kind = exchange, name = Ex}])), rabbit_stomp_client:send( - ClientFoo, "SUBSCRIBE", [{"destination", AuthorisedTopic}, - {"id", "s0"}, - {"durable", "true"}]), + ClientFoo, 'SUBSCRIBE', [{<<"destination">>, AuthorisedTopic}, + {<<"id">>, <<"s0">>}, + {<<"durable">>, <<"true">>}]), %% STOMP SUBSCRIBE creates the binding asynchronously; wait for it %% to be readable (as in consistency). @@ -201,9 +201,9 @@ change_default_topic_exchange(Config) -> 30_000), rabbit_stomp_client:send( - ClientFoo, "SEND", [{"destination", AuthorisedTopic}], ["ohai there"]), + ClientFoo, 'SEND', [{<<"destination">>, AuthorisedTopic}], ["ohai there"]), - {ok, _Client1, _, Body} = stomp_receive(ClientFoo, "MESSAGE"), + {ok, _Client1, _, Body} = stomp_receive(ClientFoo, 'MESSAGE'), [<<"ohai there">>] = Body, Delete = #'exchange.delete'{exchange = Ex}, @@ -216,7 +216,6 @@ change_default_topic_exchange(Config) -> stomp_receive(Client, Command) -> {#stomp_frame{command = Command, headers = Hdrs, - body_iolist = Body}, Client1} = + body_iolist_rev = Body}, Client1} = rabbit_stomp_client:recv(Client), {ok, Client1, Hdrs, Body}. - diff --git a/deps/rabbitmq_stomp/test/unit_content_length_SUITE.erl b/deps/rabbitmq_stomp/test/unit_content_length_SUITE.erl new file mode 100644 index 000000000000..5ff6139c260b --- /dev/null +++ b/deps/rabbitmq_stomp/test/unit_content_length_SUITE.erl @@ -0,0 +1,62 @@ +%% This Source Code Form is subject to the terms of the Mozilla Public +%% License, v. 2.0. If a copy of the MPL was not distributed with this +%% file, You can obtain one at https://mozilla.org/MPL/2.0/. +%% +%% Copyright (c) 2007-2026 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. All rights reserved. +%% + +-module(unit_content_length_SUITE). + +-include_lib("eunit/include/eunit.hrl"). +-include("rabbit_stomp_frame.hrl"). +-compile(export_all). + +all() -> + [ + negative_content_length_rejected, + non_numeric_content_length_treated_as_unknown, + zero_content_length_accepted, + valid_content_length_accepted, + missing_nul_terminator_rejected, + content_length_with_whitespace_accepted + ]. + +negative_content_length_rejected(_) -> + ?assertMatch({error, {invalid_content_length, -1}}, + parse(send_frame([{"content-length", "-1"}], "hello"))). + +non_numeric_content_length_treated_as_unknown(_) -> + %% Non-numeric content-length is ignored; the parser falls back + %% to scanning for the NUL terminator. + {ok, Frame, _} = parse(send_frame([{"content-length", "abc"}], "hello")), + ?assertEqual('SEND', Frame#stomp_frame.command). + +zero_content_length_accepted(_) -> + {ok, Frame, _} = parse(send_frame([{"content-length", "0"}], "")), + ?assertEqual('SEND', Frame#stomp_frame.command), + ?assertEqual([], Frame#stomp_frame.body_iolist_rev). + +valid_content_length_accepted(_) -> + {ok, Frame, _} = parse(send_frame([{"content-length", "5"}], "hello")), + ?assertEqual('SEND', Frame#stomp_frame.command), + ?assertEqual([<<"hello">>], Frame#stomp_frame.body_iolist_rev). + +%% When content-length is set and the byte at that position is not NUL, +%% the parser must return an error instead of crashing. +missing_nul_terminator_rejected(_) -> + Bin = <<"SEND\ndestination:/queue/t\ncontent-length:3\n\nhelloX">>, + ?assertMatch({error, missing_body_terminator}, parse(Bin)). + +content_length_with_whitespace_accepted(_) -> + {ok, Frame, _} = parse(send_frame([{"content-length", " 5 "}], "hello")), + ?assertEqual('SEND', Frame#stomp_frame.command). + +%%------------------------------------------------------------------- + +send_frame(Headers, Body) -> + HdrStr = lists:flatten( + [K ++ ":" ++ V ++ "\n" || {K, V} <- [{"destination", "/queue/t"} | Headers]]), + iolist_to_binary(["SEND\n", HdrStr, "\n", Body, "\0"]). + +parse(Bin) -> + rabbit_stomp_frame:parse(Bin, rabbit_stomp_frame:initial_state()). diff --git a/deps/rabbitmq_stomp/test/unit_frame_SUITE.erl b/deps/rabbitmq_stomp/test/unit_frame_SUITE.erl index 1a5129431794..f39b203cd01e 100644 --- a/deps/rabbitmq_stomp/test/unit_frame_SUITE.erl +++ b/deps/rabbitmq_stomp/test/unit_frame_SUITE.erl @@ -11,6 +11,9 @@ -include("rabbit_stomp_frame.hrl"). -compile(export_all). +%% Default max_headers is 100 +-define(LIMIT, 100). + all() -> [ max_headers_accepted, @@ -22,42 +25,43 @@ all() -> unique_and_duplicate_mix ]. -%% Exactly 100 unique headers must be accepted. +%% Exactly LIMIT unique headers must be accepted. max_headers_accepted(_) -> - {ok, _, _} = parse(make_frame(unique_headers(100))). + {ok, _, _} = parse(make_frame(unique_headers(?LIMIT))). -%% 101 unique headers must be rejected. +%% LIMIT+1 unique headers must be rejected. exceeds_max_headers_rejected(_) -> - ?assertEqual({error, too_many_headers}, parse(make_frame(unique_headers(101)))). + ?assertEqual({error, {max_headers, ?LIMIT}}, + parse(make_frame(unique_headers(?LIMIT + 1)))). -%% The same rejection must occur when the frame arrives in two TCP chunks, -%% verifying that the seen-names map is preserved across chunk boundaries. +%% The same rejection must occur when the frame arrives in two TCP chunks. exceeds_max_headers_rejected_chunked(_) -> - Full = make_frame(unique_headers(101)), + Full = make_frame(unique_headers(?LIMIT + 1)), Mid = byte_size(Full) div 2, <> = Full, {more, Resume} = parse(Chunk1), - ?assertEqual({error, too_many_headers}, parse(Chunk2, Resume)). + ?assertEqual({error, {max_headers, ?LIMIT}}, + parse(Chunk2, Resume)). -%% When seen-names is at the limit a duplicate must be discarded, not rejected. -%% The duplicate check must fire before the count check. +%% When the map is at the limit, a duplicate must be discarded, not rejected. duplicate_at_limit_boundary_accepted(_) -> - Headers = unique_headers(100) ++ [{"h1", "dup"}], + Headers = unique_headers(?LIMIT) ++ [{"h1", "dup"}], {ok, Frame, _} = parse(make_frame(Headers)), - ?assertEqual({ok, "v"}, rabbit_stomp_frame:header(Frame, "h1")). + ?assertEqual({ok, <<"v">>}, rabbit_stomp_frame:header(Frame, <<"h1">>)). %% Duplicate header names do not count toward the limit. %% A frame with 200 repetitions of the same name has only one unique name. duplicates_do_not_count(_) -> Headers = [{"h", integer_to_list(I)} || I <- lists:seq(1, 200)], {ok, Frame, _} = parse(make_frame(Headers)), - ?assertEqual({ok, "1"}, rabbit_stomp_frame:header(Frame, "h")). + ?assertEqual({ok, <<"1">>}, rabbit_stomp_frame:header(Frame, <<"h">>)). %% The first occurrence of a header name is kept; later ones are ignored. first_occurrence_wins(_) -> Headers = [{"destination", "/queue/a"}, {"destination", "/queue/b"}], {ok, Frame, _} = parse(make_frame(Headers)), - ?assertEqual({ok, "/queue/a"}, rabbit_stomp_frame:header(Frame, "destination")). + ?assertEqual({ok, <<"/queue/a">>}, + rabbit_stomp_frame:header(Frame, <<"destination">>)). %% 50 unique headers plus any number of duplicates stays within the limit. unique_and_duplicate_mix(_) -> diff --git a/deps/rabbitmq_stomp/test/util_SUITE.erl b/deps/rabbitmq_stomp/test/util_SUITE.erl index a0177e5eb1fd..eb5df3568029 100644 --- a/deps/rabbitmq_stomp/test/util_SUITE.erl +++ b/deps/rabbitmq_stomp/test/util_SUITE.erl @@ -8,9 +8,9 @@ -module(util_SUITE). -include_lib("eunit/include/eunit.hrl"). --include_lib("amqp_client/include/amqp_client.hrl"). --include("rabbit_stomp_routing_prefixes.hrl"). +-include_lib("rabbit_common/include/rabbit_framing.hrl"). -include("rabbit_stomp_frame.hrl"). +-include("rabbit_stomp_headers.hrl"). -compile(export_all). all() -> [ @@ -48,25 +48,25 @@ all() -> [ longstr_field(_) -> {<<"ABC">>, longstr, <<"DEF">>} = - rabbit_stomp_util:longstr_field("ABC", "DEF"). + rabbit_stomp_util:longstr_field(<<"ABC">>, <<"DEF">>). message_properties(_) -> - Headers = [ - {"content-type", "text/plain"}, - {"content-encoding", "UTF-8"}, - {"persistent", "true"}, - {"priority", "1"}, - {"correlation-id", "123"}, - {"reply-to", "something"}, - {"expiration", "my-expiration"}, - {"amqp-message-id", "M123"}, - {"timestamp", "123456"}, - {"type", "freshly-squeezed"}, - {"user-id", "joe"}, - {"app-id", "joe's app"}, - {"str", "foo"}, - {"int", "123"} - ], + Headers = #{ + <<"content-type">> => <<"text/plain">>, + <<"content-encoding">> => <<"UTF-8">>, + <<"persistent">> => <<"true">>, + <<"priority">> => <<"1">>, + <<"correlation-id">> => <<"123">>, + <<"reply-to">> => <<"something">>, + <<"expiration">> => <<"my-expiration">>, + <<"amqp-message-id">> => <<"M123">>, + <<"timestamp">> => <<"123456">>, + <<"type">> => <<"freshly-squeezed">>, + <<"user-id">> => <<"joe">>, + <<"app-id">> => <<"joe's app">>, + <<"str">> => <<"foo">>, + <<"int">> => <<"123">> + }, #'P_basic'{ content_type = <<"text/plain">>, @@ -105,20 +105,20 @@ message_headers(_) -> Headers = rabbit_stomp_util:message_headers(Properties), Expected = [ - {"content-type", "text/plain"}, - {"content-encoding", "UTF-8"}, - {"persistent", "true"}, - {"priority", "1"}, - {"correlation-id", "123"}, - {"reply-to", "something"}, - {"expiration", "my-expiration"}, - {"amqp-message-id", "M123"}, - {"timestamp", "123456"}, - {"type", "freshly-squeezed"}, - {"user-id", "joe"}, - {"app-id", "joe's app"}, - {"str", "foo"}, - {"int", "123"} + {<<"content-type">>, <<"text/plain">>}, + {<<"content-encoding">>, <<"UTF-8">>}, + {<<"persistent">>, <<"true">>}, + {<<"priority">>, <<"1">>}, + {<<"correlation-id">>, <<"123">>}, + {<<"reply-to">>, <<"something">>}, + {<<"expiration">>, <<"my-expiration">>}, + {<<"amqp-message-id">>, <<"M123">>}, + {<<"timestamp">>, <<"123456">>}, + {<<"type">>, <<"freshly-squeezed">>}, + {<<"user-id">>, <<"joe">>}, + {<<"app-id">>, <<"joe's app">>}, + {<<"str">>, <<"foo">>}, + {<<"int">>, <<"123">>} ], [] = lists:subtract(Headers, Expected). @@ -128,34 +128,34 @@ minimal_message_headers_with_no_custom(_) -> Headers = rabbit_stomp_util:message_headers(Properties), Expected = [ - {"content-type", "text/plain"}, - {"content-encoding", "UTF-8"}, - {"amqp-message-id", "M123"} + {<<"content-type">>, <<"text/plain">>}, + {<<"content-encoding">>, <<"UTF-8">>}, + {<<"amqp-message-id">>, <<"M123">>} ], [] = lists:subtract(Headers, Expected). headers_post_process(_) -> - Headers = [{"header1", "1"}, - {"header2", "12"}, - {"reply-to", "something"}], - Expected = [{"header1", "1"}, - {"header2", "12"}, - {"reply-to", "/reply-queue/something"}], + Headers = [{<<"header1">>, <<"1">>}, + {<<"header2">>, <<"12">>}, + {<<"reply-to">>, <<"something">>}], + Expected = [{<<"header1">>, <<"1">>}, + {<<"header2">>, <<"12">>}, + {<<"reply-to">>, <<"/reply-queue/something">>}], [] = lists:subtract( rabbit_stomp_util:headers_post_process(Headers), Expected). headers_post_process_noop_replyto(_) -> [begin - Headers = [{"reply-to", Prefix ++ "/something"}], + Headers = [{<<"reply-to">>, <>}], Headers = rabbit_stomp_util:headers_post_process(Headers) - end || Prefix <- rabbit_stomp_routing_util:dest_prefixes()]. + end || Prefix <- ?DEST_PREFIXES]. headers_post_process_noop2(_) -> - Headers = [{"header1", "1"}, - {"header2", "12"}], - Expected = [{"header1", "1"}, - {"header2", "12"}], + Headers = [{<<"header1">>, <<"1">>}, + {<<"header2">>, <<"12">>}], + Expected = [{<<"header1">>, <<"1">>}, + {<<"header2">>, <<"12">>}], [] = lists:subtract( rabbit_stomp_util:headers_post_process(Headers), Expected). @@ -191,40 +191,40 @@ negotiate_version_choice_duplicates(_) -> rabbit_stomp_util:negotiate_version(["1.2", "1.2"], ["1.2", "1.2"]). trim_headers(_) -> - #stomp_frame{headers = [{"one", "foo"}, {"two", "baz "}]} = + #stomp_frame{headers = #{<<"one">> := <<"foo">>, <<"two">> := <<"baz ">>}} = rabbit_stomp_util:trim_headers( - #stomp_frame{headers = [{"one", " foo"}, {"two", " baz "}]}). + #stomp_frame{headers = #{<<"one">> => <<" foo">>, <<"two">> => <<" baz ">>}}). %%-------------------------------------------------------------------- %% Frame Parsing Tests %%-------------------------------------------------------------------- ack_mode_auto(_) -> - Frame = #stomp_frame{headers = [{"ack", "auto"}]}, + Frame = #stomp_frame{headers = #{<<"ack">> => <<"auto">>}}, {auto, _} = rabbit_stomp_util:ack_mode(Frame). ack_mode_auto_default(_) -> - Frame = #stomp_frame{headers = []}, + Frame = #stomp_frame{headers = #{}}, {auto, _} = rabbit_stomp_util:ack_mode(Frame). ack_mode_client(_) -> - Frame = #stomp_frame{headers = [{"ack", "client"}]}, + Frame = #stomp_frame{headers = #{<<"ack">> => <<"client">>}}, {client, true} = rabbit_stomp_util:ack_mode(Frame). ack_mode_client_individual(_) -> - Frame = #stomp_frame{headers = [{"ack", "client-individual"}]}, + Frame = #stomp_frame{headers = #{<<"ack">> => <<"client-individual">>}}, {client, false} = rabbit_stomp_util:ack_mode(Frame). consumer_tag_id(_) -> - Frame = #stomp_frame{headers = [{"id", "foo"}]}, + Frame = #stomp_frame{headers = #{<<"id">> => <<"foo">>}}, {ok, <<"T_foo">>, _} = rabbit_stomp_util:consumer_tag(Frame). consumer_tag_destination(_) -> - Frame = #stomp_frame{headers = [{"destination", "foo"}]}, + Frame = #stomp_frame{headers = #{<<"destination">> => <<"foo">>}}, {ok, <<"Q_foo">>, _} = rabbit_stomp_util:consumer_tag(Frame). consumer_tag_invalid(_) -> - Frame = #stomp_frame{headers = []}, + Frame = #stomp_frame{headers = #{}}, {error, missing_destination_header} = rabbit_stomp_util:consumer_tag(Frame). %%-------------------------------------------------------------------- @@ -233,9 +233,8 @@ consumer_tag_invalid(_) -> parse_valid_message_id(_) -> {ok, {<<"bar">>, "abc", 123}} = - rabbit_stomp_util:parse_message_id("bar@@abc@@123"). + rabbit_stomp_util:parse_message_id(<<"bar@@abc@@123">>). parse_invalid_message_id(_) -> {error, invalid_message_id} = - rabbit_stomp_util:parse_message_id("blah"). - + rabbit_stomp_util:parse_message_id(<<"blah">>). diff --git a/deps/rabbitmq_web_stomp/src/rabbit_web_stomp_handler.erl b/deps/rabbitmq_web_stomp/src/rabbit_web_stomp_handler.erl index c3ee5b5bdbb5..a7f0acb451bf 100644 --- a/deps/rabbitmq_web_stomp/src/rabbit_web_stomp_handler.erl +++ b/deps/rabbitmq_web_stomp/src/rabbit_web_stomp_handler.erl @@ -12,7 +12,6 @@ -include_lib("kernel/include/logger.hrl"). -include_lib("rabbitmq_stomp/include/rabbit_stomp.hrl"). -include_lib("rabbitmq_stomp/include/rabbit_stomp_frame.hrl"). --include_lib("amqp_client/include/amqp_client.hrl"). -include_lib("rabbit_common/include/logging.hrl"). %% Websocket. @@ -36,6 +35,7 @@ heartbeat, heartbeat_sup, parse_state, + parser_config, proc_state, state, conserve_resources, @@ -145,12 +145,13 @@ init(Req0, Opts) -> websocket_init(State) -> process_flag(trap_exit, true), rabbit_access_control:set_max_heap_size_unauthenticated(rabbitmq_web_stomp), - {ok, ProcessorState} = init_processor_state(State), + {ok, ProcessorState, ParserConfig} = init_processor_state(State), LoginTimeout = application:get_env(rabbitmq_stomp, login_timeout, 10_000), erlang:send_after(LoginTimeout, self(), login_timeout), {ok, rabbit_event:init_stats_timer( State#state{proc_state = ProcessorState, - parse_state = rabbit_stomp_frame:initial_state()}, + parser_config = ParserConfig, + parse_state = rabbit_stomp_frame:initial_state(ParserConfig)}, #state.stats_timer)}. -spec close_connection(pid(), string()) -> 'ok'. @@ -163,7 +164,7 @@ close_connection(Pid, Reason) -> exit:{noproc, _} -> ok end. -init_processor_state(#state{socket=Sock, peername=PeerAddr, auth_hd=AuthHd}) -> +init_processor_state(#state{socket=Sock, auth_hd=AuthHd}) -> Self = self(), SendFun = fun(Data) -> Self ! {send, Data}, @@ -171,7 +172,19 @@ init_processor_state(#state{socket=Sock, peername=PeerAddr, auth_hd=AuthHd}) -> end, SSLLogin = application:get_env(rabbitmq_stomp, ssl_cert_login, false), - StompConfig0 = #stomp_configuration{ssl_cert_login = SSLLogin, implicit_connect = false}, + Defaults = #stomp_parser_config{}, + StompConfig0 = #stomp_configuration{ + ssl_cert_login = SSLLogin, + implicit_connect = false, + max_headers = application:get_env( + rabbitmq_stomp, max_headers, + Defaults#stomp_parser_config.max_headers), + max_header_length = application:get_env( + rabbitmq_stomp, max_header_length, + Defaults#stomp_parser_config.max_header_length), + max_body_length = application:get_env( + rabbitmq_stomp, max_body_length, + Defaults#stomp_parser_config.max_body_length)}, UseHTTPAuth = application:get_env(rabbitmq_web_stomp, use_http_auth, false), UserConfig = application:get_env(rabbitmq_stomp, default_user, undefined), StompConfig1 = rabbit_stomp:parse_default_user(UserConfig, StompConfig0), @@ -193,13 +206,21 @@ init_processor_state(#state{socket=Sock, peername=PeerAddr, auth_hd=AuthHd}) -> StompConfig1 end, - AdapterInfo = amqp_connection:socket_adapter_info(Sock, {'Web STOMP', 0}), RealSocket = rabbit_net:unwrap_socket(Sock), + {ok, ConnStr} = rabbit_net:connection_string(Sock, inbound), + ConnName = rabbit_data_coercion:to_binary(ConnStr), + {ok, {PeerHost, PeerPort, Host, Port}} = rabbit_net:socket_ends(Sock, inbound), + logger:update_process_metadata(#{connection => ConnName}), LoginNameFromCertificate = rabbit_stomp_reader:ssl_login_name(RealSocket, StompConfig2), ProcessorState = rabbit_stomp_processor:initial_state( StompConfig2, - {SendFun, AdapterInfo, LoginNameFromCertificate, PeerAddr}), - {ok, ProcessorState}. + {SendFun, LoginNameFromCertificate, ConnName, + Host, Port, PeerHost, PeerPort}), + ParserConfig = #stomp_parser_config{ + max_headers = StompConfig2#stomp_configuration.max_headers, + max_header_length = StompConfig2#stomp_configuration.max_header_length, + max_body_length = StompConfig2#stomp_configuration.max_body_length}, + {ok, ProcessorState, ParserConfig}. websocket_handle({text, Data}, State) -> handle_data(Data, State); @@ -218,43 +239,40 @@ websocket_info({bump_credit, Msg}, State) -> credit_flow:handle_bump_msg(Msg), handle_credits(control_throttle(State)); -websocket_info(#'basic.consume_ok'{}, State) -> - {ok, State}; -websocket_info(#'basic.cancel_ok'{}, State) -> - {ok, State}; -websocket_info(#'basic.ack'{delivery_tag = Tag, multiple = IsMulti}, - State=#state{ proc_state = ProcState0 }) -> - ProcState = rabbit_stomp_processor:flush_pending_receipts(Tag, - IsMulti, - ProcState0), - {ok, State#state{ proc_state = ProcState }}; -websocket_info({Delivery = #'basic.deliver'{}, - #amqp_msg{props = Props, payload = Payload}, - DeliveryCtx}, - State=#state{ proc_state = ProcState0 }) -> - ProcState = rabbit_stomp_processor:send_delivery(Delivery, - Props, - Payload, - DeliveryCtx, - ProcState0), - {ok, State#state{ proc_state = ProcState }}; -websocket_info({Delivery = #'basic.deliver'{}, - #amqp_msg{props = Props, payload = Payload}}, - State=#state{ proc_state = ProcState0 }) -> - ProcState = rabbit_stomp_processor:send_delivery(Delivery, - Props, - Payload, - undefined, - ProcState0), - {ok, State#state{ proc_state = ProcState }}; -websocket_info(#'basic.cancel'{consumer_tag = Ctag}, - State=#state{ proc_state = ProcState0 }) -> - case rabbit_stomp_processor:cancel_consumer(Ctag, ProcState0) of - {ok, ProcState, _Connection} -> - {ok, State#state{ proc_state = ProcState }}; - {stop, _Reason, ProcState} -> - stop(State#state{ proc_state = ProcState }) +websocket_info({'$gen_cast', QueueEvent = {queue_event, _, _}}, + State = #state{proc_state = ProcState0}) -> + case rabbit_stomp_processor:handle_queue_event(QueueEvent, ProcState0) of + {ok, ProcState} -> + {ok, State#state{proc_state = ProcState}}; + {error, _Reason, ProcState} -> + stop(State#state{proc_state = ProcState}) end; +websocket_info({{'DOWN', _QName}, _MRef, process, _Pid, _Reason} = Evt, + State = #state{proc_state = ProcState0}) -> + {ok, ProcState} = rabbit_stomp_processor:handle_down(Evt, ProcState0), + {ok, State#state{proc_state = ProcState}}; +websocket_info({'DOWN', _MRef, process, QPid, _Reason}, State) -> + rabbit_amqqueue_common:notify_sent_queue_down(QPid), + {ok, State}; +websocket_info(connection_created, State = #state{proc_state = ProcState}) -> + State1 = State#state{connection = self()}, + Infos = [{pid, self()}, + {name, rabbit_stomp_processor:adapter_name(ProcState)}, + {protocol, rabbit_stomp_processor:info(protocol, ProcState)}, + {peer_host, rabbit_stomp_processor:info(peer_host, ProcState)}, + {peer_port, rabbit_stomp_processor:info(peer_port, ProcState)}, + {host, rabbit_stomp_processor:info(host, ProcState)}, + {port, rabbit_stomp_processor:info(port, ProcState)}, + {user, rabbit_stomp_processor:info(user, ProcState)}, + {vhost, rabbit_stomp_processor:info(vhost, ProcState)}, + {connected_at, rabbit_stomp_processor:info(connected_at, ProcState)}], + rabbit_core_metrics:connection_created(self(), Infos), + rabbit_event:notify(connection_created, Infos), + logger:update_process_metadata( + #{connection => rabbit_stomp_processor:adapter_name(ProcState), + vhost => rabbit_stomp_processor:info(vhost, ProcState), + user => rabbit_stomp_processor:info(user, ProcState)}), + {ok, State1}; websocket_info({start_heartbeats, _}, State = #state{heartbeat_mode = no_heartbeat}) -> @@ -291,16 +309,8 @@ websocket_info(client_timeout, State) -> stop(State); %%---------------------------------------------------------------------------- -websocket_info({'EXIT', From, Reason}, - State=#state{ proc_state = ProcState0 }) -> - case rabbit_stomp_processor:handle_exit(From, Reason, ProcState0) of - {stop, _Reason, ProcState} -> - stop(State#state{ proc_state = ProcState }); - unknown_exit -> - %% Allow the server to send remaining error messages - self() ! close_websocket, - {ok, State} - end; +websocket_info({'EXIT', _From, _Reason}, State) -> + stop(State); websocket_info(close_websocket, State) -> stop(State); @@ -395,23 +405,23 @@ handle_data(Data, State0) -> handle_data1(<<>>, State) -> {ok, ensure_stats_timer(State)}; -handle_data1(Bytes, State = #state{proc_state = ProcState, - parse_state = ParseState, - connection = OldConn}) -> +handle_data1(Bytes, State = #state{proc_state = ProcState, + parse_state = ParseState, + parser_config = ParserConfig, + connection = OldConn}) -> case rabbit_stomp_frame:parse(Bytes, ParseState) of {more, ParseState1} -> {ok, ensure_stats_timer(State#state{ parse_state = ParseState1 })}; {ok, Frame, Rest} -> case rabbit_stomp_processor:process_frame(Frame, ProcState) of - {ok, ProcState1, ConnPid} -> - maybe_increase_max_frame_size(OldConn, ConnPid), - ParseState1 = rabbit_stomp_frame:initial_state(), + {ok, ProcState1} -> + maybe_increase_max_frame_size(OldConn, ProcState1), + ParseState1 = rabbit_stomp_frame:initial_state(ParserConfig), State1 = maybe_block(State, Frame), handle_data1( Rest, State1 #state{ parse_state = ParseState1, - proc_state = ProcState1, - connection = ConnPid }); + proc_state = ProcState1 }); {stop, _Reason, ProcState1} -> %% do not exit here immediately, because we need to wait for messages eventually enqueued by process_request self() ! close_websocket, @@ -421,15 +431,17 @@ handle_data1(Bytes, State = #state{proc_state = ProcState, Other end. -maybe_increase_max_frame_size(OldConn, ConnPid) - when (OldConn =:= none orelse OldConn =:= undefined) andalso - is_pid(ConnPid) -> - self() ! increase_max_frame_size; +maybe_increase_max_frame_size(OldConn, ProcState) + when OldConn =:= none; OldConn =:= undefined -> + case rabbit_stomp_processor:info(user, ProcState) of + undefined -> ok; + _ -> self() ! increase_max_frame_size + end; maybe_increase_max_frame_size(_, _) -> ok. maybe_block(State = #state{state = blocking, heartbeat = Heartbeat}, - #stomp_frame{command = "SEND"}) -> + #stomp_frame{command = 'SEND'}) -> rabbit_heartbeat:pause_monitor(Heartbeat), State#state{state = blocked}; maybe_block(State, _) -> diff --git a/deps/rabbitmq_web_stomp/test/cowboy_websocket_SUITE.erl b/deps/rabbitmq_web_stomp/test/cowboy_websocket_SUITE.erl index 875846c893d1..2cb4a895508c 100644 --- a/deps/rabbitmq_web_stomp/test/cowboy_websocket_SUITE.erl +++ b/deps/rabbitmq_web_stomp/test/cowboy_websocket_SUITE.erl @@ -198,7 +198,8 @@ sub_non_existent(Config) -> ok = raw_send(WS, "SUBSCRIBE", [{"destination", "/exchange/doesnotexist"}, {"id", "s0"}]), - {<<"ERROR">>, [{<<"message">>,<<"not_found">>} | _Tail ], <<"NOT_FOUND - no exchange 'doesnotexist' in vhost '/'">>} = raw_recv(WS), + {<<"ERROR">>, Headers, <<"no exchange 'doesnotexist' in vhost '/'">>} = raw_recv(WS), + <<"not_found">> = proplists:get_value(<<"message">>, Headers), {close, _} = raw_recv(WS), ok.