As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit 49bd5e39 authored by Ash McKenzie's avatar Ash McKenzie
Browse files

Pass ctx where needed

parent 5f3b0d3f
No related branches found
No related tags found
No related merge requests found
Showing
with 64 additions and 52 deletions
Loading
Loading
@@ -22,7 +22,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("twofactorrecover: execute: Waiting for user input")
Loading
Loading
@@ -34,7 +34,7 @@ func (c *Command) Execute(ctx context.Context) error {
fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.")
}
return nil
return ctx, nil
}
func (c *Command) getUserAnswer(ctx context.Context) string {
Loading
Loading
Loading
Loading
@@ -132,7 +132,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, output.String())
Loading
Loading
Loading
Loading
@@ -25,10 +25,10 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
client, err := twofactorverify.NewClient(c.Config)
if err != nil {
return err
return ctx, err
}
ctx, cancel := context.WithTimeout(ctx, timeout)
Loading
Loading
@@ -67,7 +67,7 @@ func (c *Command) Execute(ctx context.Context) error {
log.WithContextFields(ctx, log.Fields{"message": message}).Info("Two factor verify command finished")
fmt.Fprintf(c.ReadWriter.Out, "\n%v\n", message)
return nil
return ctx, nil
}
func (c *Command) getOTP(ctx context.Context) (string, error) {
Loading
Loading
Loading
Loading
@@ -160,7 +160,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, prompt+"\n"+tc.expectedOutput, output.String())
Loading
Loading
@@ -183,7 +183,10 @@ func TestCanceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error)
go func() { errCh <- cmd.Execute(ctx) }()
go func() {
_, err := cmd.Execute(ctx)
errCh <- err
}()
cancel()
require.NoError(t, <-errCh)
Loading
Loading
Loading
Loading
@@ -56,7 +56,7 @@ func TestUploadArchive(t *testing.T) {
ctx := correlation.ContextWithCorrelation(context.Background(), "a-correlation-id")
ctx = correlation.ContextWithClientName(ctx, "gitlab-shell-tests")
err := cmd.Execute(ctx)
_, err := cmd.Execute(ctx)
require.NoError(t, err)
require.Equal(t, "UploadArchive: "+repo, output.String())
Loading
Loading
Loading
Loading
@@ -16,19 +16,19 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
return ctx, disallowedcommand.Error
}
repo := args[1]
response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
return ctx, err
}
return c.performGitalyCall(ctx, response)
return ctx, c.performGitalyCall(ctx, response)
}
func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
Loading
Loading
Loading
Loading
@@ -26,6 +26,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
Loading
Loading
@@ -57,7 +57,7 @@ func TestUploadPack(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input},
}
err := cmd.Execute(ctx)
_, err := cmd.Execute(ctx)
require.NoError(t, err)
require.Equal(t, "SSHUploadPackWithSidechannel: "+repo, output.String())
Loading
Loading
Loading
Loading
@@ -17,16 +17,16 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
return ctx, disallowedcommand.Error
}
repo := args[1]
response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
return ctx, err
}
if response.IsCustomAction() {
Loading
Loading
@@ -35,10 +35,10 @@ func (c *Command) Execute(ctx context.Context) error {
ReadWriter: c.ReadWriter,
EOFSent: false,
}
return customAction.Execute(ctx, response)
return ctx, customAction.Execute(ctx, response)
}
return c.performGitalyCall(ctx, response)
return ctx, c.performGitalyCall(ctx, response)
}
func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
Loading
Loading
Loading
Loading
@@ -26,6 +26,6 @@ func TestForbiddenAccess(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
Loading
Loading
@@ -36,7 +36,7 @@ type connection struct {
remoteAddr string
}
type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error
type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error)
func newConnection(cfg *config.Config, nconn net.Conn) *connection {
maxSessions := cfg.Server.ConcurrentSessionsLimit
Loading
Loading
@@ -94,7 +94,7 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi
return sconn, chans, err
}
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) {
func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) context.Context {
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
for newChannel := range chans {
Loading
Loading
@@ -134,7 +134,7 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn,
}()
metrics.SliSshdSessionsTotal.Inc()
err := handler(sconn, channel, requests)
ctx, err = handler(ctx, sconn, channel, requests)
if err != nil {
c.trackError(ctxlog, err)
}
Loading
Loading
@@ -148,6 +148,8 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn,
ctx, cancel := context.WithTimeout(ctx, EOFTimeout)
defer cancel()
c.concurrentSessions.Acquire(ctx, c.maxSessions)
return ctx
}
func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) {
Loading
Loading
Loading
Loading
@@ -97,7 +97,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) {
numSessions := 0
require.NotPanics(t, func() {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
numSessions += 1
close(chans)
panic("This is a panic")
Loading
Loading
@@ -135,9 +135,9 @@ func TestTooManySessions(t *testing.T) {
defer cancel()
go func() {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
<-ctx.Done() // Keep the accepted channel open until the end of the test
return nil
return ctx, nil
})
}()
Loading
Loading
@@ -148,12 +148,13 @@ func TestTooManySessions(t *testing.T) {
func TestAcceptSessionSucceeds(t *testing.T) {
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
channelHandled = true
close(chans)
return nil
return ctx, nil
})
require.True(t, channelHandled)
Loading
Loading
@@ -166,12 +167,13 @@ func TestAcceptSessionFails(t *testing.T) {
acceptErr := errors.New("some failure")
newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr}
conn, chans := setup(1, newChannel)
ctx := context.Background()
channelHandled := false
go func() {
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
channelHandled = true
return nil
return ctx, nil
})
}()
Loading
Loading
@@ -203,11 +205,12 @@ func TestSessionsMetrics(t *testing.T) {
initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal)
newChannel := &fakeNewChannel{channelType: "session"}
conn, chans := setup(1, newChannel)
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
close(chans)
return errors.New("custom error")
return ctx, errors.New("custom error")
})
eventuallyInDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
Loading
Loading
@@ -226,9 +229,11 @@ func TestSessionsMetrics(t *testing.T) {
t.Run(ignoredError.desc, func(t *testing.T) {
conn, chans = setup(1, newChannel)
ignored := ignoredError.err
conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error {
ctx := context.Background()
conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) {
close(chans)
return ignored
return ctx, ignored
})
eventuallyInDelta(t, initialSessionsTotal+2+float64(i), testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1)
Loading
Loading
Loading
Loading
@@ -49,7 +49,7 @@ type exitStatusReq struct {
ExitStatus uint32
}
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error {
func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("session: handle: entering request loop")
Loading
Loading
@@ -70,13 +70,13 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) erro
case "exec":
// The command has been executed as `ssh user@host command` or `exec` channel has been used
// in the app implementation
shouldContinue, err = s.handleExec(ctx, req)
ctx, shouldContinue, err = s.handleExec(ctx, req)
case "shell":
// The command has been entered into the shell or `shell` channel has been used
// in the app implementation
shouldContinue = false
var status uint32
status, err = s.handleShell(ctx, req)
ctx, status, err = s.handleShell(ctx, req)
s.exit(ctx, status)
default:
// Ignore unknown requests but don't terminate the session
Loading
Loading
@@ -99,7 +99,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) erro
ctxlog.Debug("session: handle: exiting request loop")
return err
return ctx, err
}
func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) {
Loading
Loading
@@ -132,21 +132,22 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error)
return true, nil
}
func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) {
func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, bool, error) {
var execRequest execRequest
if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil {
return false, err
return ctx, false, err
}
s.execCmd = execRequest.Command
status, err := s.handleShell(ctx, req)
ctx, status, err := s.handleShell(ctx, req)
s.exit(ctx, status)
return false, err
return ctx, false, err
}
func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) {
func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) {
ctxlog := log.ContextLogger(ctx)
if req.WantReply {
Loading
Loading
@@ -183,7 +184,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, er
s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error())
}
return 128, err
return ctx, 128, err
}
cmdName := reflect.TypeOf(cmd).String()
Loading
Loading
@@ -194,18 +195,19 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, er
}).Info("session: handleShell: executing command")
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
if err := cmd.Execute(ctx); err != nil {
ctx, err = cmd.Execute(ctx)
if err != nil {
grpcStatus := grpcstatus.Convert(err)
if grpcStatus.Code() != grpccodes.Internal {
s.toStderr(ctx, "ERROR: %v\n", grpcStatus.Message())
}
return 1, err
return ctx, 1, err
}
ctxlog.Info("session: handleShell: command executed successfully")
return 0, nil
return ctx, 0, nil
}
func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) {
Loading
Loading
Loading
Loading
@@ -146,7 +146,7 @@ func TestHandleExec(t *testing.T) {
r := &ssh.Request{Payload: tc.payload}
s.channel = f
shouldContinue, err := s.handleExec(context.Background(), r)
_, shouldContinue, err := s.handleExec(context.Background(), r)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, false, shouldContinue)
Loading
Loading
@@ -210,7 +210,7 @@ func TestHandleShell(t *testing.T) {
}
r := &ssh.Request{}
exitCode, err := s.handleShell(context.Background(), r)
_, exitCode, err := s.handleShell(context.Background(), r)
if tc.expectedErrString != "" {
require.Equal(t, tc.expectedErrString, err.Error())
Loading
Loading
Loading
Loading
@@ -193,7 +193,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
started := time.Now()
conn := newConnection(s.Config, nconn)
conn.handle(ctx, s.serverConfig.get(ctx), func(sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error {
conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) {
session := &session{
cfg: s.Config,
channel: channel,
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment