Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions broker/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func Setup(t *testing.T, ctx context.Context, egrp *errgroup.Group) {
})
require.NoError(t, err)

namespaceKeys = nil
LaunchNamespaceKeyMaintenance(ctx, egrp)
}

Expand Down Expand Up @@ -193,7 +194,11 @@ func TestBroker(t *testing.T) {
viper.Set("Federation.RegistryUrl", param.Server_ExternalWebUrl.GetString())
listenerChan := make(chan any)
ctxQuick, deadlineCancel := context.WithTimeout(ctx, 5*time.Second) // Have shorter timeout for this handshake
err = LaunchRequestMonitor(ctxQuick, egrp, listenerChan)

externalWebUrl, err := url.Parse(param.Server_ExternalWebUrl.GetString())
require.NoError(t, err)

err = LaunchRequestMonitor(ctxQuick, egrp, server_structs.CacheType, externalWebUrl.Hostname(), listenerChan)
require.NoError(t, err)

// Initiate the callback using the cache-based routines.
Expand All @@ -203,9 +208,9 @@ func TestBroker(t *testing.T) {
brokerUrl.Path = "/api/v1.0/broker/reverse"
query := brokerUrl.Query()
query.Set("origin", param.Server_Hostname.GetString())
query.Set("prefix", "/foo")
query.Set("prefix", "/caches/"+externalWebUrl.Hostname())
brokerUrl.RawQuery = query.Encode()
clientConn, err := ConnectToOrigin(ctxQuick, brokerUrl.String(), "/foo", param.Server_Hostname.GetString())
clientConn, err := ConnectToService(ctxQuick, brokerUrl.String(), "/caches/"+externalWebUrl.Hostname(), param.Server_Hostname.GetString())
require.NoError(t, err)
log.Debugf("Cache got reversed client connection with cache side %s and origin side %s", clientConn.LocalAddr(), clientConn.RemoteAddr())

Expand Down
86 changes: 60 additions & 26 deletions broker/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func generateRequestId() string {
}

// Given an origin's broker URL, return a connected socket to the origin
func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string) (conn net.Conn, err error) {
func ConnectToService(ctx context.Context, brokerUrl, prefix, originName string) (conn net.Conn, err error) {

// Ensure we have a local CA for signing an origin host certificate.
if err = config.GenerateCACert(); err != nil {
Expand Down Expand Up @@ -226,7 +226,7 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string)
// Create a cloned transport which disables HTTP/2 (as that TCP string can't
// be hijacked which we will need to do below). The clone ensures that we're
// not going to be reusing TCP connections.
tr := config.GetTransport().Clone()
tr := config.GetBasicTransport().Clone()
tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
client := &http.Client{Transport: tr}

Expand Down Expand Up @@ -385,7 +385,7 @@ func ConnectToOrigin(ctx context.Context, brokerUrl, prefix, originName string)
//
// The TCP socket used for the callback will be converted to a one-shot listener
// and reused with the origin as the "server".
func doCallback(ctx context.Context, brokerResp reversalRequest) (listener net.Listener, err error) {
func doCallback(ctx context.Context, sType server_structs.ServerType, brokerResp reversalRequest) (listener net.Listener, err error) {
log.Debugln("Origin starting callback to cache at", brokerResp.CallbackUrl)

privateKey, err := privateKeyFromBytes(brokerResp.PrivateKey)
Expand Down Expand Up @@ -415,7 +415,18 @@ func doCallback(ctx context.Context, brokerResp reversalRequest) (listener net.L
}
cacheAud.Path = ""

token, err := createToken(param.Origin_FederationPrefix.GetString(), param.Server_Hostname.GetString(), cacheAud.String(), token_scopes.Broker_Callback)
servicePrefix := ""
url, err := url.Parse(param.Server_ExternalWebUrl.GetString())
if err != nil {
err = errors.Wrap(err, "failure when parsing the external web URL")
return
}
if sType.IsEnabled(server_structs.CacheType) {
servicePrefix = server_structs.GetCacheNs(url.Hostname())
} else {
servicePrefix = server_structs.GetOriginNs(url.Host)
}
token, err := createToken(servicePrefix, url.Hostname(), cacheAud.String(), token_scopes.Broker_Callback)
if err != nil {
err = errors.Wrap(err, "failure when constructing the cache callback token")
return
Expand Down Expand Up @@ -550,32 +561,59 @@ func doCallback(ctx context.Context, brokerResp reversalRequest) (listener net.L
// TLS listener where you can invoke "Accept" once before it automatically
// closes itself. It is the result of a successful connection reversal to
// a cache.
func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan chan any) (err error) {
//
// The request monitor is used by the "private service" (the service behind the
// firewall) to know when to setup connections requested by the "public service"
// (e.g., a cache).
func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, sType server_structs.ServerType, privateName string, resultChan chan any) (err error) {
fedInfo, err := config.GetFederation(ctx)
if err != nil {
return err
}

prefix := ""
if sType.IsEnabled(server_structs.CacheType) {
prefix = server_structs.GetCacheNs(privateName)
} else {
prefix = server_structs.GetOriginNs(privateName)
}

brokerUrl := fedInfo.BrokerEndpoint
if brokerUrl == "" {
return errors.New("Broker service is not set or discovered; cannot enable broker functionality. Try setting Federation.BrokerUrl")
}
brokerEndpoint := brokerUrl + "/api/v1.0/broker/retrieve"
originUrl, err := url.Parse(param.Server_ExternalWebUrl.GetString())
if err != nil {
return
}
oReq := originRequest{
Origin: originUrl.Hostname(),
Prefix: param.Origin_FederationPrefix.GetString(),
Origin: privateName,
Prefix: prefix,
}
req, err := json.Marshal(&oReq)
if err != nil {
return
}
reqReader := bytes.NewReader(req)

// Create a token that will be used to retrieve requests from the broker;
// this is done before the goroutine starts to guarantee that we are looking up
// the Viper config from a single-threaded context. Otherwise, during startup,
// we may have concurrent read and write operations to the Viper config, which
// can lead to a panic.
brokerAud, err := url.Parse(fedInfo.BrokerEndpoint)
if err != nil {
log.Errorln("Failure when parsing broker URL:", err)
return
}
brokerAud.Path = ""
token, err := createToken(prefix, param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
if err != nil {
log.Errorln("Failure when constructing the broker retrieve token:", err)
return
}

client := &http.Client{Transport: config.GetBasicTransport()}

egrp.Go(func() (err error) {
firstLoop := true
for {
sleepDuration := time.Second + time.Duration(mrand.Intn(500))*time.Microsecond
select {
Expand All @@ -595,25 +633,21 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "pelican-origin/"+config.GetVersion())

brokerAud, err := url.Parse(fedInfo.BrokerEndpoint)
if err != nil {
log.Errorln("Failure when parsing broker URL:", err)
break
}
brokerAud.Path = ""

token, err := createToken(param.Origin_FederationPrefix.GetString(), param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
if err != nil {
log.Errorln("Failure when constructing the broker retrieve token:", err)
break
if !firstLoop {
token, err = createToken(prefix, param.Server_Hostname.GetString(), brokerAud.String(), token_scopes.Broker_Retrieve)
if err != nil {
log.Errorln("Failure when constructing the broker retrieve token:", err)
break
}
}
firstLoop = false
req.Header.Set("Authorization", "Bearer "+token)

tr := config.GetTransport()
client := &http.Client{Transport: tr}

resp, err := client.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) {
break
}
log.Errorln("Failure when invoking the broker URL for retrieving requests", err)
break
}
Expand Down Expand Up @@ -642,7 +676,7 @@ func LaunchRequestMonitor(ctx context.Context, egrp *errgroup.Group, resultChan
}

if brokerResp.Status == server_structs.RespOK {
listener, err := doCallback(ctx, brokerResp.Request)
listener, err := doCallback(ctx, sType, brokerResp.Request)
if err != nil {
log.Errorln("Failed to callback to the cache:", err)
resultChan <- err
Expand Down
105 changes: 105 additions & 0 deletions broker/dialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/***************************************************************
*
* Copyright (C) 2025, Pelican Project, Morgridge Institute for Research
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
***************************************************************/

// This file contains methods for a dialer that can use the broker
// functionality to connect to a remote service.

package broker

import (
"context"
"net"
"strings"

"github.com/jellydator/ttlcache/v3"
"golang.org/x/sync/errgroup"

"github.com/pelicanplatform/pelican/param"
"github.com/pelicanplatform/pelican/server_structs"
)

type (
brokerPrefixInfo struct {
ServerType server_structs.ServerType
BrokerUrl string
}

// BrokerDialer is a dialer that can use the broker
// functionality to connect to a remote service.
BrokerDialer struct {
dialerContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Map from service name to broker endpoint.
// If the service name is not found in the cache, then the dialer
// will use a normal TCP connection to the service.
brokerEndpoints *ttlcache.Cache[string, brokerPrefixInfo]
}
)

// NewBrokerDialer creates a new BrokerDialer.
func NewBrokerDialer(ctx context.Context, egrp *errgroup.Group) *BrokerDialer {

dialer := &net.Dialer{
Timeout: param.Transport_DialerTimeout.GetDuration(),
KeepAlive: param.Transport_DialerKeepAlive.GetDuration(),
}
brokerEndpoints := ttlcache.New(
ttlcache.WithTTL[string, brokerPrefixInfo](param.Transport_BrokerEndpointCacheTTL.GetDuration()),
ttlcache.WithDisableTouchOnHit[string, brokerPrefixInfo](),
)

go brokerEndpoints.Start()
Comment thread
bbockelm marked this conversation as resolved.
egrp.Go(func() error {
<-ctx.Done()
brokerEndpoints.DeleteAll()
brokerEndpoints.Stop()
return nil
})

return &BrokerDialer{
dialerContext: dialer.DialContext,
brokerEndpoints: brokerEndpoints,
}
}

// Set the dialer to use `brokerUrl` as the broker endpoint for
// the service `name`.
func (d *BrokerDialer) UseBroker(serverType server_structs.ServerType, name, brokerUrl string) {
d.brokerEndpoints.Set(name, brokerPrefixInfo{
ServerType: serverType,
BrokerUrl: brokerUrl,
}, ttlcache.DefaultTTL)
}

// DialContext dials a connection to the given network and address using the broker.
func (d *BrokerDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
info := d.brokerEndpoints.Get(addr)
if info == nil {
// If the endpoint is not found in the cache, use the default dialer.
return d.dialerContext(ctx, network, addr)
}

sType := info.Value().ServerType
prefix := ""
if sType.IsEnabled(server_structs.CacheType) {
addrSplit := strings.SplitN(addr, ":", 2)
prefix = "/caches/" + addrSplit[0]
} else {
prefix = "/origins/" + addr
}
return ConnectToService(ctx, info.Value().BrokerUrl, prefix, addr)
}
8 changes: 4 additions & 4 deletions broker/request_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type (
PrivateKey string `json:"private_key,omitempty"`
RequestId string `json:"request_id,omitempty"`
Prefix string `json:"prefix,omitempty"`
OriginName string `json:"origin,omitempty"`
OriginName string `json:"origin,omitempty"` // Name of the service for the reversal request. Originally, brokers were for origins-only (hence the inexact name of the parameter).
}

requestInfo struct {
Expand Down Expand Up @@ -68,7 +68,7 @@ func getOriginQueue(prefix, origin string) chan reversalRequest {
// Send a request to a given origin's queue.
// Return a requestTimeout error if no origin retrieved the request before the context timed out.
func handleRequest(ctx context.Context, origin string, req reversalRequest, timeout time.Duration) (err error) {
queue := getOriginQueue(req.Prefix, origin)
queue := getOriginQueue("/", origin)
maxTime := timeout - 500*time.Millisecond - time.Duration(rand.Intn(500))*time.Millisecond
if maxTime <= 0 {
maxTime = time.Millisecond
Expand All @@ -90,7 +90,7 @@ func handleRequest(ctx context.Context, origin string, req reversalRequest, time
}

// Handle the origin's request to retrieve any pending reversals.
func handleRetrieve(appCtx context.Context, ginCtx context.Context, prefix, origin string, timeout time.Duration) (req reversalRequest, err error) {
func handleRetrieve(appCtx context.Context, ginCtx context.Context, origin string, timeout time.Duration) (req reversalRequest, err error) {
// Return randomly short of the timeout.
maxTime := timeout - 500*time.Millisecond - time.Duration(rand.Intn(500))*time.Millisecond
if maxTime <= 0 {
Expand All @@ -99,7 +99,7 @@ func handleRetrieve(appCtx context.Context, ginCtx context.Context, prefix, orig
tick := time.NewTicker(maxTime)
defer tick.Stop()
select {
case req = <-getOriginQueue(prefix, origin):
case req = <-getOriginQueue("/", origin):
break
case <-tick.C:
err = errRetrieveTimeout
Expand Down
25 changes: 23 additions & 2 deletions broker/server_apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func newBrokerRespTimeout() (result brokerRetrievalResp) {
return
}

// Retrieve any pending reversal requests from the connection broker.
//
// This is long-polled by the service relying on the connection broker
// (e.g., an origin); it will return any reversal requests from a public
// service (e.g., a cache) for the origin to make a connection.
func retrieveRequest(ctx context.Context, ginCtx *gin.Context) {
timeoutStr := "5s"
if val := ginCtx.Request.Header.Get("X-Pelican-Timeout"); val != "" {
Expand Down Expand Up @@ -108,7 +113,7 @@ func retrieveRequest(ctx context.Context, ginCtx *gin.Context) {
ginCtx.AbortWithStatusJSON(http.StatusUnauthorized, newBrokerRespFail("Authorization denied"))
}

req, err := handleRetrieve(ctx, ginCtx, originReq.Prefix, originReq.Origin, timeoutVal)
req, err := handleRetrieve(ctx, ginCtx, originReq.Origin, timeoutVal)
if errors.Is(err, errRetrieveTimeout) {
ginCtx.JSON(http.StatusOK, newBrokerRespTimeout())
return
Expand All @@ -120,6 +125,10 @@ func retrieveRequest(ctx context.Context, ginCtx *gin.Context) {
ginCtx.JSON(http.StatusOK, newBrokerReqResp(req))
}

// Service a request to the broker to initiate a connection.
//
// The connection reversal request will cause a listening service (e.g., an origin)
// to connect to the endpoint provided by the public service (e.g., a cache).
func reverseRequest(ctx context.Context, ginCtx *gin.Context) {
timeoutStr := "5s"
if val := ginCtx.Request.Header.Get("X-Pelican-Timeout"); val != "" {
Expand Down Expand Up @@ -178,13 +187,25 @@ func reverseRequest(ctx context.Context, ginCtx *gin.Context) {
}
}

// Register the central broker functionality with the gin router.
//
// Typically, this is done by the director; two APIs are exposed:
// - `retrieve`: Services needing connection brokering (e.g., origins behind a firewall)
// will long-poll this endpoint to retrieve any connection brokering requests from
// a public service (e.g., a cache).
// - `reverse`: Invoked by a public service (e.g., a cache) that would like to connect
// to a service behind a firewall (e.g., an origin). Official request for the origin
// to make a connection.
func RegisterBroker(ctx context.Context, router *gin.RouterGroup) {
// Establish the routes used for cache/origin redirection
router.POST("/api/v1.0/broker/retrieve", func(ginCtx *gin.Context) { retrieveRequest(ctx, ginCtx) })
router.POST("/api/v1.0/broker/reverse", func(ginCtx *gin.Context) { reverseRequest(ctx, ginCtx) })
}

// Cache's HTTP handler function for callbacks from an origin
// Server's HTTP handler function for callbacks from a remote service behind a broker.
//
// The server will authorize the request then hand the go routine waiting for the connection
// reversal. It will return once the other routine has completed the connection reversal.
func handleCallback(ctx context.Context, ginCtx *gin.Context) {
callbackReq := callbackRequest{}
if err := ginCtx.Bind(&callbackReq); err != nil {
Expand Down
Loading
Loading