Skip to content

Commit fde70eb

Browse files
Support delayed registration for AWS KWOK
1 parent 79eeadc commit fde70eb

File tree

3 files changed

+95
-29
lines changed

3 files changed

+95
-29
lines changed

kwok/ec2/ec2.go

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ import (
3838
"k8s.io/client-go/rest"
3939
"k8s.io/client-go/util/workqueue"
4040
"k8s.io/utils/clock"
41+
"k8s.io/utils/set"
4142
"sigs.k8s.io/controller-runtime/pkg/client"
4243
"sigs.k8s.io/controller-runtime/pkg/log"
4344
"sigs.k8s.io/karpenter/kwok/apis/v1alpha1"
4445
v1 "sigs.k8s.io/karpenter/pkg/apis/v1"
45-
"sigs.k8s.io/karpenter/pkg/cloudprovider"
4646

4747
k8serrors "k8s.io/apimachinery/pkg/api/errors"
4848

@@ -64,13 +64,16 @@ type Client struct {
6464
subnets []ec2types.Subnet
6565
strategy strategy.Strategy
6666

67-
instances sync.Map
67+
instances sync.Map
68+
instanceLaunchCancels sync.Map
69+
70+
backupCompleted chan struct{}
6871

6972
launchTemplates sync.Map
7073
launchTemplateNameToID sync.Map
7174
}
7275

73-
func NewClient(region, namespace string, ec2Client *ec2.Client, rateLimiterProvider RateLimiterProvider, strategy strategy.Strategy, kubeClient client.Client, clk clock.Clock, cfg *rest.Config) *Client {
76+
func NewClient(region, namespace string, ec2Client *ec2.Client, rateLimiterProvider RateLimiterProvider, strategy strategy.Strategy, kubeClient client.Client, clk clock.Clock) *Client {
7477
var instanceTypes []ec2types.InstanceTypeInfo
7578
instanceTypesPaginator := ec2.NewDescribeInstanceTypesPaginator(ec2Client, &ec2.DescribeInstanceTypesInput{
7679
MaxResults: aws.Int32(100),
@@ -100,18 +103,25 @@ func NewClient(region, namespace string, ec2Client *ec2.Client, rateLimiterProvi
100103
subnets: subnets,
101104
strategy: strategy,
102105

103-
instances: sync.Map{},
106+
instances: sync.Map{},
107+
instanceLaunchCancels: sync.Map{},
108+
109+
backupCompleted: make(chan struct{}),
104110

105111
launchTemplates: sync.Map{},
106112
launchTemplateNameToID: sync.Map{},
107113
}
108-
c.readBackup(context.Background(), cfg)
109114
return c
110115
}
111116

112-
func (c *Client) readBackup(ctx context.Context, cfg *rest.Config) {
117+
func (c *Client) ReadBackup(ctx context.Context) {
113118
configMaps := &corev1.ConfigMapList{}
114-
lo.Must0(client.IgnoreNotFound(lo.Must(client.New(cfg, client.Options{})).List(ctx, configMaps, client.InNamespace(c.namespace))))
119+
lo.Must0(c.kubeClient.List(ctx, configMaps, client.InNamespace(c.namespace)))
120+
121+
nodeList := &corev1.NodeList{}
122+
lo.Must0(c.kubeClient.List(ctx, nodeList, client.MatchingLabels{v1alpha1.KwokLabelKey: v1alpha1.KwokLabelValue}))
123+
124+
instanceIDs := set.New[string](lo.Map(nodeList.Items, func(n corev1.Node, _ int) string { return lo.Must(utils.ParseInstanceID(n.Spec.ProviderID)) })...)
115125

116126
configMaps.Items = lo.Filter(configMaps.Items, func(c corev1.ConfigMap, _ int) bool {
117127
return strings.Contains(c.Name, "kwok-aws-instances-")
@@ -123,11 +133,16 @@ func (c *Client) readBackup(ctx context.Context, cfg *rest.Config) {
123133
lo.Must0(json.Unmarshal([]byte(cm.Data["instances"]), &instances))
124134
for _, instance := range instances {
125135
c.instances.Store(lo.FromPtr(instance.InstanceId), instance)
136+
// Register nodes immediately if we killed the KWOK controller before actually registering the node
137+
if !instanceIDs.Has(lo.FromPtr(instance.InstanceId)) {
138+
lo.Must0(c.kubeClient.Create(ctx, c.toNode(ctx, instance)))
139+
}
126140
}
127141
total += len(instances)
128142
}
129143
}
130144
log.FromContext(ctx).WithValues("count", total).Info("loaded instances from backup")
145+
close(c.backupCompleted)
131146
}
132147

133148
//nolint:gocyclo
@@ -175,7 +190,7 @@ func (c *Client) backupInstances(ctx context.Context) error {
175190
numConfigMaps := int(math.Ceil(float64(len(instances)) / float64(500)))
176191
if numConfigMaps < len(configMaps.Items) {
177192
errs := make([]error, numConfigMaps)
178-
workqueue.ParallelizeUntil(ctx, 10, len(configMaps.Items)-numConfigMaps, func(i int) {
193+
workqueue.ParallelizeUntil(ctx, len(configMaps.Items)-numConfigMaps, len(configMaps.Items)-numConfigMaps, func(i int) {
179194
if err := c.kubeClient.Delete(ctx, &configMaps.Items[len(configMaps.Items)-i-1]); client.IgnoreNotFound(err) != nil {
180195
errs[i] = fmt.Errorf("deleting configmap %q, %w", configMaps.Items[len(configMaps.Items)-i-1].Name, err)
181196
}
@@ -186,7 +201,7 @@ func (c *Client) backupInstances(ctx context.Context) error {
186201
}
187202

188203
errs := make([]error, numConfigMaps)
189-
workqueue.ParallelizeUntil(ctx, 10, numConfigMaps, func(i int) {
204+
workqueue.ParallelizeUntil(ctx, numConfigMaps, numConfigMaps, func(i int) {
190205
cm := &corev1.ConfigMap{
191206
ObjectMeta: metav1.ObjectMeta{
192207
Name: fmt.Sprintf("kwok-aws-instances-%d", i),
@@ -224,7 +239,7 @@ func (c *Client) StartBackupThread(ctx context.Context) {
224239
continue
225240
}
226241
select {
227-
case <-time.After(time.Second * 5):
242+
case <-time.After(time.Second):
228243
case <-ctx.Done():
229244
return
230245
}
@@ -276,6 +291,7 @@ func removeNullFields(bytes []byte) []byte {
276291

277292
//nolint:gocyclo
278293
func (c *Client) DescribeLaunchTemplates(_ context.Context, input *ec2.DescribeLaunchTemplatesInput, _ ...func(*ec2.Options)) (*ec2.DescribeLaunchTemplatesOutput, error) {
294+
<-c.backupCompleted
279295
if !c.rateLimiterProvider.DescribeLaunchTemplates().TryAccept() {
280296
return nil, &smithy.GenericAPIError{
281297
Code: errors.RateLimitingErrorCode,
@@ -372,6 +388,7 @@ func (c *Client) DescribeLaunchTemplates(_ context.Context, input *ec2.DescribeL
372388

373389
//nolint:gocyclo
374390
func (c *Client) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, _ ...func(*ec2.Options)) (*ec2.CreateFleetOutput, error) {
391+
<-c.backupCompleted
375392
if !c.rateLimiterProvider.CreateFleet().TryAccept() {
376393
return nil, &smithy.GenericAPIError{
377394
Code: errors.RateLimitingErrorCode,
@@ -586,15 +603,21 @@ func (c *Client) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, _
586603
VpcId: subnet.VpcId,
587604
}
588605
c.instances.Store(lo.FromPtr(instance.InstanceId), instance)
606+
launchCtx, cancel := context.WithCancel(ctx)
607+
c.instanceLaunchCancels.Store(lo.FromPtr(instance.InstanceId), cancel)
589608

590-
// Create the Node through the instance launch
591-
// TODO: Eventually support delayed registration
592-
nodePoolNameTag, _ := lo.Find(instance.Tags, func(t ec2types.Tag) bool {
593-
return lo.FromPtr(t.Key) == v1.NodePoolLabelKey
594-
})
595-
if err := c.kubeClient.Create(ctx, toNode(lo.FromPtr(instance.InstanceId), lo.FromPtr(nodePoolNameTag.Value), it, lo.FromPtr(subnet.AvailabilityZone), v1.CapacityTypeOnDemand)); err != nil {
596-
return nil, fmt.Errorf("creating node, %w", err)
597-
}
609+
go func() {
610+
select {
611+
case <-launchCtx.Done():
612+
return
613+
// This is meant to simulate instance startup time
614+
case <-c.clock.After(30 * time.Second):
615+
}
616+
if err := c.kubeClient.Create(launchCtx, c.toNode(ctx, instance)); err != nil {
617+
c.instances.Delete(lo.FromPtr(instance.InstanceId))
618+
c.instanceLaunchCancels.Delete(lo.FromPtr(instance.InstanceId))
619+
}
620+
}()
598621
fleetInstances = append(fleetInstances, ec2types.CreateFleetInstance{
599622
InstanceIds: []string{lo.FromPtr(instance.InstanceId)},
600623
InstanceType: instance.InstanceType,
@@ -628,6 +651,7 @@ func (c *Client) CreateFleet(ctx context.Context, input *ec2.CreateFleetInput, _
628651
}
629652

630653
func (c *Client) TerminateInstances(_ context.Context, input *ec2.TerminateInstancesInput, _ ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) {
654+
<-c.backupCompleted
631655
if !c.rateLimiterProvider.TerminateInstances().TryAccept() {
632656
return nil, &smithy.GenericAPIError{
633657
Code: errors.RateLimitingErrorCode,
@@ -644,6 +668,9 @@ func (c *Client) TerminateInstances(_ context.Context, input *ec2.TerminateInsta
644668

645669
for _, id := range input.InstanceIds {
646670
c.instances.Delete(id)
671+
if cancel, ok := c.instanceLaunchCancels.LoadAndDelete(id); ok {
672+
cancel.(context.CancelFunc)()
673+
}
647674
}
648675
return &ec2.TerminateInstancesOutput{
649676
TerminatingInstances: lo.Map(input.InstanceIds, func(id string, _ int) ec2types.InstanceStateChange {
@@ -663,6 +690,7 @@ func (c *Client) TerminateInstances(_ context.Context, input *ec2.TerminateInsta
663690
}
664691

665692
func (c *Client) DescribeInstances(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) {
693+
<-c.backupCompleted
666694
if !c.rateLimiterProvider.DescribeInstances().TryAccept() {
667695
return nil, &smithy.GenericAPIError{
668696
Code: errors.RateLimitingErrorCode,
@@ -713,6 +741,7 @@ func (c *Client) DescribeInstances(_ context.Context, input *ec2.DescribeInstanc
713741
}
714742

715743
func (c *Client) RunInstances(_ context.Context, input *ec2.RunInstancesInput, _ ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) {
744+
<-c.backupCompleted
716745
if !c.rateLimiterProvider.RunInstances().TryAccept() {
717746
return nil, &smithy.GenericAPIError{
718747
Code: errors.RateLimitingErrorCode,
@@ -733,6 +762,7 @@ func (c *Client) RunInstances(_ context.Context, input *ec2.RunInstancesInput, _
733762

734763
//nolint:gocyclo
735764
func (c *Client) CreateTags(_ context.Context, input *ec2.CreateTagsInput, _ ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error) {
765+
<-c.backupCompleted
736766
if !c.rateLimiterProvider.CreateTags().TryAccept() {
737767
return nil, &smithy.GenericAPIError{
738768
Code: errors.RateLimitingErrorCode,
@@ -791,6 +821,7 @@ func (c *Client) CreateTags(_ context.Context, input *ec2.CreateTagsInput, _ ...
791821
}
792822

793823
func (c *Client) CreateLaunchTemplate(_ context.Context, input *ec2.CreateLaunchTemplateInput, _ ...func(*ec2.Options)) (*ec2.CreateLaunchTemplateOutput, error) {
824+
<-c.backupCompleted
794825
if !c.rateLimiterProvider.CreateLaunchTemplate().TryAccept() {
795826
return nil, &smithy.GenericAPIError{
796827
Code: errors.RateLimitingErrorCode,
@@ -823,6 +854,7 @@ func (c *Client) CreateLaunchTemplate(_ context.Context, input *ec2.CreateLaunch
823854
}
824855

825856
func (c *Client) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunchTemplateInput, _ ...func(*ec2.Options)) (*ec2.DeleteLaunchTemplateOutput, error) {
857+
<-c.backupCompleted
826858
if !c.rateLimiterProvider.DeleteLaunchTemplate().TryAccept() {
827859
return nil, &smithy.GenericAPIError{
828860
Code: errors.RateLimitingErrorCode,
@@ -862,7 +894,35 @@ func (c *Client) DeleteLaunchTemplate(_ context.Context, input *ec2.DeleteLaunch
862894
}, nil
863895
}
864896

865-
func toNode(instanceID, nodePoolName string, instanceType *cloudprovider.InstanceType, zone, capacityType string) *corev1.Node {
897+
func (c *Client) toNode(ctx context.Context, instance ec2types.Instance) *corev1.Node {
898+
nodePoolNameTag, _ := lo.Find(instance.Tags, func(t ec2types.Tag) bool {
899+
return lo.FromPtr(t.Key) == v1.NodePoolLabelKey
900+
})
901+
subnet := lo.Must(lo.Find(c.subnets, func(s ec2types.Subnet) bool {
902+
return lo.FromPtr(s.SubnetId) == lo.FromPtr(instance.SubnetId)
903+
}))
904+
instanceTypeInfo := lo.Must(lo.Find(c.instanceTypes, func(i ec2types.InstanceTypeInfo) bool {
905+
return i.InstanceType == instance.InstanceType
906+
}))
907+
// TODO: We need to get the capacity and allocatable information from the userData
908+
it := instancetype.NewInstanceType(
909+
ctx,
910+
instanceTypeInfo,
911+
c.region,
912+
nil,
913+
nil,
914+
nil,
915+
nil,
916+
nil,
917+
nil,
918+
nil,
919+
nil,
920+
nil,
921+
nil,
922+
// TODO: Eventually support different AMIFamilies from userData
923+
"al2023",
924+
nil,
925+
)
866926
nodeName := fmt.Sprintf("%s-%d", strings.ReplaceAll(namesgenerator.GetRandomName(0), "_", "-"), rand.Uint32()) //nolint:gosec
867927
return &corev1.Node{
868928
ObjectMeta: metav1.ObjectMeta{
@@ -872,25 +932,25 @@ func toNode(instanceID, nodePoolName string, instanceType *cloudprovider.Instanc
872932
},
873933
// TODO: We can eventually add all the labels from the userData but for now we just add the NodePool labels
874934
Labels: map[string]string{
875-
corev1.LabelInstanceTypeStable: instanceType.Name,
935+
corev1.LabelInstanceTypeStable: it.Name,
876936
corev1.LabelHostname: nodeName,
877-
corev1.LabelTopologyRegion: instanceType.Requirements.Get(corev1.LabelTopologyRegion).Any(),
878-
corev1.LabelTopologyZone: zone,
879-
v1.CapacityTypeLabelKey: capacityType,
880-
corev1.LabelArchStable: instanceType.Requirements.Get(corev1.LabelArchStable).Any(),
937+
corev1.LabelTopologyRegion: it.Requirements.Get(corev1.LabelTopologyRegion).Any(),
938+
corev1.LabelTopologyZone: lo.FromPtr(subnet.AvailabilityZone),
939+
v1.CapacityTypeLabelKey: v1.CapacityTypeOnDemand,
940+
corev1.LabelArchStable: it.Requirements.Get(corev1.LabelArchStable).Any(),
881941
corev1.LabelOSStable: string(corev1.Linux),
882-
v1.NodePoolLabelKey: nodePoolName,
942+
v1.NodePoolLabelKey: lo.FromPtr(nodePoolNameTag.Value),
883943
v1alpha1.KwokLabelKey: v1alpha1.KwokLabelValue,
884944
v1alpha1.KwokPartitionLabelKey: "a",
885945
},
886946
},
887947
Spec: corev1.NodeSpec{
888-
ProviderID: fmt.Sprintf("kwok-aws:///%s/%s", zone, instanceID),
948+
ProviderID: fmt.Sprintf("kwok-aws:///%s/%s", lo.FromPtr(subnet.AvailabilityZone), lo.FromPtr(instance.InstanceId)),
889949
Taints: []corev1.Taint{v1.UnregisteredNoExecuteTaint},
890950
},
891951
Status: corev1.NodeStatus{
892-
Capacity: instanceType.Capacity,
893-
Allocatable: instanceType.Allocatable(),
952+
Capacity: it.Capacity,
953+
Allocatable: it.Allocatable(),
894954
Phase: corev1.NodePending,
895955
},
896956
}

kwok/main.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ func main() {
5959
<-op.Elected()
6060
op.EC2API.StartKillNodeThread(ctx)
6161
}()
62+
wg.Add(1)
63+
go func() {
64+
defer wg.Done()
65+
<-op.Elected()
66+
op.EC2API.ReadBackup(ctx)
67+
}()
6268

6369
op.
6470
WithControllers(ctx, corecontrollers.NewControllers(

kwok/operator/operator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func NewOperator(ctx context.Context, operator *operator.Operator) (context.Cont
103103
region := lo.Must(imds.NewFromConfig(cfg).GetRegion(ctx, nil))
104104
cfg.Region = region.Region
105105
}
106-
ec2api := kwokec2.NewClient(cfg.Region, option.MustGetEnv("SYSTEM_NAMESPACE"), ec2.NewFromConfig(cfg), kwokec2.NewNopRateLimiterProvider(), strategy.NewLowestPrice(pricing.NewAPI(cfg), ec2.NewFromConfig(cfg), cfg.Region), operator.GetClient(), operator.Clock, operator.GetConfig())
106+
ec2api := kwokec2.NewClient(cfg.Region, option.MustGetEnv("SYSTEM_NAMESPACE"), ec2.NewFromConfig(cfg), kwokec2.NewNopRateLimiterProvider(), strategy.NewLowestPrice(pricing.NewAPI(cfg), ec2.NewFromConfig(cfg), cfg.Region), operator.GetClient(), operator.Clock)
107107

108108
eksapi := eks.NewFromConfig(cfg)
109109
log.FromContext(ctx).WithValues("region", cfg.Region).V(1).Info("discovered region")

0 commit comments

Comments
 (0)