As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit f5d0c1c3 authored by Patrick Bajao's avatar Patrick Bajao
Browse files

Merge branch 'id-return-metadata-context' into 'main'

Return metadata context without using channels

See merge request https://gitlab.com/gitlab-org/gitlab-shell/-/merge_requests/810



Merged-by: default avatarPatrick Bajao <ebajao@gitlab.com>
Approved-by: default avatarPatrick Bajao <ebajao@gitlab.com>
Approved-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Reviewed-by: default avatarAsh McKenzie <amckenzie@gitlab.com>
Co-authored-by: default avatarIgor Drozdov <idrozdov@gitlab.com>
parents b3428942 f0553cfb
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -379,20 +379,11 @@ func TestTwoFactorAuthRecoveryCodesSuccess(t *testing.T) {
},
}
client := runSSHD(t, successAPI(t, handler))
session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
session, stdin, stdout := newSession(t, client)
reader := bufio.NewReader(stdout)
err = session.Start("2fa_recovery_codes")
err := session.Start("2fa_recovery_codes")
require.NoError(t, err)
line, err := reader.ReadString('\n')
Loading
Loading
@@ -428,20 +419,11 @@ func TwoFactorAuthVerifySuccess(t *testing.T) {
},
}
client := runSSHD(t, successAPI(t, handler))
session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
session, stdin, stdout := newSession(t, client)
reader := bufio.NewReader(stdout)
err = session.Start("2fa_verify")
err := session.Start("2fa_verify")
require.NoError(t, err)
line, err := reader.ReadString('\n')
Loading
Loading
@@ -480,17 +462,9 @@ func TestGitReceivePackSuccess(t *testing.T) {
ensureGitalyRepository(t)
client := runSSHD(t, successAPI(t))
session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()
session, stdin, stdout := newSession(t, client)
stdin, err := session.StdinPipe()
require.NoError(t, err)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
err = session.Start(fmt.Sprintf("git-receive-pack %s", testRepo))
err := session.Start(fmt.Sprintf("git-receive-pack %s", testRepo))
require.NoError(t, err)
// Gracefully close connection
Loading
Loading
@@ -525,17 +499,9 @@ func TestGeoGitReceivePackSuccess(t *testing.T) {
},
}
client := runSSHD(t, successAPI(t, handler))
session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
session, stdin, stdout := newSession(t, client)
err = session.Start(fmt.Sprintf("git-receive-pack %s", testRepo))
err := session.Start(fmt.Sprintf("git-receive-pack %s", testRepo))
require.NoError(t, err)
// Gracefully close connection
Loading
Loading
@@ -559,59 +525,47 @@ func TestGitUploadPackSuccess(t *testing.T) {
ensureGitalyRepository(t)
client := runSSHD(t, successAPI(t))
session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()
defer client.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
numberOfSessions := 3
for sessionNumber := 0; sessionNumber < numberOfSessions; sessionNumber++ {
t.Run(fmt.Sprintf("session #%v", sessionNumber), func(t *testing.T) {
session, stdin, stdout := newSession(t, client)
reader := bufio.NewReader(stdout)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
reader := bufio.NewReader(stdout)
err := session.Start(fmt.Sprintf("git-upload-pack %s", testRepo))
require.NoError(t, err)
err = session.Start(fmt.Sprintf("git-upload-pack %s", testRepo))
require.NoError(t, err)
line, err := reader.ReadString('\n')
require.NoError(t, err)
require.Regexp(t, "^[0-9a-f]{44} HEAD.+", line)
line, err := reader.ReadString('\n')
require.NoError(t, err)
require.Regexp(t, "^[0-9a-f]{44} HEAD.+", line)
// Gracefully close connection
_, err = fmt.Fprintln(stdin, "0000")
require.NoError(t, err)
// Gracefully close connection
_, err = fmt.Fprintln(stdin, "0000")
require.NoError(t, err)
output, err := io.ReadAll(stdout)
require.NoError(t, err)
output, err := io.ReadAll(stdout)
require.NoError(t, err)
outputLines := strings.Split(string(output), "\n")
outputLines := strings.Split(string(output), "\n")
for i := 1; i < (len(outputLines) - 1); i++ {
require.Regexp(t, "^[0-9a-f]{44} refs/(heads|tags)/[^ ]+", outputLines[i])
}
for i := 1; i < (len(outputLines) - 1); i++ {
require.Regexp(t, "^[0-9a-f]{44} refs/(heads|tags)/[^ ]+", outputLines[i])
require.Equal(t, "0000", outputLines[len(outputLines)-1])
})
}
require.Equal(t, "0000", outputLines[len(outputLines)-1])
}
func TestGitUploadArchiveSuccess(t *testing.T) {
ensureGitalyRepository(t)
client := runSSHD(t, successAPI(t))
session, err := client.NewSession()
require.NoError(t, err)
defer session.Close()
stdin, err := session.StdinPipe()
require.NoError(t, err)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
session, stdin, stdout := newSession(t, client)
reader := bufio.NewReader(stdout)
err = session.Start(fmt.Sprintf("git-upload-archive %s", testRepo))
err := session.Start(fmt.Sprintf("git-upload-archive %s", testRepo))
require.NoError(t, err)
_, err = fmt.Fprintln(stdin, "0012argument HEAD\n0000")
Loading
Loading
@@ -631,3 +585,20 @@ func TestGitUploadArchiveSuccess(t *testing.T) {
t.Logf("output: %q", output)
require.Equal(t, []byte("0000"), output[len(output)-4:])
}
func newSession(t *testing.T, client *ssh.Client) (*ssh.Session, io.WriteCloser, io.Reader) {
session, err := client.NewSession()
require.NoError(t, err)
stdin, err := session.StdinPipe()
require.NoError(t, err)
stdout, err := session.StdoutPipe()
require.NoError(t, err)
t.Cleanup(func() {
session.Close()
})
return session, stdin, stdout
}
Loading
Loading
@@ -36,7 +36,7 @@ type connection struct {
remoteAddr string
}
type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error)
type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
func newConnection(cfg *config.Config, nconn net.Conn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
Loading
Loading
@@ -50,12 +50,12 @@ func newConnection(cfg *config.Config, nconn net.Conn) *connection {
}
}
func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) context.Context {
func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) {
log.WithContextFields(ctx, log.Fields{}).Info("server: handleConn: start")
sconn, chans, err := c.initServerConn(ctx, srvCfg)
if err != nil {
return ctx
return
}
if c.cfg.Server.ClientAliveInterval > 0 {
Loading
Loading
@@ -64,17 +64,10 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl
go c.sendKeepAliveMsg(ctx, sconn, ticker)
}
ctxWithLogMetadataChan := make(chan context.Context)
defer close(ctxWithLogMetadataChan)
go c.handleRequests(ctx, sconn, chans, ctxWithLogMetadataChan, handler)
ctxWithLogMetadata := <-ctxWithLogMetadataChan
c.handleRequests(ctx, sconn, chans, handler)
reason := sconn.Wait()
log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done")
return ctxWithLogMetadata
}
func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) {
Loading
Loading
@@ -101,7 +94,7 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi
return sconn, chans, err
}
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, ctxWithLogMetadataChan chan<- context.Context, handler channelHandler) {
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
Loading
Loading
@@ -144,12 +137,10 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn,
}()
metrics.SliSshdSessionsTotal.Inc()
ctxWithLogMetadata, err := handler(ctx, sconn, channel, requests)
err := handler(ctx, sconn, channel, requests)
if err != nil {
c.trackError(ctxlog, err)
}
ctxWithLogMetadataChan <- ctxWithLogMetadata
}()
}
Loading
Loading
Loading
Loading
@@ -81,25 +81,23 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo
return true, nil, nil
}
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel, chan<- context.Context) {
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}}
conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)}
chans := make(chan ssh.NewChannel, 1)
chans <- newChannel
ctxWithLogMetadataChan := make(chan context.Context)
return conn, chans, ctxWithLogMetadataChan
return conn, chans
}
func TestPanicDuringSessionIsRecovered(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
numSessions := 0
require.NotPanics(t, func() {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
numSessions += 1
close(chans)
panic("This is a panic")
Loading
Loading
@@ -114,10 +112,10 @@ func TestUnknownChannelType(t *testing.T) {
defer close(rejectCh)
newChannel := &fakeNewChannel{channelType: "unknown session", rejectCh: rejectCh}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
go func() {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, nil)
conn.handleRequests(context.Background(), nil, chans, nil)
}()
rejectionData := <-rejectCh
Loading
Loading
@@ -131,15 +129,15 @@ func TestTooManySessions(t *testing.T) {
defer close(rejectCh)
newChannel := &fakeNewChannel{channelType: "session", rejectCh: rejectCh}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return ctx, nil
return nil
})
}()
Loading
Loading
@@ -149,14 +147,14 @@ func TestTooManySessions(t *testing.T) {
func TestAcceptSessionSucceeds(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
close(chans)
return ctx, nil
return nil
})
require.True(t, channelHandled)
Loading
Loading
@@ -168,14 +166,14 @@ func TestAcceptSessionFails(t *testing.T) {
acceptErr := errors.New("some failure")
newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
go func() {
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
channelHandled = true
return ctx, nil
return nil
})
}()
Loading
Loading
@@ -207,12 +205,12 @@ func TestSessionsMetrics(t *testing.T) {
initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal)
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return ctx, errors.New("custom error")
return errors.New("custom error")
})
eventuallyInDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
Loading
Loading
@@ -229,13 +227,13 @@ func TestSessionsMetrics(t *testing.T) {
{"not our ref", grpcstatus.Error(grpccodes.Internal, `rpc error: code = Internal desc = cmd wait: exit status 128, stderr: "fatal: git upload-pack: not our ref 9106d18f6a1b8022f6517f479696f3e3ea5e68c1"`)},
} {
t.Run(ignoredError.desc, func(t *testing.T) {
conn, chans, ctxWithLogMetadataChan := setup(1, newChannel)
conn, chans := setup(1, newChannel)
ignored := ignoredError.err
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
close(chans)
return ctx, ignored
return ignored
})
eventuallyInDelta(t, initialSessionsTotal+2+float64(i), testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
Loading
Loading
Loading
Loading
@@ -194,7 +194,8 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
conn := newConnection(s.Config, nconn)
ctxWithLogMetadata := conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) {
var ctxWithLogMetadata context.Context
conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
session := &session{
cfg: s.Config,
channel: channel,
Loading
Loading
@@ -204,7 +205,10 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started: started,
}
return session.handle(ctx, requests)
var err error
ctxWithLogMetadata, err = session.handle(ctx, requests)
return err
})
ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds(), "meta": extractMetaDataFromContext(ctxWithLogMetadata)}).Info("access: finish")
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment