Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion server/cmd/api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"sync"
"time"

"github.com/hashicorp/go-multierror"
"github.com/onkernel/kernel-images/server/lib/devtoolsproxy"
"github.com/onkernel/kernel-images/server/lib/logger"
"github.com/onkernel/kernel-images/server/lib/nekoclient"
Expand Down Expand Up @@ -297,6 +299,48 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ
return oapi.ListRecorders200JSONResponse(infos), nil
}

// killAllProcesses sends SIGKILL to every tracked process that is still running.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is new behavior — we'll need to reach out to heavy browser pool users and make sure they don't depend on process execs carrying over between session re-use.

func (s *ApiService) killAllProcesses(ctx context.Context) error {
log := logger.FromContext(ctx)
s.procMu.RLock()
defer s.procMu.RUnlock()

var result *multierror.Error
for id, h := range s.procs {
if h.state() != "running" {
continue
}
if h.cmd.Process == nil {
continue
}
// supervisorctl handles the lifecycle of long running processes so we don't want to kill
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure we should make an exception for supervisorctl here. if someone sent a process exec for a supervisorctl command it shouldn't matter — we're going to hard reset supervisor services anyway, right? or are we doing our own supervisorctl hard-reset of things like chromium during server shutdown? that feels a little weird but i could live with it — just want to make sure the reasoning is clear.

// any active supervisorctl processes. For example it is used to restart kernel-images-api
// and killing that process would break the restart process.
if filepath.Base(h.cmd.Path) == "supervisorctl" {
continue
}
if err := h.cmd.Process.Kill(); err != nil {
result = multierror.Append(result, fmt.Errorf("process %s: %w", id, err))
log.Error("failed to kill process", "process_id", id, "err", err)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
}
}
return result.ErrorOrNil()
}

func (s *ApiService) Shutdown(ctx context.Context) error {
return s.recordManager.StopAll(ctx)
var wg sync.WaitGroup
var killErr, stopErr error

wg.Add(2)
go func() {
defer wg.Done()
killErr = s.killAllProcesses(ctx)
}()
go func() {
defer wg.Done()
stopErr = s.recordManager.StopAll(ctx)
}()
wg.Wait()
Comment thread
cursor[bot] marked this conversation as resolved.

return multierror.Append(killErr, stopErr).ErrorOrNil()
}
3 changes: 3 additions & 0 deletions server/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ func main() {
defer shutdownCancel()
g, _ := errgroup.WithContext(shutdownCtx)

g.Go(func() error {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stz.Drain runs concurrently with srv.Shutdown here. this works but makes the state transitions harder to reason about — Drain races with in-flight requests' deferred Enable() calls. consider sequencing: drain HTTP servers first (letting all in-flight Enables run normally), then call stz.Drain after. at that point the controller is already at rest and Drain is just a safety net freeze. same outcome, easier to verify correctness.

return stz.Drain(shutdownCtx)
})
g.Go(func() error {
return srv.Shutdown(shutdownCtx)
})
Expand Down
2 changes: 2 additions & 0 deletions server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
github.com/glebarez/sqlite v1.11.0
github.com/go-chi/chi/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/hashicorp/go-multierror v1.1.1
github.com/kelseyhightower/envconfig v1.4.0
github.com/klauspost/compress v1.18.3
github.com/m1k1o/neko/server v0.0.0-20251008185748-46e2fc7d3866
Expand Down Expand Up @@ -51,6 +52,7 @@ require (
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions server/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
Expand Down
22 changes: 22 additions & 0 deletions server/lib/scaletozero/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@ package scaletozero

import (
"context"
"net"
"net/http"

"github.com/onkernel/kernel-images/server/lib/logger"
)

// Middleware returns a standard net/http middleware that disables scale-to-zero
// at the start of each request and re-enables it after the handler completes.
// Connections from loopback addresses are ignored and do not affect the
// scale-to-zero state.
func Middleware(ctrl Controller) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isLoopbackAddr(r.RemoteAddr) {
next.ServeHTTP(w, r)
return
}

if err := ctrl.Disable(r.Context()); err != nil {
logger.FromContext(r.Context()).Error("failed to disable scale-to-zero", "error", err)
http.Error(w, "failed to disable scale-to-zero", http.StatusInternalServerError)
Expand All @@ -23,3 +31,17 @@ func Middleware(ctrl Controller) func(http.Handler) http.Handler {
})
}
}

// isLoopbackAddr reports whether addr is a loopback address.
// addr may be an "ip:port" pair or a bare IP.
func isLoopbackAddr(addr string) bool {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback()
}
114 changes: 114 additions & 0 deletions server/lib/scaletozero/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package scaletozero

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{}
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.50:12345"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 1, mock.disableCalls)
assert.Equal(t, 1, mock.enableCalls)
}

func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
t.Parallel()

loopbackAddrs := []struct {
name string
addr string
}{
{"loopback-v4", "127.0.0.1:8080"},
{"loopback-v6", "[::1]:8080"},
}

for _, tc := range loopbackAddrs {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{}
var called bool
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = tc.addr
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.True(t, called, "handler should still be called")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 0, mock.disableCalls, "should not disable for loopback addr")
assert.Equal(t, 0, mock.enableCalls, "should not enable for loopback addr")
})
}
}

func TestMiddlewareDisableError(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{disableErr: assert.AnError}
var called bool
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.50:12345"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.False(t, called, "handler should not be called on disable error")
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, 0, mock.enableCalls)
}

func TestIsLoopbackAddr(t *testing.T) {
t.Parallel()

tests := []struct {
addr string
loopback bool
}{
// Loopback
{"127.0.0.1:80", true},
{"[::1]:80", true},
{"127.0.0.1", true},
{"::1", true},
// Non-loopback
{"10.0.0.1:80", false},
{"172.16.0.1:80", false},
{"192.168.1.1:80", false},
{"203.0.113.50:80", false},
{"8.8.8.8:53", false},
{"[2001:db8::1]:80", false},
// Unparseable
{"not-an-ip:80", false},
{"", false},
}

for _, tc := range tests {
t.Run(tc.addr, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.loopback, isLoopbackAddr(tc.addr))
})
}
}
Loading
Loading