Skip to content

Commit 79f9ef0

Browse files
committed
Refactor ACME support to certificate provider
1 parent 3de5fdb commit 79f9ef0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3083
-172
lines changed

adapter/certificate/adapter.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package certificate
2+
3+
type Adapter struct {
4+
providerType string
5+
providerTag string
6+
}
7+
8+
func NewAdapter(providerType string, providerTag string) Adapter {
9+
return Adapter{
10+
providerType: providerType,
11+
providerTag: providerTag,
12+
}
13+
}
14+
15+
func (a *Adapter) Type() string {
16+
return a.providerType
17+
}
18+
19+
func (a *Adapter) Tag() string {
20+
return a.providerTag
21+
}

adapter/certificate/manager.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package certificate
2+
3+
import (
4+
"context"
5+
"os"
6+
"sync"
7+
"time"
8+
9+
"github.com/sagernet/sing-box/adapter"
10+
"github.com/sagernet/sing-box/common/taskmonitor"
11+
C "github.com/sagernet/sing-box/constant"
12+
"github.com/sagernet/sing-box/log"
13+
"github.com/sagernet/sing/common"
14+
E "github.com/sagernet/sing/common/exceptions"
15+
F "github.com/sagernet/sing/common/format"
16+
)
17+
18+
var _ adapter.CertificateProviderManager = (*Manager)(nil)
19+
20+
type Manager struct {
21+
logger log.ContextLogger
22+
registry adapter.CertificateProviderRegistry
23+
access sync.Mutex
24+
started bool
25+
stage adapter.StartStage
26+
providers []adapter.CertificateProviderService
27+
providerByTag map[string]adapter.CertificateProviderService
28+
}
29+
30+
func NewManager(logger log.ContextLogger, registry adapter.CertificateProviderRegistry) *Manager {
31+
return &Manager{
32+
logger: logger,
33+
registry: registry,
34+
providerByTag: make(map[string]adapter.CertificateProviderService),
35+
}
36+
}
37+
38+
func (m *Manager) Start(stage adapter.StartStage) error {
39+
m.access.Lock()
40+
if m.started && m.stage >= stage {
41+
panic("already started")
42+
}
43+
m.started = true
44+
m.stage = stage
45+
providers := m.providers
46+
m.access.Unlock()
47+
for _, provider := range providers {
48+
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
49+
m.logger.Trace(stage, " ", name)
50+
startTime := time.Now()
51+
err := adapter.LegacyStart(provider, stage)
52+
if err != nil {
53+
return E.Cause(err, stage, " ", name)
54+
}
55+
m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
56+
}
57+
return nil
58+
}
59+
60+
func (m *Manager) Close() error {
61+
m.access.Lock()
62+
defer m.access.Unlock()
63+
if !m.started {
64+
return nil
65+
}
66+
m.started = false
67+
providers := m.providers
68+
m.providers = nil
69+
monitor := taskmonitor.New(m.logger, C.StopTimeout)
70+
var err error
71+
for _, provider := range providers {
72+
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
73+
m.logger.Trace("close ", name)
74+
startTime := time.Now()
75+
monitor.Start("close ", name)
76+
err = E.Append(err, provider.Close(), func(err error) error {
77+
return E.Cause(err, "close ", name)
78+
})
79+
monitor.Finish()
80+
m.logger.Trace("close ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
81+
}
82+
return err
83+
}
84+
85+
func (m *Manager) CertificateProviders() []adapter.CertificateProviderService {
86+
m.access.Lock()
87+
defer m.access.Unlock()
88+
return m.providers
89+
}
90+
91+
func (m *Manager) Get(tag string) (adapter.CertificateProviderService, bool) {
92+
m.access.Lock()
93+
provider, found := m.providerByTag[tag]
94+
m.access.Unlock()
95+
return provider, found
96+
}
97+
98+
func (m *Manager) Remove(tag string) error {
99+
m.access.Lock()
100+
provider, found := m.providerByTag[tag]
101+
if !found {
102+
m.access.Unlock()
103+
return os.ErrInvalid
104+
}
105+
delete(m.providerByTag, tag)
106+
index := common.Index(m.providers, func(it adapter.CertificateProviderService) bool {
107+
return it == provider
108+
})
109+
if index == -1 {
110+
panic("invalid certificate provider index")
111+
}
112+
m.providers = append(m.providers[:index], m.providers[index+1:]...)
113+
started := m.started
114+
m.access.Unlock()
115+
if started {
116+
return provider.Close()
117+
}
118+
return nil
119+
}
120+
121+
func (m *Manager) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error {
122+
provider, err := m.registry.Create(ctx, logger, tag, providerType, options)
123+
if err != nil {
124+
return err
125+
}
126+
m.access.Lock()
127+
defer m.access.Unlock()
128+
if m.started {
129+
name := "certificate-provider/" + provider.Type() + "[" + provider.Tag() + "]"
130+
for _, stage := range adapter.ListStartStages {
131+
m.logger.Trace(stage, " ", name)
132+
startTime := time.Now()
133+
err = adapter.LegacyStart(provider, stage)
134+
if err != nil {
135+
return E.Cause(err, stage, " ", name)
136+
}
137+
m.logger.Trace(stage, " ", name, " completed (", F.Seconds(time.Since(startTime).Seconds()), "s)")
138+
}
139+
}
140+
if existsProvider, loaded := m.providerByTag[tag]; loaded {
141+
if m.started {
142+
err = existsProvider.Close()
143+
if err != nil {
144+
return E.Cause(err, "close certificate-provider/", existsProvider.Type(), "[", existsProvider.Tag(), "]")
145+
}
146+
}
147+
existsIndex := common.Index(m.providers, func(it adapter.CertificateProviderService) bool {
148+
return it == existsProvider
149+
})
150+
if existsIndex == -1 {
151+
panic("invalid certificate provider index")
152+
}
153+
m.providers = append(m.providers[:existsIndex], m.providers[existsIndex+1:]...)
154+
}
155+
m.providers = append(m.providers, provider)
156+
m.providerByTag[tag] = provider
157+
return nil
158+
}

adapter/certificate/registry.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package certificate
2+
3+
import (
4+
"context"
5+
"sync"
6+
7+
"github.com/sagernet/sing-box/adapter"
8+
"github.com/sagernet/sing-box/log"
9+
"github.com/sagernet/sing/common"
10+
E "github.com/sagernet/sing/common/exceptions"
11+
)
12+
13+
type ConstructorFunc[T any] func(ctx context.Context, logger log.ContextLogger, tag string, options T) (adapter.CertificateProviderService, error)
14+
15+
func Register[Options any](registry *Registry, providerType string, constructor ConstructorFunc[Options]) {
16+
registry.register(providerType, func() any {
17+
return new(Options)
18+
}, func(ctx context.Context, logger log.ContextLogger, tag string, rawOptions any) (adapter.CertificateProviderService, error) {
19+
var options *Options
20+
if rawOptions != nil {
21+
options = rawOptions.(*Options)
22+
}
23+
return constructor(ctx, logger, tag, common.PtrValueOrDefault(options))
24+
})
25+
}
26+
27+
var _ adapter.CertificateProviderRegistry = (*Registry)(nil)
28+
29+
type (
30+
optionsConstructorFunc func() any
31+
constructorFunc func(ctx context.Context, logger log.ContextLogger, tag string, options any) (adapter.CertificateProviderService, error)
32+
)
33+
34+
type Registry struct {
35+
access sync.Mutex
36+
optionsType map[string]optionsConstructorFunc
37+
constructor map[string]constructorFunc
38+
}
39+
40+
func NewRegistry() *Registry {
41+
return &Registry{
42+
optionsType: make(map[string]optionsConstructorFunc),
43+
constructor: make(map[string]constructorFunc),
44+
}
45+
}
46+
47+
func (m *Registry) CreateOptions(providerType string) (any, bool) {
48+
m.access.Lock()
49+
defer m.access.Unlock()
50+
optionsConstructor, loaded := m.optionsType[providerType]
51+
if !loaded {
52+
return nil, false
53+
}
54+
return optionsConstructor(), true
55+
}
56+
57+
func (m *Registry) Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (adapter.CertificateProviderService, error) {
58+
m.access.Lock()
59+
defer m.access.Unlock()
60+
constructor, loaded := m.constructor[providerType]
61+
if !loaded {
62+
return nil, E.New("certificate provider type not found: " + providerType)
63+
}
64+
return constructor(ctx, logger, tag, options)
65+
}
66+
67+
func (m *Registry) register(providerType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
68+
m.access.Lock()
69+
defer m.access.Unlock()
70+
m.optionsType[providerType] = optionsConstructor
71+
m.constructor[providerType] = constructor
72+
}

adapter/certificate_provider.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package adapter
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
7+
"github.com/sagernet/sing-box/log"
8+
"github.com/sagernet/sing-box/option"
9+
)
10+
11+
type CertificateProvider interface {
12+
GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
13+
}
14+
15+
type ACMECertificateProvider interface {
16+
CertificateProvider
17+
GetACMENextProtos() []string
18+
}
19+
20+
type CertificateProviderService interface {
21+
Lifecycle
22+
Type() string
23+
Tag() string
24+
CertificateProvider
25+
}
26+
27+
type CertificateProviderRegistry interface {
28+
option.CertificateProviderOptionsRegistry
29+
Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) (CertificateProviderService, error)
30+
}
31+
32+
type CertificateProviderManager interface {
33+
Lifecycle
34+
CertificateProviders() []CertificateProviderService
35+
Get(tag string) (CertificateProviderService, bool)
36+
Remove(tag string) error
37+
Create(ctx context.Context, logger log.ContextLogger, tag string, providerType string, options any) error
38+
}

0 commit comments

Comments
 (0)