As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit 1c216d02 authored by Ash McKenzie's avatar Ash McKenzie
Browse files

Fix data race with ctxWithLogMetadata

parent 5e521fe0
No related branches found
No related tags found
Loading
Loading
Loading
@@ -64,7 +64,12 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl
go c.sendKeepAliveMsg(ctx, sconn, ticker)
}
ctxWithLogMetadata := c.handleRequests(ctx, sconn, chans, handler)
ctxWithLogMetadataChan := make(chan context.Context)
defer close(ctxWithLogMetadataChan)
go c.handleRequests(ctx, sconn, chans, ctxWithLogMetadataChan, handler)
ctxWithLogMetadata := <-ctxWithLogMetadataChan
reason := sconn.Wait()
log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done")
Loading
Loading
@@ -96,23 +101,25 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi
return sconn, chans, err
}
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) context.Context {
ctxWithLogMetadata := ctx
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, ctxWithLogMetadataChan chan<- context.Context, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
if newChannel.ChannelType() != "session" {
ctxlog.Info("connection: handleRequests: unknown channel type")
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
if !c.concurrentSessions.TryAcquire(1) {
ctxlog.Info("connection: handleRequests: too many concurrent sessions")
newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions")
metrics.SshdHitMaxSessions.Inc()
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed")
Loading
Loading
@@ -137,11 +144,12 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn,
}()
metrics.SliSshdSessionsTotal.Inc()
ctxWithLogMetadata, err = handler(ctx, sconn, channel, requests)
ctxWithLogMetadata, err := handler(ctx, sconn, channel, requests)
if err != nil {
c.trackError(ctxlog, err)
}
ctxWithLogMetadataChan <- ctxWithLogMetadata
}()
}
Loading
Loading
@@ -152,8 +160,6 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn,
ctx, cancel := context.WithTimeout(ctx, EOFTimeout)
defer cancel()
c.concurrentSessions.Acquire(ctx, c.maxSessions)
return ctxWithLogMetadata
}
func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) {
Loading
Loading
Loading
Loading
@@ -81,23 +81,25 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo
return true, nil, nil
}
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel, chan<- context.Context) {
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}}
conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)}
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
return conn, chans
ctxWithLogMetadataChan := make(chan context.Context)
return conn, chans, ctxWithLogMetadataChan
}
func TestPanicDuringSessionIsRecovered(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
numSessions := 0
require.NotPanics(t, func() {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
numSessions += 1
close(chans)
panic("This is a panic")
Loading
Loading
@@ -112,10 +114,10 @@ func TestUnknownChannelType(t *testing.T) {
defer close(rejectCh)
newChannel := &fakeNewChannel{channelType: "unknown session", rejectCh: rejectCh}
conn, chans := setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
go func() {
conn.handleRequests(context.Background(), nil, chans, nil)
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, nil)
}()
rejectionData := <-rejectCh
Loading
Loading
@@ -129,13 +131,13 @@ func TestTooManySessions(t *testing.T) {
defer close(rejectCh)
newChannel := &fakeNewChannel{channelType: "session", rejectCh: rejectCh}
conn, chans := setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return ctx, nil
})
Loading
Loading
@@ -147,18 +149,17 @@ func TestTooManySessions(t *testing.T) {
func TestAcceptSessionSucceeds(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
returnedCtx := conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
channelHandled = true
close(chans)
return ctx, nil
})
require.True(t, channelHandled)
require.NotNil(t, returnedCtx)
}
func TestAcceptSessionFails(t *testing.T) {
Loading
Loading
@@ -167,12 +168,12 @@ func TestAcceptSessionFails(t *testing.T) {
acceptErr := errors.New("some failure")
newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr}
conn, chans := setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
go func() {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
channelHandled = true
return ctx, nil
})
Loading
Loading
@@ -206,10 +207,10 @@ func TestSessionsMetrics(t *testing.T) {
initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal)
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
close(chans)
return ctx, errors.New("custom error")
})
Loading
Loading
@@ -228,11 +229,11 @@ func TestSessionsMetrics(t *testing.T) {
{"not our ref", grpcstatus.Error(grpccodes.Internal, `rpc error: code = Internal desc = cmd wait: exit status 128, stderr: "fatal: git upload-pack: not our ref 9106d18f6a1b8022f6517f479696f3e3ea5e68c1"`)},
} {
t.Run(ignoredError.desc, func(t *testing.T) {
conn, chans = setup(1, newChannel)
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
ignored := ignoredError.err
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
close(chans)
return ctx, ignored
})
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment