Skip to content
This repository was archived by the owner on Apr 25, 2025. It is now read-only.

Commit ab35fb8

Browse files
committed
[FAB-9601] Move cert pool wrapper into its own package
Change-Id: I8e549dc957454bb15692d9285d3949c0f1b8c815 Signed-off-by: Divyank Katira <Divyank.Katira@securekey.com>
1 parent 830bdea commit ab35fb8

File tree

4 files changed

+265
-204
lines changed

4 files changed

+265
-204
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
Copyright SecureKey Technologies Inc. All Rights Reserved.
3+
4+
SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
package tls
8+
9+
import (
10+
"crypto/x509"
11+
"sync"
12+
13+
"github.com/hyperledger/fabric-sdk-go/pkg/common/logging"
14+
)
15+
16+
var logger = logging.NewLogger("fabsdk/core")
17+
18+
// CertPool is a thread safe wrapper around the x509 standard library
19+
// cert pool implementation.
20+
type CertPool interface {
21+
// Get returns the cert pool, optionally adding the provided certs
22+
Get(certs ...*x509.Certificate) (*x509.CertPool, error)
23+
}
24+
25+
// certPool is a thread safe wrapper around the x509 standard library
26+
// cert pool implementation.
27+
// It optionally allows loading the system trust store.
28+
type certPool struct {
29+
useSystemCertPool bool
30+
certs []*x509.Certificate
31+
certPool *x509.CertPool
32+
certsByName map[string][]int
33+
lock sync.RWMutex
34+
}
35+
36+
// NewCertPool new CertPool implementation
37+
func NewCertPool(useSystemCertPool bool) CertPool {
38+
return &certPool{
39+
useSystemCertPool: useSystemCertPool,
40+
certsByName: make(map[string][]int),
41+
certPool: x509.NewCertPool(),
42+
}
43+
}
44+
45+
func (c *certPool) Get(certs ...*x509.Certificate) (*x509.CertPool, error) {
46+
c.lock.RLock()
47+
if len(certs) == 0 || c.containsCerts(certs...) {
48+
defer c.lock.RUnlock()
49+
return c.certPool, nil
50+
}
51+
c.lock.RUnlock()
52+
53+
// We have a cert we have not encountered before, recreate the cert pool
54+
certPool, err := c.loadSystemCertPool()
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
c.lock.Lock()
60+
defer c.lock.Unlock()
61+
62+
//add certs to SDK cert list
63+
for _, newCert := range certs {
64+
c.addCert(newCert)
65+
}
66+
//add all certs to cert pool
67+
for _, cert := range c.certs {
68+
certPool.AddCert(cert)
69+
}
70+
c.certPool = certPool
71+
72+
return c.certPool, nil
73+
}
74+
75+
func (c *certPool) addCert(newCert *x509.Certificate) {
76+
if newCert != nil && !c.containsCert(newCert) {
77+
n := len(c.certs)
78+
// Store cert
79+
c.certs = append(c.certs, newCert)
80+
// Store cert name index
81+
name := string(newCert.RawSubject)
82+
c.certsByName[name] = append(c.certsByName[name], n)
83+
}
84+
}
85+
86+
func (c *certPool) containsCert(newCert *x509.Certificate) bool {
87+
possibilities := c.certsByName[string(newCert.RawSubject)]
88+
for _, p := range possibilities {
89+
if c.certs[p].Equal(newCert) {
90+
return true
91+
}
92+
}
93+
94+
return false
95+
}
96+
97+
func (c *certPool) containsCerts(certs ...*x509.Certificate) bool {
98+
for _, cert := range certs {
99+
if cert != nil && !c.containsCert(cert) {
100+
return false
101+
}
102+
}
103+
return true
104+
}
105+
106+
func (c *certPool) loadSystemCertPool() (*x509.CertPool, error) {
107+
if !c.useSystemCertPool {
108+
return x509.NewCertPool(), nil
109+
}
110+
systemCertPool, err := x509.SystemCertPool()
111+
if err != nil {
112+
return nil, err
113+
}
114+
logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects()))
115+
116+
return systemCertPool, nil
117+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
Copyright SecureKey Technologies Inc. All Rights Reserved.
3+
4+
SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
package tls
8+
9+
import (
10+
"crypto/x509"
11+
"strconv"
12+
"testing"
13+
"time"
14+
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
var goodCert = &x509.Certificate{
20+
RawSubject: []byte("Good header"),
21+
Raw: []byte("Good cert"),
22+
}
23+
24+
func TestTLSCAConfig(t *testing.T) {
25+
tlsCertPool := NewCertPool(true).(*certPool)
26+
_, err := tlsCertPool.Get(goodCert)
27+
require.NoError(t, err)
28+
assert.Equal(t, true, tlsCertPool.useSystemCertPool)
29+
assert.NotNil(t, tlsCertPool.certPool)
30+
assert.NotNil(t, tlsCertPool.certsByName)
31+
32+
originalLength := len(tlsCertPool.certs)
33+
//Try again with same cert
34+
_, err = tlsCertPool.Get(goodCert)
35+
assert.NoError(t, err, "TLS CA cert pool fetch failed")
36+
assert.False(t, len(tlsCertPool.certs) > originalLength, "number of certs in cert list shouldn't accept duplicates")
37+
38+
// Test with system cert pool disabled
39+
tlsCertPool = NewCertPool(false).(*certPool)
40+
_, err = tlsCertPool.Get(goodCert)
41+
require.NoError(t, err)
42+
assert.Len(t, tlsCertPool.certs, 1)
43+
assert.Len(t, tlsCertPool.certPool.Subjects(), 1)
44+
}
45+
46+
func TestTLSCAPoolManyCerts(t *testing.T) {
47+
size := 50
48+
49+
tlsCertPool := NewCertPool(true).(*certPool)
50+
_, err := tlsCertPool.Get(goodCert)
51+
require.NoError(t, err)
52+
53+
pool, err := tlsCertPool.Get()
54+
assert.NoError(t, err)
55+
originalLen := len(pool.Subjects())
56+
57+
certs := createNCerts(size)
58+
pool, err = tlsCertPool.Get(certs[0])
59+
assert.NoError(t, err)
60+
assert.Len(t, pool.Subjects(), originalLen+1)
61+
62+
pool, err = tlsCertPool.Get(certs...)
63+
assert.NoError(t, err)
64+
assert.Len(t, pool.Subjects(), originalLen+size)
65+
}
66+
67+
func TestConcurrent(t *testing.T) {
68+
concurrency := 1000
69+
certs := createNCerts(concurrency)
70+
71+
tlsCertPool := NewCertPool(false).(*certPool)
72+
73+
writeDone := make(chan bool)
74+
readDone := make(chan bool)
75+
76+
for i := 0; i < concurrency; i++ {
77+
go func(c *x509.Certificate) {
78+
_, err := tlsCertPool.Get(c)
79+
assert.NoError(t, err)
80+
writeDone <- true
81+
}(certs[i])
82+
go func() {
83+
_, err := tlsCertPool.Get()
84+
assert.NoError(t, err)
85+
readDone <- true
86+
}()
87+
}
88+
89+
for i := 0; i < concurrency; i++ {
90+
select {
91+
case b := <-writeDone:
92+
assert.True(t, b)
93+
case <-time.After(time.Second * 10):
94+
t.Fatalf("Timed out waiting for write %d", i)
95+
}
96+
97+
select {
98+
case b := <-readDone:
99+
assert.True(t, b)
100+
case <-time.After(time.Second * 10):
101+
t.Fatalf("Timed out waiting for read %d", i)
102+
}
103+
}
104+
105+
assert.Len(t, tlsCertPool.certs, concurrency)
106+
assert.Len(t, tlsCertPool.certPool.Subjects(), concurrency)
107+
}
108+
109+
func createNCerts(n int) []*x509.Certificate {
110+
var certs []*x509.Certificate
111+
for i := 0; i < n; i++ {
112+
cert := &x509.Certificate{
113+
RawSubject: []byte(strconv.Itoa(i)),
114+
Raw: []byte(strconv.Itoa(i)),
115+
}
116+
certs = append(certs, cert)
117+
}
118+
return certs
119+
}
120+
121+
func BenchmarkTLSCertPool(b *testing.B) {
122+
tlsCertPool := NewCertPool(true).(*certPool)
123+
124+
for n := 0; n < b.N; n++ {
125+
tlsCertPool.Get()
126+
}
127+
}
128+
129+
func BenchmarkTLSCertPoolSameCert(b *testing.B) {
130+
tlsCertPool := NewCertPool(true).(*certPool)
131+
132+
for n := 0; n < b.N; n++ {
133+
tlsCertPool.Get(goodCert)
134+
}
135+
}
136+
137+
func BenchmarkTLSCertPoolDifferentCert(b *testing.B) {
138+
tlsCertPool := NewCertPool(true).(*certPool)
139+
certs := createNCerts(b.N)
140+
141+
for n := 0; n < b.N; n++ {
142+
tlsCertPool.Get(certs[n])
143+
}
144+
}

pkg/fab/endpointconfig.go

Lines changed: 4 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ import (
1515
"sort"
1616
"strconv"
1717
"strings"
18-
"sync"
1918
"time"
2019

2120
"github.com/hyperledger/fabric-sdk-go/pkg/common/errors/status"
2221
"github.com/hyperledger/fabric-sdk-go/pkg/common/logging"
2322
"github.com/hyperledger/fabric-sdk-go/pkg/common/providers/core"
2423
"github.com/hyperledger/fabric-sdk-go/pkg/common/providers/fab"
2524
"github.com/hyperledger/fabric-sdk-go/pkg/common/providers/msp"
25+
commtls "github.com/hyperledger/fabric-sdk-go/pkg/core/config/comm/tls"
2626
"github.com/hyperledger/fabric-sdk-go/pkg/core/config/cryptoutil"
2727
"github.com/hyperledger/fabric-sdk-go/pkg/core/config/endpoint"
2828
"github.com/hyperledger/fabric-sdk-go/pkg/core/config/lookup"
@@ -60,8 +60,6 @@ const (
6060
func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, error) {
6161
config := &EndpointConfig{
6262
backend: lookup.New(coreBackend),
63-
tlsCertsByName: make(map[string][]int),
64-
tlsCertPool: x509.NewCertPool(),
6563
peerMatchers: make(map[int]*regexp.Regexp),
6664
ordererMatchers: make(map[int]*regexp.Regexp),
6765
caMatchers: make(map[int]*regexp.Regexp),
@@ -72,6 +70,7 @@ func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, erro
7270
return nil, errors.WithMessage(err, "network configuration load failed")
7371
}
7472

73+
config.tlsCertPool = commtls.NewCertPool(config.backend.GetBool("client.tlsCerts.systemCertPool"))
7574
// preemptively add all TLS certs to cert pool as adding them at request time
7675
// is expensive
7776
certs, err := config.loadTLSCerts()
@@ -94,16 +93,13 @@ func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, erro
9493
// EndpointConfig represents the endpoint configuration for the client
9594
type EndpointConfig struct {
9695
backend *lookup.ConfigLookup
97-
tlsCerts []*x509.Certificate
9896
networkConfig *fab.NetworkConfig
97+
tlsCertPool commtls.CertPool
9998
networkConfigCached bool
100-
tlsCertPool *x509.CertPool
10199
peerMatchers map[int]*regexp.Regexp
102100
ordererMatchers map[int]*regexp.Regexp
103101
caMatchers map[int]*regexp.Regexp
104102
channelMatchers map[int]*regexp.Regexp
105-
tlsCertsByName map[string][]int
106-
certPoolLock sync.RWMutex
107103
}
108104

109105
// Timeout reads timeouts for the given timeout type, if type is not found in the config
@@ -464,33 +460,7 @@ func (c *EndpointConfig) ChannelOrderers(name string) ([]fab.OrdererConfig, erro
464460
// TLSCACertPool returns the configured cert pool. If a certConfig
465461
// is provided, the certficate is added to the pool
466462
func (c *EndpointConfig) TLSCACertPool(certs ...*x509.Certificate) (*x509.CertPool, error) {
467-
c.certPoolLock.RLock()
468-
if len(certs) == 0 || c.containsCerts(certs...) {
469-
defer c.certPoolLock.RUnlock()
470-
return c.tlsCertPool, nil
471-
}
472-
c.certPoolLock.RUnlock()
473-
474-
// We have a cert we have not encountered before, recreate the cert pool
475-
tlsCertPool, err := c.loadSystemCertPool()
476-
if err != nil {
477-
return nil, err
478-
}
479-
480-
c.certPoolLock.Lock()
481-
defer c.certPoolLock.Unlock()
482-
483-
//add certs to SDK cert list
484-
for _, newCert := range certs {
485-
c.addCert(newCert)
486-
}
487-
//add all certs to cert pool
488-
for _, cert := range c.tlsCerts {
489-
tlsCertPool.AddCert(cert)
490-
}
491-
c.tlsCertPool = tlsCertPool
492-
493-
return c.tlsCertPool, nil
463+
return c.tlsCertPool.Get(certs...)
494464
}
495465

496466
// EventServiceType returns the type of event service client to use
@@ -1107,50 +1077,6 @@ func (c *EndpointConfig) loadTLSCerts() ([]*x509.Certificate, error) {
11071077
return certs, nil
11081078
}
11091079

1110-
func (c *EndpointConfig) addCert(newCert *x509.Certificate) {
1111-
if newCert != nil && !c.containsCert(newCert) {
1112-
n := len(c.tlsCerts)
1113-
// Store cert
1114-
c.tlsCerts = append(c.tlsCerts, newCert)
1115-
// Store cert name index
1116-
name := string(newCert.RawSubject)
1117-
c.tlsCertsByName[name] = append(c.tlsCertsByName[name], n)
1118-
}
1119-
}
1120-
1121-
func (c *EndpointConfig) containsCert(newCert *x509.Certificate) bool {
1122-
possibilities := c.tlsCertsByName[string(newCert.RawSubject)]
1123-
for _, p := range possibilities {
1124-
if c.tlsCerts[p].Equal(newCert) {
1125-
return true
1126-
}
1127-
}
1128-
1129-
return false
1130-
}
1131-
1132-
func (c *EndpointConfig) containsCerts(certs ...*x509.Certificate) bool {
1133-
for _, cert := range certs {
1134-
if cert != nil && !c.containsCert(cert) {
1135-
return false
1136-
}
1137-
}
1138-
return true
1139-
}
1140-
1141-
func (c *EndpointConfig) loadSystemCertPool() (*x509.CertPool, error) {
1142-
if !c.backend.GetBool("client.tlsCerts.systemCertPool") {
1143-
return x509.NewCertPool(), nil
1144-
}
1145-
systemCertPool, err := x509.SystemCertPool()
1146-
if err != nil {
1147-
return nil, err
1148-
}
1149-
logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects()))
1150-
1151-
return systemCertPool, nil
1152-
}
1153-
11541080
// Client returns the Client config
11551081
func (c *EndpointConfig) client() (*msp.ClientConfig, error) {
11561082
config, err := c.NetworkConfig()

0 commit comments

Comments
 (0)