As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Unverified Commit 49bd5e39 authored by Ash McKenzie's avatar Ash McKenzie
Browse files

Pass ctx where needed

parent 5f3b0d3f
No related branches found
No related tags found
No related merge requests found
Showing
with 50 additions and 48 deletions
Loading
Loading
@@ -43,7 +43,7 @@ func main() {
ctx, finished := command.Setup(executable.Name, config)
defer finished()
if err = cmd.Execute(ctx); err != nil {
if ctx, err = cmd.Execute(ctx); err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -46,7 +46,7 @@ func main() {
ctx, finished := command.Setup(executable.Name, config)
defer finished()
if err = cmd.Execute(ctx); err != nil {
if ctx, err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -46,7 +46,7 @@ func main() {
ctx, finished := command.Setup(executable.Name, config)
defer finished()
if err = cmd.Execute(ctx); err != nil {
if ctx, err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -76,7 +76,7 @@ func main() {
ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("gitlab-shell: main: executing command")
fips.Check()
if err := cmd.Execute(ctx); err != nil {
if _, err := cmd.Execute(ctx); err != nil {
ctxlog.WithError(err).Warn("gitlab-shell: main: command execution failed")
if grpcstatus.Convert(err).Code() != grpccodes.Internal {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
Loading
Loading
Loading
Loading
@@ -18,21 +18,21 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
if c.Args.ExpectedUser != c.Args.ActualUser {
// TODO: Log this event once we have a consistent way to log in Go.
// See https://gitlab.com/gitlab-org/gitlab-shell/issues/192 for more info.
return nil
return ctx, nil
}
if err := c.printKeyLine(ctx); err != nil {
return err
return ctx, err
}
return nil
return ctx, nil
}
func (c *Command) printKeyLine(ctx context.Context) error {
Loading
Loading
Loading
Loading
@@ -84,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
Loading
Loading
@@ -16,12 +16,12 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
if err := c.printPrincipalLines(); err != nil {
return err
return ctx, err
}
return nil
return ctx, nil
}
func (c *Command) printPrincipalLines() error {
Loading
Loading
Loading
Loading
@@ -42,7 +42,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
Loading
Loading
@@ -9,7 +9,7 @@ import (
)
type Command interface {
Execute(ctx context.Context) error
Execute(ctx context.Context) (context.Context, error)
}
// Setup() initializes tracing from the configuration file and generates a
Loading
Loading
Loading
Loading
@@ -16,10 +16,10 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
return ctx, fmt.Errorf("Failed to get username: %v", err)
}
if response.IsAnonymous() {
Loading
Loading
@@ -28,7 +28,7 @@ func (c *Command) Execute(ctx context.Context) error {
fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username)
}
return nil
return ctx, nil
}
func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
Loading
Loading
Loading
Loading
@@ -81,7 +81,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
@@ -123,7 +123,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
Loading
Loading
Loading
Loading
@@ -19,20 +19,20 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
return ctx, fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", apiMessage)
if !response.Redis {
return fmt.Errorf("%v: FAILED", redisMessage)
return ctx, fmt.Errorf("%v: FAILED", redisMessage)
}
fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", redisMessage)
return nil
return ctx, nil
}
func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
Loading
Loading
Loading
Loading
@@ -53,7 +53,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
Loading
Loading
@@ -68,7 +68,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
Loading
Loading
@@ -82,7 +82,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API unreachable")
}
Loading
Loading
@@ -37,10 +37,10 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
return ctx, disallowedcommand.Error
}
// e.g. git-lfs-authenticate user/repo.git download
Loading
Loading
@@ -49,12 +49,12 @@ func (c *Command) Execute(ctx context.Context) error {
action, err := actionFromOperation(operation)
if err != nil {
return err
return ctx, err
}
accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
return ctx, err
}
payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
Loading
Loading
@@ -65,12 +65,12 @@ func (c *Command) Execute(ctx context.Context) error {
log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId},
).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed")
return nil
return ctx, nil
}
fmt.Fprintf(c.ReadWriter.Out, "%s\n", payload)
return nil
return ctx, nil
}
func actionFromOperation(operation string) (commandargs.CommandType, error) {
Loading
Loading
Loading
Loading
@@ -54,7 +54,7 @@ func TestFailedRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Error(t, err)
require.Equal(t, tc.expectedOutput, err.Error())
Loading
Loading
@@ -145,7 +145,7 @@ func TestLfsAuthenticateRequests(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, output.String())
Loading
Loading
Loading
Loading
@@ -34,10 +34,10 @@ type tokenArgs struct {
ExpiresDate string // Calculated, a TTL is passed from command-line.
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
err := c.parseTokenArgs()
if err != nil {
return err
return ctx, err
}
log.WithContextFields(ctx, log.Fields{
Loading
Loading
@@ -46,13 +46,14 @@ func (c *Command) Execute(ctx context.Context) error {
response, err := c.getPersonalAccessToken(ctx)
if err != nil {
return err
return ctx, err
}
fmt.Fprint(c.ReadWriter.Out, "Token: "+response.Token+"\n")
fmt.Fprint(c.ReadWriter.Out, "Scopes: "+strings.Join(response.Scopes, ",")+"\n")
fmt.Fprint(c.ReadWriter.Out, "Expires: "+response.ExpiresAt+"\n")
return nil
return ctx, nil
}
func (c *Command) parseTokenArgs() error {
Loading
Loading
Loading
Loading
@@ -167,7 +167,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: output, In: input},
}
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
if tc.expectedError == "" {
require.NoError(t, err)
Loading
Loading
Loading
Loading
@@ -72,7 +72,7 @@ func TestReceivePack(t *testing.T) {
ctx := correlation.ContextWithCorrelation(context.Background(), "a-correlation-id")
ctx = correlation.ContextWithClientName(ctx, "gitlab-shell-tests")
err := cmd.Execute(ctx)
_, err := cmd.Execute(ctx)
require.NoError(t, err)
if tc.username != "" {
Loading
Loading
Loading
Loading
@@ -18,16 +18,16 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute(ctx context.Context) error {
func (c *Command) Execute(ctx context.Context) (context.Context, error) {
args := c.Args.SshArgs
if len(args) != 2 {
return disallowedcommand.Error
return ctx, disallowedcommand.Error
}
repo := args[1]
response, err := c.verifyAccess(ctx, repo)
if err != nil {
return err
return ctx, err
}
if response.IsCustomAction() {
Loading
Loading
@@ -42,7 +42,7 @@ func (c *Command) Execute(ctx context.Context) error {
Response: response,
}
return cmd.Execute(ctx)
return ctx, cmd.Execute(ctx)
}
customAction := customaction.Command{
Loading
Loading
@@ -50,10 +50,10 @@ func (c *Command) Execute(ctx context.Context) error {
ReadWriter: c.ReadWriter,
EOFSent: true,
}
return customAction.Execute(ctx, response)
return ctx, customAction.Execute(ctx, response)
}
return c.performGitalyCall(ctx, response)
return ctx, c.performGitalyCall(ctx, response)
}
func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) {
Loading
Loading
Loading
Loading
@@ -18,14 +18,15 @@ func TestForbiddenAccess(t *testing.T) {
requests := requesthandlers.BuildDisallowedByApiHandlers(t)
cmd, _ := setup(t, "disallowed", requests)
err := cmd.Execute(context.Background())
_, err := cmd.Execute(context.Background())
require.Equal(t, "Disallowed by API call", err.Error())
}
func TestCustomReceivePack(t *testing.T) {
cmd, output := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t))
require.NoError(t, cmd.Execute(context.Background()))
_, err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "customoutput", output.String())
}
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment