@@ -55,6 +55,10 @@ const (
5555 schedulerKey = "scheduling-policy.noderesource.dev"
5656 // Deprecated: Prefix of the key used for scheduler attribute adjustment.
5757 oldSchedulerKey = "scheduling-policy.nri.io"
58+ // Prefix of the key used for memory policy adjustment.
59+ memoryPolicyKey = "memory-policy.noderesource.dev"
60+ // Deprecated: Prefix of the key used for memory policy adjustment.
61+ oldMemoryPolicyKey = "memory-policy.nri.io"
5862)
5963
6064var (
@@ -104,6 +108,13 @@ type scheduler struct {
104108 Period uint64 `json:"period"`
105109}
106110
111+ // memoryPolicy adjustment
112+ type memoryPolicy struct {
113+ Mode string `json:"mode"`
114+ Nodes string `json:"nodes"`
115+ Flags []string `json:"flags"`
116+ }
117+
107118// our injector plugin
108119type plugin struct {
109120 stub stub.Stub
@@ -135,6 +146,9 @@ func (p *plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *ap
135146 if err := adjustScheduler (pod , ctr , adjust ); err != nil {
136147 return nil , nil , err
137148 }
149+ if err := adjustMemoryPolicy (pod , ctr , adjust ); err != nil {
150+ return nil , nil , err
151+ }
138152
139153 if err := injectNetDevices (pod , ctr , adjust ); err != nil {
140154 return nil , nil , err
@@ -395,6 +409,53 @@ func parseScheduler(ctr string, annotations map[string]string) (*scheduler, erro
395409 return sch , nil
396410}
397411
412+ func adjustMemoryPolicy (pod * api.PodSandbox , ctr * api.Container , a * api.ContainerAdjustment ) error {
413+ pol , err := parseMemoryPolicy (ctr .Name , pod .Annotations )
414+ if err != nil {
415+ log .Errorf ("%s: invalid memory policy annotation: %v" ,
416+ containerName (pod , ctr ), err )
417+ return err
418+ }
419+
420+ if pol == nil {
421+ log .Debugf ("%s: no memory policy annotated..." , containerName (pod , ctr ))
422+ return nil
423+ }
424+
425+ if verbose {
426+ dump (containerName (pod , ctr ), "annotated memory policy" , pol )
427+ }
428+
429+ amp , err := pol .ToNRI ()
430+ if err != nil {
431+ log .Errorf ("%s: invalid memory policy annotation: %v" ,
432+ containerName (pod , ctr ), err )
433+ return fmt .Errorf ("invalid memory policy: %w" , err )
434+ }
435+
436+ a .SetLinuxMemoryPolicy (amp .Mode , amp .Nodes , amp .Flags ... )
437+ log .Infof ("%s: adjusted memory policy to %s..." , containerName (pod , ctr ), amp )
438+
439+ return nil
440+ }
441+
442+ func parseMemoryPolicy (ctr string , annotations map [string ]string ) (* memoryPolicy , error ) {
443+ var (
444+ policy = & memoryPolicy {}
445+ )
446+
447+ annotation := getAnnotation (annotations , memoryPolicyKey , oldMemoryPolicyKey , ctr )
448+ if annotation == nil {
449+ return nil , nil
450+ }
451+
452+ if err := yaml .Unmarshal (annotation , policy ); err != nil {
453+ return nil , fmt .Errorf ("invalid memory policy annotation %q: %w" , string (annotation ), err )
454+ }
455+
456+ return policy , nil
457+ }
458+
398459func getAnnotation (annotations map [string ]string , mainKey , oldKey , ctr string ) []byte {
399460 for _ , key := range []string {
400461 mainKey + "/container." + ctr ,
@@ -519,6 +580,46 @@ func (sc *scheduler) String() string {
519580 return s
520581}
521582
583+ // Convert memoryPolicy to the NRI API representation.
584+ func (mp * memoryPolicy ) ToNRI () (* api.LinuxMemoryPolicy , error ) {
585+ apiPol := & api.LinuxMemoryPolicy {
586+ Nodes : mp .Nodes ,
587+ }
588+
589+ mode , ok := api .MpolMode_value [mp .Mode ]
590+ if ! ok {
591+ return nil , fmt .Errorf ("invalid memory policy mode %q" , mp .Mode )
592+ }
593+ apiPol .Mode = api .MpolMode (mode )
594+
595+ for _ , f := range mp .Flags {
596+ flag , ok := api .MpolFlag_value [f ]
597+ if ! ok {
598+ return nil , fmt .Errorf ("invalid memory policy flag %q" , f )
599+ }
600+ apiPol .Flags = append (apiPol .Flags , api .MpolFlag (flag ))
601+ }
602+
603+ return apiPol , nil
604+ }
605+
606+ func (mp * memoryPolicy ) String () string {
607+ if mp == nil {
608+ return "<no memory policy>"
609+ }
610+
611+ s := fmt .Sprintf ("<memory policy mode=%s" , mp .Mode )
612+ if mp .Nodes != "" {
613+ s += fmt .Sprintf (", nodes=%s" , mp .Nodes )
614+ }
615+ if len (mp .Flags ) > 0 {
616+ s += fmt .Sprintf (", flags=%v" , mp .Flags )
617+ }
618+ s += ">"
619+
620+ return s
621+ }
622+
522623// Construct a container name for log messages.
523624func containerName (pod * api.PodSandbox , container * api.Container ) string {
524625 if pod != nil {
0 commit comments