Skip to content

Commit 125397b

Browse files
Support delayed registration for AWS KWOK
1 parent 79eeadc commit 125397b

File tree

3 files changed

+98
-30
lines changed

3 files changed

+98
-30
lines changed

kwok/ec2/ec2.go

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,13 @@ import (
3535
corev1 "k8s.io/api/core/v1"
3636
"k8s.io/apimachinery/pkg/api/equality"
3737
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
38-
"k8s.io/client-go/rest"
3938
"k8s.io/client-go/util/workqueue"
4039
"k8s.io/utils/clock"
40+
"k8s.io/utils/set"
4141
"sigs.k8s.io/controller-runtime/pkg/client"
4242
"sigs.k8s.io/controller-runtime/pkg/log"
4343
"sigs.k8s.io/karpenter/kwok/apis/v1alpha1"
4444
v1 "sigs.k8s.io/karpenter/pkg/apis/v1"
45-
"sigs.k8s.io/karpenter/pkg/cloudprovider"
4645

4746
k8serrors "k8s.io/apimachinery/pkg/api/errors"
4847

@@ -64,13 +63,16 @@ type Client struct {
6463
subnets []ec2types.Subnet
6564
strategy strategy.Strategy
6665

67-
instances sync.Map
66+
instances sync.Map
67+
instanceLaunchCancels sync.Map
68+
69+
backupCompleted chan struct{}
6870

6971
launchTemplates sync.Map
7072
launchTemplateNameToID sync.Map
7173
}
7274

73-
func NewClient(region, namespace string, ec2Client *ec2.Client, rateLimiterProvider RateLimiterProvider, strategy strategy.Strategy, kubeClient client.Client, clk clock.Clock, cfg *rest.Config) *Client {
75+
func NewClient(region, namespace string, ec2Client *ec2.Client, rateLimiterProvider RateLimiterProvider, strategy strategy.Strategy, kubeClient client.Client, clk clock.Clock) *Client {
7476
var instanceTypes []ec2types.InstanceTypeInfo
7577
instanceTypesPaginator := ec2.NewDescribeInstanceTypesPaginator(ec2Client, &ec2.DescribeInstanceTypesInput{
7678
MaxResults: aws.Int32(100),
@@ -100,18 +102,25 @@ func NewClient(region, namespace string, ec2Client *ec2.Client, rateLimiterProvi
100102
subnets: subnets,
101103
strategy: strategy,
102104

103-
instances: sync.Map{},
105+
instances: sync.Map{},
106+
instanceLaunchCancels: sync.Map{},
107+
108+
backupCompleted: make(chan struct{}),
104109

105110
launchTemplates: sync.Map{},
106111
launchTemplateNameToID: sync.Map{},
107112
}
108-
c.readBackup(context.Background(), cfg)
109113
return c
110114
}
111115

112-
func (c *Client) readBackup(ctx context.Context, cfg *rest.Config) {
116+
func (c *Client) ReadBackup(ctx context.Context) {
113117
configMaps := &corev1.ConfigMapList{}
114-
lo.Must0(client.IgnoreNotFound(lo.Must(client.New(cfg, client.Options{})).List(ctx, configMaps, client.InNamespace(c.namespace))))
118+
lo.Must0(c.kubeClient.List(ctx, configMaps, client.InNamespace(c.namespace)))
119+
120+
nodeList := &corev1.NodeList{}
121+
lo.Must0(c.kubeClient.List(ctx, nodeList, client.MatchingLabels{v1alpha1.KwokLabelKey: v1alpha1.KwokLabelValue}))
122+
123+
instanceIDs := set.New[string](lo.Map(nodeList.Items, func(n corev1.Node, _ int) string { return lo.Must(utils.ParseInstanceID(n.Spec.ProviderID)) })...)
115124

116125
configMaps.Items = lo.Filter(configMaps.Items, func(c corev1.ConfigMap, _ int) bool {
117126
return strings.Contains(c.Name, "kwok-aws-instances-")
@@ -123,11 +132,17 @@ func (c *Client) readBackup(ctx context.Context, cfg *rest.Config) {
123132
lo.Must0(json.Unmarshal([]byte(cm.Data["instances"]), &instances))
124133
for _, instance := range instances {
125134
c.instances.Store(lo.FromPtr(instance.InstanceId), instance)
135+
// Register nodes immediately if we killed the KWOK controller before actually registering the node
136+
if !instanceIDs.Has(lo.FromPtr(instance.InstanceId)) {
137+
log.FromContext(ctx).WithValues("instance-id", lo.FromPtr(instance.InstanceId)).Info("creating node for instance id")
138+
lo.Must0(c.kubeClient.Create(ctx, c.toNode(ctx, instance)))
139+
}
126140
}
127141
total += len(instances)
128142
}
129143
}
130144
log.FromContext(ctx).WithValues("count", total).Info("loaded instances from backup")
145+
close(c.backupCompleted)
131146
}
132147

133148
//nolint:gocyclo
@@ -175,7 +190,7 @@ func (c *Client) backupInstances(ctx context.Context) error {
175190
numConfigMaps := int(math.Ceil(float64(len(instances)) / float64(500)))
176191
if numConfigMaps < len(configMaps.Items) {
177192
errs := make([]error, numConfigMaps)
178-
workqueue.ParallelizeUntil(ctx, 10, len(configMaps.Items)-numConfigMaps, func(i int) {
193+
workqueue.ParallelizeUntil(ctx, len(configMaps.Items)-numConfigMaps, len(configMaps.Items)-numConfigMaps, func(i int) {
179194
if err := c.kubeClient.Delete(ctx, &configMaps.Items[len(configMaps.Items)-i-1]); client.IgnoreNotFound(err) != nil {
180195
errs[i] = fmt.Errorf("deleting configmap %q, %w", configMaps.Items[len(configMaps.Items)-i-1].Name, err)
181196
}
@@ -186,7 +201,7 @@ func (c *Client) backupInstances(ctx context.Context) error {
186201
}
187202

188203
errs := make([]error, numConfigMaps)
189-
workqueue.ParallelizeUntil(ctx, 10, numConfigMaps, func(i int) {
204+
workqueue.ParallelizeUntil(ctx, numConfigMaps, numConfigMaps, func(i int) {
190205
cm := &corev1.ConfigMap{
191206
ObjectMeta: metav1.ObjectMeta{
192207
Name: fmt.Sprintf("kwok-aws-instances-%d", i),
@@ -218,13 +233,14 @@ func (c *Client) backupInstances(ctx context.Context) error {
218233

219234
// StartBackupThread initiates the thread that is responsible for storing instances in ConfigMaps on the cluster
220235
func (c *Client) StartBackupThread(ctx context.Context) {
236+
<-c.backupCompleted
221237
for {
222238
if err := c.backupInstances(ctx); err != nil {
223239
log.FromContext(ctx).Error(err, "unable to backup instances")
224240
continue
225241
}
226242
select {
227-
case <-time.After(time.Second * 5):
243+
case <-time.After(time.Second):
228244
case <-ctx.Done():
229245
return
230246
}
@@ -233,6 +249,7 @@ func (c *Client) StartBackupThread(ctx context.Context) {
233249

234250
// StartKillNodeThread initiates the thread that is responsible for killing nodes on the cluster that no longer have an instance representation (similar to CCM)
235251
func (c *Client) StartKillNodeThread(ctx context.Context) {
252+
<-c.backupCompleted
236253
for {
237254
nodes := &corev1.NodeList{}
238255
if err := c.kubeClient.List(ctx, nodes, client.MatchingLabels{v1alpha1.KwokLabelKey: v1alpha1.KwokLabelValue}); err != nil {
@@ -276,6 +293,7 @@ func removeNullFields(bytes []byte) []byte {
276293

277294
//nolint:gocyclo
278295
func (c *Client) DescribeLaunchTemplates(_ context.Context, input *ec2.DescribeLaunchTemplatesInput, _ ...func(*ec2.Options)) (*ec2.DescribeLaunchTemplatesOutput, error) {
296+
<-c.backupCompleted
279297
if !c.rateLimiterProvider.DescribeLaunchTemplates().TryAccept() {
280298
return nil, &smithy.GenericAPIError{
281299
Code: errors.RateLimitingErrorCode,
@@ -372,6 +390,7 @@ func (c *Client) DescribeLaunchTemplates(_ context.Context, input *ec2.DescribeL
372390

373391
//nolint:gocyclo
374392
func (c *Client) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, _ ...func(*ec2.Options)) (*ec2.CreateFleetOutput, error) {
393+
<-c.backupCompleted
375394
if !c.rateLimiterProvider.CreateFleet().TryAccept() {
376395
return nil, &smithy.GenericAPIError{
377396
Code: errors.RateLimitingErrorCode,
@@ -586,15 +605,21 @@ func (c *Client) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, _
586605
VpcId: subnet.VpcId,
587606
}
588607
c.instances.Store(lo.FromPtr(instance.InstanceId), instance)
608+
launchCtx, cancel := context.WithCancel(ctx)
609+
c.instanceLaunchCancels.Store(lo.FromPtr(instance.InstanceId), cancel)
589610

590-
// Create the Node through the instance launch
591-
// TODO: Eventually support delayed registration
592-
nodePoolNameTag, _ := lo.Find(instance.Tags, func(t ec2types.Tag) bool {
593-
return lo.FromPtr(t.Key) == v1.NodePoolLabelKey
594-
})
595-
if err := c.kubeClient.Create(ctx, toNode(lo.FromPtr(instance.InstanceId), lo.FromPtr(nodePoolNameTag.Value), it, lo.FromPtr(subnet.AvailabilityZone), v1.CapacityTypeOnDemand)); err != nil {
596-
return nil, fmt.Errorf("creating node, %w", err)
597-
}
611+
go func() {
612+
select {
613+
case <-launchCtx.Done():
614+
return
615+
// This is meant to simulate instance startup time
616+
case <-c.clock.After(30 * time.Second):
617+
}
618+
if err := c.kubeClient.Create(launchCtx, c.toNode(ctx, instance)); err != nil {
619+
c.instances.Delete(lo.FromPtr(instance.InstanceId))
620+
c.instanceLaunchCancels.Delete(lo.FromPtr(instance.InstanceId))
621+
}
622+
}()
598623
fleetInstances = append(fleetInstances, ec2types.CreateFleetInstance{
599624
InstanceIds: []string{lo.FromPtr(instance.InstanceId)},
600625
InstanceType: instance.InstanceType,
@@ -628,6 +653,7 @@ func (c *Client) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, _
628653
}
629654

630655
func (c *Client) TerminateInstances(_ context.Context, input *ec2.TerminateInstancesInput, _ ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) {
656+
<-c.backupCompleted
631657
if !c.rateLimiterProvider.TerminateInstances().TryAccept() {
632658
return nil, &smithy.GenericAPIError{
633659
Code: errors.RateLimitingErrorCode,
@@ -644,6 +670,9 @@ func (c *Client) TerminateInstances(_ context.Context, input *ec2.TerminateInsta
644670

645671
for _, id := range input.InstanceIds {
646672
c.instances.Delete(id)
673+
if cancel, ok := c.instanceLaunchCancels.LoadAndDelete(id); ok {
674+
cancel.(context.CancelFunc)()
675+
}
647676
}
648677
return &ec2.TerminateInstancesOutput{
649678
TerminatingInstances: lo.Map(input.InstanceIds, func(id string, _ int) ec2types.InstanceStateChange {
@@ -663,6 +692,7 @@ func (c *Client) TerminateInstances(_ context.Context, input *ec2.TerminateInsta
663692
}
664693

665694
func (c *Client) DescribeInstances(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) {
695+
<-c.backupCompleted
666696
if !c.rateLimiterProvider.DescribeInstances().TryAccept() {
667697
return nil, &smithy.GenericAPIError{
668698
Code: errors.RateLimitingErrorCode,
@@ -713,6 +743,7 @@ func (c *Client) DescribeInstances(_ context.Context, input *ec2.DescribeInstanc
713743
}
714744

715745
func (c *Client) RunInstances(_ context.Context, input *ec2.RunInstancesInput, _ ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) {
746+
<-c.backupCompleted
716747
if !c.rateLimiterProvider.RunInstances().TryAccept() {
717748
return nil, &smithy.GenericAPIError{
718749
Code: errors.RateLimitingErrorCode,
@@ -733,6 +764,7 @@ func (c *Client) RunInstances(_ context.Context, input *ec2.RunInstancesInput, _
733764

734765
//nolint:gocyclo
735766
func (c *Client) CreateTags(_ context.Context, input *ec2.CreateTagsInput, _ ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) {
767+
<-c.backupCompleted
736768
if !c.rateLimiterProvider.CreateTags().TryAccept() {
737769
return nil, &smithy.GenericAPIError{
738770
Code: errors.RateLimitingErrorCode,
@@ -791,6 +823,7 @@ func (c *Client) CreateTags(_ context.Context, input *ec2.CreateTagsInput, _ ...
791823
}
792824

793825
func (c *Client) CreateLaunchTemplate(_ context.Context, input *ec2.CreateLaunchTemplateInput, _ ...func(*ec2.Options)) (*ec2.CreateLaunchTemplateOutput, error) {
826+
<-c.backupCompleted
794827
if !c.rateLimiterProvider.CreateLaunchTemplate().TryAccept() {
795828
return nil, &smithy.GenericAPIError{
796829
Code: errors.RateLimitingErrorCode,
@@ -823,6 +856,7 @@ func (c *Client) CreateLaunchTemplate(_ context.Context, input *ec2.CreateLaunch
823856
}
824857

825858
func (c *Client) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunchTemplateInput, _ ...func(*ec2.Options)) (*ec2.DeleteLaunchTemplateOutput, error) {
859+
<-c.backupCompleted
826860
if !c.rateLimiterProvider.DeleteLaunchTemplate().TryAccept() {
827861
return nil, &smithy.GenericAPIError{
828862
Code: errors.RateLimitingErrorCode,
@@ -862,7 +896,35 @@ func (c *Client) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunch
862896
}, nil
863897
}
864898

865-
func toNode(instanceID, nodePoolName string, instanceType *cloudprovider.InstanceType, zone, capacityType string) *corev1.Node {
899+
func (c *Client) toNode(ctx context.Context, instance ec2types.Instance) *corev1.Node {
900+
nodePoolNameTag, _ := lo.Find(instance.Tags, func(t ec2types.Tag) bool {
901+
return lo.FromPtr(t.Key) == v1.NodePoolLabelKey
902+
})
903+
subnet := lo.Must(lo.Find(c.subnets, func(s ec2types.Subnet) bool {
904+
return lo.FromPtr(s.SubnetId) == lo.FromPtr(instance.SubnetId)
905+
}))
906+
instanceTypeInfo := lo.Must(lo.Find(c.instanceTypes, func(i ec2types.InstanceTypeInfo) bool {
907+
return i.InstanceType == instance.InstanceType
908+
}))
909+
// TODO: We need to get the capacity and allocatable information from the userData
910+
it := instancetype.NewInstanceType(
911+
ctx,
912+
instanceTypeInfo,
913+
c.region,
914+
nil,
915+
nil,
916+
nil,
917+
nil,
918+
nil,
919+
nil,
920+
nil,
921+
nil,
922+
nil,
923+
nil,
924+
// TODO: Eventually support different AMIFamilies from userData
925+
"al2023",
926+
nil,
927+
)
866928
nodeName := fmt.Sprintf("%s-%d", strings.ReplaceAll(namesgenerator.GetRandomName(0), "_", "-"), rand.Uint32()) //nolint:gosec
867929
return &corev1.Node{
868930
ObjectMeta: metav1.ObjectMeta{
@@ -872,25 +934,25 @@ func toNode(instanceID, nodePoolName string, instanceType *cloudprovider.Instanc
872934
},
873935
// TODO: We can eventually add all the labels from the userData but for now we just add the NodePool labels
874936
Labels: map[string]string{
875-
corev1.LabelInstanceTypeStable: instanceType.Name,
937+
corev1.LabelInstanceTypeStable: it.Name,
876938
corev1.LabelHostname: nodeName,
877-
corev1.LabelTopologyRegion: instanceType.Requirements.Get(corev1.LabelTopologyRegion).Any(),
878-
corev1.LabelTopologyZone: zone,
879-
v1.CapacityTypeLabelKey: capacityType,
880-
corev1.LabelArchStable: instanceType.Requirements.Get(corev1.LabelArchStable).Any(),
939+
corev1.LabelTopologyRegion: it.Requirements.Get(corev1.LabelTopologyRegion).Any(),
940+
corev1.LabelTopologyZone: lo.FromPtr(subnet.AvailabilityZone),
941+
v1.CapacityTypeLabelKey: v1.CapacityTypeOnDemand,
942+
corev1.LabelArchStable: it.Requirements.Get(corev1.LabelArchStable).Any(),
881943
corev1.LabelOSStable: string(corev1.Linux),
882-
v1.NodePoolLabelKey: nodePoolName,
944+
v1.NodePoolLabelKey: lo.FromPtr(nodePoolNameTag.Value),
883945
v1alpha1.KwokLabelKey: v1alpha1.KwokLabelValue,
884946
v1alpha1.KwokPartitionLabelKey: "a",
885947
},
886948
},
887949
Spec: corev1.NodeSpec{
888-
ProviderID: fmt.Sprintf("kwok-aws:///%s/%s", zone, instanceID),
950+
ProviderID: fmt.Sprintf("kwok-aws:///%s/%s", lo.FromPtr(subnet.AvailabilityZone), lo.FromPtr(instance.InstanceId)),
889951
Taints: []corev1.Taint{v1.UnregisteredNoExecuteTaint},
890952
},
891953
Status: corev1.NodeStatus{
892-
Capacity: instanceType.Capacity,
893-
Allocatable: instanceType.Allocatable(),
954+
Capacity: it.Capacity,
955+
Allocatable: it.Allocatable(),
894956
Phase: corev1.NodePending,
895957
},
896958
}

kwok/main.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ func main() {
5959
<-op.Elected()
6060
op.EC2API.StartKillNodeThread(ctx)
6161
}()
62+
wg.Add(1)
63+
go func() {
64+
defer wg.Done()
65+
<-op.Elected()
66+
op.EC2API.ReadBackup(ctx)
67+
}()
6268

6369
op.
6470
WithControllers(ctx, corecontrollers.NewControllers(

kwok/operator/operator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont
103103
region := lo.Must(imds.NewFromConfig(cfg).GetRegion(ctx, nil))
104104
cfg.Region = region.Region
105105
}
106-
ec2api := kwokec2.NewClient(cfg.Region, option.MustGetEnv("SYSTEM_NAMESPACE"), ec2.NewFromConfig(cfg), kwokec2.NewNopRateLimiterProvider(), strategy.NewLowestPrice(pricing.NewAPI(cfg), ec2.NewFromConfig(cfg), cfg.Region), operator.GetClient(), operator.Clock, operator.GetConfig())
106+
ec2api := kwokec2.NewClient(cfg.Region, option.MustGetEnv("SYSTEM_NAMESPACE"), ec2.NewFromConfig(cfg), kwokec2.NewNopRateLimiterProvider(), strategy.NewLowestPrice(pricing.NewAPI(cfg), ec2.NewFromConfig(cfg), cfg.Region), operator.GetClient(), operator.Clock)
107107

108108
eksapi := eks.NewFromConfig(cfg)
109109
log.FromContext(ctx).WithValues("region", cfg.Region).V(1).Info("discovered region")

0 commit comments

Comments
 (0)