Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions federation/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func MakeJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request
// or dealing with HTTP responses itself.
func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeJoin, err error) {
// Generate a join event
proto, err := room.ProtoEventCreator(Event{
proto, err := room.ProtoEventCreator(room, Event{
Type: "m.room.member",
StateKey: &userID,
Content: map[string]interface{}{
Expand All @@ -84,7 +84,7 @@ func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.
// or dealing with HTTP responses itself.
func MakeRespMakeKnock(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeKnock, err error) {
// Generate a knock event
proto, err := room.ProtoEventCreator(Event{
proto, err := room.ProtoEventCreator(room, Event{
Type: "m.room.member",
StateKey: &userID,
Content: map[string]interface{}{
Expand Down Expand Up @@ -159,7 +159,7 @@ func SendJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request
return
}

resp := room.GenerateSendJoinResponse(s, event, expectPartialState, omitServersInRoom)
resp := room.GenerateSendJoinResponse(room, s, event, expectPartialState, omitServersInRoom)
b, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(500)
Expand Down
55 changes: 41 additions & 14 deletions federation/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (s *Server) MakeAliasMapping(aliasLocalpart, roomID string) string {

// MustMakeRoom will add a room to this server so it is accessible to other servers when prompted via federation.
// The `events` will be added to this room. Returns the created room.
func (s *Server) MustMakeRoom(t ct.TestLike, roomVer gomatrixserverlib.RoomVersion, events []Event) *ServerRoom {
func (s *Server) MustMakeRoom(t ct.TestLike, roomVer gomatrixserverlib.RoomVersion, events []Event, opts ...ServerRoomOpt) *ServerRoom {
if !s.listening {
ct.Fatalf(s.t, "MustMakeRoom() called before Listen() - this is not supported because Listen() chooses a high-numbered port and thus changes the server name and thus changes the room ID. Ensure you Listen() first!")
}
Expand All @@ -184,13 +184,16 @@ func (s *Server) MustMakeRoom(t ct.TestLike, roomVer gomatrixserverlib.RoomVersi
roomID := fmt.Sprintf("!%d-%s:%s", len(s.rooms), util.RandomString(18), s.serverName)
t.Logf("Creating room %s with version %s", roomID, roomVer)
room := NewServerRoom(roomVer, roomID)
for _, opt := range opts {
opt(room)
}

// sign all these events
for _, ev := range events {
signedEvent := s.MustCreateEvent(t, room, ev)
room.AddEvent(signedEvent)
}
s.rooms[roomID] = room
s.rooms[room.RoomID] = room
return room
}

Expand Down Expand Up @@ -303,11 +306,11 @@ func (s *Server) DoFederationRequest(
// It does not insert this event into the room however. See ServerRoom.AddEvent for that.
func (s *Server) MustCreateEvent(t ct.TestLike, room *ServerRoom, ev Event) gomatrixserverlib.PDU {
t.Helper()
proto, err := room.ProtoEventCreator(ev)
proto, err := room.ProtoEventCreator(room, ev)
if err != nil {
ct.Fatalf(t, "MustCreateEvent: failed to create proto event: %v", err)
}
pdu, err := room.EventCreator(s, proto)
pdu, err := room.EventCreator(room, s, proto)
if err != nil {
ct.Fatalf(t, "MustCreateEvent: failed to create PDU: %v", err)
}
Expand All @@ -316,8 +319,12 @@ func (s *Server) MustCreateEvent(t ct.TestLike, room *ServerRoom, ev Event) goma

// MustJoinRoom will make the server send a make_join and a send_join to join a room
// It returns the resultant room.
func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, remoteServer spec.ServerName, roomID string, userID string, partialState ...bool) *ServerRoom {
func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, remoteServer spec.ServerName, roomID string, userID string, opts ...JoinRoomOpt) *ServerRoom {
t.Helper()
var jr joinRoom
for _, opt := range opts {
opt(&jr)
}
origin := spec.ServerName(s.serverName)
fedClient := s.FederationClient(deployment)
makeJoinResp, err := fedClient.MakeJoin(context.Background(), origin, remoteServer, roomID, userID)
Expand Down Expand Up @@ -372,7 +379,7 @@ func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, re
ct.Fatalf(t, "MustJoinRoom: failed to sign event: %v", err)
}
var sendJoinResp fclient.RespSendJoin
if len(partialState) == 0 || !partialState[0] {
if !jr.partialState {
// Default to doing a regular join.
sendJoinResp, err = fedClient.SendJoin(context.Background(), origOrigin, remoteServer, joinEvent)
} else {
Expand All @@ -382,10 +389,13 @@ func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, re
ct.Fatalf(t, "MustJoinRoom: send_join failed: %v", err)
}
room := NewServerRoom(roomVer, roomID)
room.PopulateFromSendJoinResponse(joinEvent, sendJoinResp)
s.rooms[roomID] = room
for _, opt := range jr.roomOpts {
opt(room)
}
room.PopulateFromSendJoinResponse(room, joinEvent, sendJoinResp)
s.rooms[room.RoomID] = room

t.Logf("Server.MustJoinRoom joined room ID %s", roomID)
t.Logf("Server.MustJoinRoom joined room ID %s", room.RoomID)

return room
}
Expand Down Expand Up @@ -433,11 +443,6 @@ func (s *Server) MustLeaveRoom(t ct.TestLike, deployment FederationDeployment, r
t.Logf("Server.MustLeaveRoom left room ID %s", roomID)
}

// AddRoom is a low-level function to add a custom room to the server. Useful to mix custom logic with helper functions.
func (s *Server) AddRoom(room *ServerRoom) {
s.rooms[room.RoomID] = room
}

// ValidFederationRequest is a wrapper around http.HandlerFunc which automatically validates the incoming
// federation request and supports sending back JSON. Fails the test if the request is not valid.
func (s *Server) ValidFederationRequest(t ct.TestLike, handler func(fr *fclient.FederationRequest, pathParams map[string]string) util.JSONResponse) http.HandlerFunc {
Expand Down Expand Up @@ -513,6 +518,28 @@ func (s *Server) Listen() (cancel func()) {
}
}

type joinRoom struct {
partialState bool
roomOpts []ServerRoomOpt
}

// JoinRoomOpt is an option for configuring how the server should join the room
type JoinRoomOpt func(jr *joinRoom)

// WithPartialState tells the server to join the room with partial state
func WithPartialState() JoinRoomOpt {
return func(jr *joinRoom) {
jr.partialState = true
}
}

// WithRoomOpts controls how the newly joined room is created
func WithRoomOpts(opts ...ServerRoomOpt) JoinRoomOpt {
return func(jr *joinRoom) {
jr.roomOpts = opts
}
}

// federationServer creates a federation server with the given handler
func federationServer(cfg *config.Complement, h http.Handler) (*http.Server, string, string, error) {
var derBytes []byte
Expand Down
100 changes: 58 additions & 42 deletions federation/server_room.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ type Event struct {
Redacts string
}

// ServerRoomOpt are options that can configure ServerRooms
type ServerRoomOpt func(r *ServerRoom)

// WithRoomID configures the room to have the given room ID
func WithRoomID(roomID string) ServerRoomOpt {
return func(r *ServerRoom) {
r.RoomID = roomID
}
}

// WithImpl configures the room to have the given ServerRoomImpl.
// Useful for custom rooms.
func WithImpl(impl ServerRoomImpl) ServerRoomOpt {
return func(r *ServerRoom) {
r.ServerRoomImpl = impl
}
}

// EXPERIMENTAL
// ServerRoom represents a room on this test federation server
type ServerRoom struct {
Expand Down Expand Up @@ -67,7 +85,7 @@ func NewServerRoom(roomVer gomatrixserverlib.RoomVersion, roomId string) *Server
waiters: make(map[string][]*helpers.Waiter),
waitersMu: &sync.Mutex{},
}
room.ServerRoomImpl = &ServerRoomImplDefault{Room: room}
room.ServerRoomImpl = &ServerRoomImplDefault{}
return room
}

Expand Down Expand Up @@ -354,58 +372,56 @@ type ServerRoomImpl interface {
// ProtoEventCreator converts a Complement Event into a gomatrixserverlib proto event, ready to be signed.
// This function is used in /make_x endpoints to create proto events to return to other servers.
// This function is one of two used when creating events, the other being EventCreator.
ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error)
ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error)
// EventCreator converts a proto event into a signed PDU.
EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
// PopulateFromSendJoinResponse should replace the state of this ServerRoom with the information contained
// in RespSendJoin and the join event.
PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
// GenerateSendJoinResponse generates a /send_join response to send back to a server.
GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
}

type ServerRoomImplCustom struct {
ServerRoomImplDefault
ProtoEventCreatorFn func(def ServerRoomImpl, ev Event) (*gomatrixserverlib.ProtoEvent, error)
EventCreatorFn func(def ServerRoomImpl, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
PopulateFromSendJoinResponseFn func(def ServerRoomImpl, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
GenerateSendJoinResponseFn func(def ServerRoomImpl, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
ProtoEventCreatorFn func(def ServerRoomImpl, room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error)
EventCreatorFn func(def ServerRoomImpl, room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error)
PopulateFromSendJoinResponseFn func(def ServerRoomImpl, room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin)
GenerateSendJoinResponseFn func(def ServerRoomImpl, room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin
}

func (i *ServerRoomImplCustom) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) {
func (i *ServerRoomImplCustom) ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error) {
if i.ProtoEventCreatorFn != nil {
return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, ev)
return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, room, ev)
}
return i.ServerRoomImplDefault.ProtoEventCreator(ev)
return i.ServerRoomImplDefault.ProtoEventCreator(room, ev)
}

func (i *ServerRoomImplCustom) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
func (i *ServerRoomImplCustom) EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
if i.EventCreatorFn != nil {
return i.EventCreatorFn(&i.ServerRoomImplDefault, s, proto)
return i.EventCreatorFn(&i.ServerRoomImplDefault, room, s, proto)
}
return i.ServerRoomImplDefault.EventCreator(s, proto)
return i.ServerRoomImplDefault.EventCreator(room, s, proto)
}

func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
if i.PopulateFromSendJoinResponseFn != nil {
i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, joinEvent, resp)
i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, room, joinEvent, resp)
return
}
i.ServerRoomImplDefault.PopulateFromSendJoinResponse(joinEvent, resp)
i.ServerRoomImplDefault.PopulateFromSendJoinResponse(room, joinEvent, resp)
}

func (i *ServerRoomImplCustom) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
func (i *ServerRoomImplCustom) GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
if i.GenerateSendJoinResponseFn != nil {
return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, s, joinEvent, expectPartialState, omitServersInRoom)
return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, room, s, joinEvent, expectPartialState, omitServersInRoom)
}
return i.ServerRoomImplDefault.GenerateSendJoinResponse(s, joinEvent, expectPartialState, omitServersInRoom)
return i.ServerRoomImplDefault.GenerateSendJoinResponse(room, s, joinEvent, expectPartialState, omitServersInRoom)
}

type ServerRoomImplDefault struct {
Room *ServerRoom
}
type ServerRoomImplDefault struct{}

func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) {
func (i *ServerRoomImplDefault) ProtoEventCreator(room *ServerRoom, ev Event) (*gomatrixserverlib.ProtoEvent, error) {
var prevEvents interface{}
if ev.PrevEvents != nil {
// We deliberately want to set the prev events.
Expand All @@ -414,14 +430,14 @@ func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.
// No other prev events were supplied so we'll just
// use the forward extremities of the room, which is
// the usual behaviour.
prevEvents = i.Room.ForwardExtremities
prevEvents = room.ForwardExtremities
}
proto := gomatrixserverlib.ProtoEvent{
SenderID: ev.Sender,
Depth: int64(i.Room.Depth + 1), // depth starts at 1
Depth: int64(room.Depth + 1), // depth starts at 1
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: i.Room.RoomID,
RoomID: room.RoomID,
PrevEvents: prevEvents,
AuthEvents: ev.AuthEvents,
Redacts: ev.Redacts,
Expand All @@ -438,13 +454,13 @@ func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.
if err != nil {
return nil, fmt.Errorf("EventCreator: failed to work out auth_events : %s", err)
}
proto.AuthEvents = i.Room.AuthEvents(stateNeeded)
proto.AuthEvents = room.AuthEvents(stateNeeded)
}
return &proto, nil
}

func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
verImpl, err := gomatrixserverlib.GetRoomVersion(i.Room.Version)
func (i *ServerRoomImplDefault) EventCreator(room *ServerRoom, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) {
verImpl, err := gomatrixserverlib.GetRoomVersion(room.Version)
if err != nil {
return nil, fmt.Errorf("EventCreator: invalid room version: %s", err)
}
Expand All @@ -456,19 +472,19 @@ func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib
return signedEvent, nil
}

func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
stateEvents := resp.StateEvents.UntrustedEvents(i.Room.Version)
func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(room *ServerRoom, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) {
stateEvents := resp.StateEvents.UntrustedEvents(room.Version)
for _, ev := range stateEvents {
i.Room.ReplaceCurrentState(ev)
room.ReplaceCurrentState(ev)
}
i.Room.AddEvent(joinEvent)
room.AddEvent(joinEvent)
}

func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
func (i *ServerRoomImplDefault) GenerateSendJoinResponse(room *ServerRoom, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin {
// build the state list *before* we insert the new event
var stateEvents []gomatrixserverlib.PDU
i.Room.StateMutex.RLock()
for _, ev := range i.Room.State {
room.StateMutex.RLock()
for _, ev := range room.State {
// filter out non-critical memberships if this is a partial-state join
if expectPartialState {
if ev.Type() == "m.room.member" && ev.StateKey() != joinEvent.StateKey() {
Expand All @@ -477,18 +493,18 @@ func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent go
}
stateEvents = append(stateEvents, ev)
}
i.Room.StateMutex.RUnlock()
room.StateMutex.RUnlock()

authEvents := i.Room.AuthChainForEvents(stateEvents)
authEvents := room.AuthChainForEvents(stateEvents)

// get servers in room *before* the join event
serversInRoom := []string{s.serverName}
if !omitServersInRoom {
serversInRoom = i.Room.ServersInRoom()
serversInRoom = room.ServersInRoom()
}

// insert the join event into the room state
i.Room.AddEvent(joinEvent)
room.AddEvent(joinEvent)
log.Printf("Received send-join of event %s", joinEvent.EventID())

// return state and auth chain
Expand Down
Loading