As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit 1a2bfecd authored by Ash McKenzie's avatar Ash McKenzie
Browse files

Merge branch 'sh-extract-context-from-env' into 'master'

Make it possible to propagate correlation ID across processes

Closes #474

See merge request gitlab-org/gitlab-shell!413
parents f100e7e8 a487572a
No related branches found
No related tags found
No related merge requests found
Showing
with 178 additions and 66 deletions
package client
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Loading
Loading
@@ -78,7 +79,7 @@ func TestClients(t *testing.T) {
func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
t.Run("Successful get", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Get("/hello")
response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
hook := testhelper.SetupLogger()
data := map[string]string{"key": "value"}
response, err := client.Post("/post_endpoint", data)
response, err := client.Post(context.Background(), "/post_endpoint", data)
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Get("/missing")
response, err := client.Get(context.Background(), "/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
Loading
Loading
@@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Post("/missing", map[string]string{})
response, err := client.Post(context.Background(), "/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
Loading
Loading
@@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
func testErrorMessage(t *testing.T, client *GitlabNetClient) {
t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error")
response, err := client.Get(context.Background(), "/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
t.Run("Error with message for POST", func(t *testing.T) {
response, err := client.Post("/error", map[string]string{})
response, err := client.Post(context.Background(), "/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
Loading
Loading
@@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Get("/broken")
response, err := client.Get(context.Background(), "/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
Loading
Loading
@@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Post("/broken", map[string]string{})
response, err := client.Post(context.Background(), "/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
Loading
Loading
@@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth")
response, err := client.Get(context.Background(), "/auth")
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
t.Run("Authentication headers for POST", func(t *testing.T) {
response, err := client.Post("/auth", map[string]string{})
response, err := client.Post(context.Background(), "/auth", map[string]string{})
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
Loading
Loading
@@ -11,8 +11,9 @@ import (
"strings"
"time"
log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
log "github.com/sirupsen/logrus"
)
const (
Loading
Loading
@@ -59,7 +60,7 @@ func normalizePath(path string) string {
return path
}
func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) {
func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) {
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
Loading
Loading
@@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str
jsonReader = bytes.NewReader(jsonData)
}
correlationID, err := correlation.RandomID()
ctx := context.Background()
if err != nil {
log.WithError(err).Warn("unable to generate correlation ID")
} else {
ctx = correlation.ContextWithCorrelation(ctx, correlationID)
}
request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
if err != nil {
return nil, "", err
}
correlationID := correlation.ExtractFromContext(ctx)
return request, correlationID, nil
}
Loading
Loading
@@ -102,16 +96,16 @@ func parseError(resp *http.Response) error {
}
func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
return c.DoRequest(http.MethodGet, normalizePath(path), nil)
func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
}
func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
return c.DoRequest(http.MethodPost, normalizePath(path), data)
func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
}
func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
request, correlationID, err := newRequest(method, c.httpClient.Host, path, data)
func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data)
if err != nil {
return nil, err
}
Loading
Loading
package client
import (
"context"
"encoding/base64"
"fmt"
"io/ioutil"
Loading
Loading
@@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, username, password, requests)
defer cleanup()
response, err := client.Get("/get_endpoint")
response, err := client.Get(context.Background(), "/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
response, err = client.Post("/post_endpoint", nil)
response, err = client.Post(context.Background(), "/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
Loading
Loading
@@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, "", "", requests)
defer cleanup()
_, err := client.Get("/empty_basic_auth")
_, err := client.Get(context.Background(), "/empty_basic_auth")
require.NoError(t, err)
}
Loading
Loading
package client
import (
"context"
"fmt"
"io/ioutil"
"net/http"
Loading
Loading
@@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
defer cleanup()
response, err := client.Get("/hello")
response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
defer cleanup()
_, err := client.Get("/hello")
_, err := client.Get(context.Background(), "/hello")
require.Error(t, err)
assert.Equal(t, err.Error(), "Internal API unreachable")
Loading
Loading
Loading
Loading
@@ -38,7 +38,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
package authorizedkeys
import (
"context"
"fmt"
"strconv"
Loading
Loading
@@ -17,7 +18,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
func (c *Command) Execute(ctx context.Context) error {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
Loading
Loading
@@ -27,15 +28,15 @@ func (c *Command) Execute() error {
return nil
}
if err := c.printKeyLine(); err != nil {
if err := c.printKeyLine(ctx); err != nil {
return err
}
return nil
}
func (c *Command) printKeyLine() error {
response, err := c.getAuthorizedKey()
func (c *Command) printKeyLine(ctx context.Context) error {
response, err := c.getAuthorizedKey(ctx)
if err != nil {
fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
return nil
Loading
Loading
@@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error {
return nil
}
func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) {
client, err := authorizedkeys.NewClient(c.Config)
if err != nil {
return nil, err
}
return client.GetByKey(c.Args.Key)
return client.GetByKey(ctx, c.Args.Key)
}
Loading
Loading
@@ -2,6 +2,7 @@ package authorizedkeys
import (
"bytes"
"context"
"encoding/json"
"net/http"
"testing"
Loading
Loading
@@ -97,7 +98,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
package authorizedprincipals
import (
"context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
Loading
Loading
@@ -15,7 +16,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
func (c *Command) Execute(ctx context.Context) error {
if err := c.printPrincipalLines(); err != nil {
return err
}
Loading
Loading
Loading
Loading
@@ -2,6 +2,7 @@ package authorizedprincipals
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
Loading
Loading
@@ -54,7 +55,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
package command
import (
"context"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
Loading
Loading
@@ -16,10 +18,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/labkit/tracing"
)
type Command interface {
Execute() error
Execute(ctx context.Context) error
}
func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
Loading
Loading
@@ -35,6 +40,28 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re
return nil, disallowedcommand.Error
}
// ContextWithCorrelationID() will always return a background Context
// with a correlation ID. It will first attempt to extract the ID from
// an environment variable. If is not available, a random one will be
// generated.
func ContextWithCorrelationID() (context.Context, func()) {
ctx, finished := tracing.ExtractFromEnv(context.Background())
defer finished()
correlationID := correlation.ExtractFromContext(ctx)
if correlationID == "" {
correlationID, err := correlation.RandomID()
if err != nil {
log.WithError(err).Warn("unable to generate correlation ID")
} else {
log.Info("generated random correlation ID")
ctx = correlation.ContextWithCorrelation(ctx, correlationID)
}
}
return ctx, finished
}
func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch e.Name {
case executable.GitlabShell:
Loading
Loading
Loading
Loading
@@ -2,6 +2,7 @@ package command
import (
"errors"
"os"
"testing"
"github.com/stretchr/testify/require"
Loading
Loading
@@ -20,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
"gitlab.com/gitlab-org/labkit/correlation"
)
var (
Loading
Loading
@@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) {
})
}
}
func TestContextWithCorrelationID(t *testing.T) {
testCases := []struct {
name string
additionalEnv map[string]string
expectedCorrelationID string
}{
{
name: "no CORRELATION_ID in environment",
},
{
name: "CORRELATION_ID in environment",
additionalEnv: map[string]string{
"CORRELATION_ID": "abc123",
},
expectedCorrelationID: "abc123",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resetEnvironment := addAdditionalEnv(tc.additionalEnv)
defer resetEnvironment()
ctx, finished := ContextWithCorrelationID()
require.NotNil(t, ctx, "ctx is nil")
require.NotNil(t, finished, "finished is nil")
correlationID := correlation.ExtractFromContext(ctx)
require.NotEmpty(t, correlationID)
if tc.expectedCorrelationID != "" {
require.Equal(t, tc.expectedCorrelationID, correlationID)
}
defer finished()
})
}
}
// addAdditionalEnv will configure additional environment values
// and return a deferrable function to reset the environment to
// it's original state after the test
func addAdditionalEnv(envMap map[string]string) func() {
prevValues := map[string]string{}
unsetValues := []string{}
for k, v := range envMap {
value, exists := os.LookupEnv(k)
if exists {
prevValues[k] = value
} else {
unsetValues = append(unsetValues, k)
}
os.Setenv(k, v)
}
return func() {
for k, v := range prevValues {
os.Setenv(k, v)
}
for _, k := range unsetValues {
os.Unsetenv(k)
}
}
}
package discover
import (
"context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
Loading
Loading
@@ -15,8 +16,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
response, err := c.getUserInfo()
func (c *Command) Execute(ctx context.Context) error {
response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
Loading
Loading
@@ -30,11 +31,11 @@ func (c *Command) Execute() error {
return nil
}
func (c *Command) getUserInfo() (*discover.Response, error) {
func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
return client.GetByCommandArgs(c.Args)
return client.GetByCommandArgs(ctx, c.Args)
}
Loading
Loading
@@ -2,6 +2,7 @@ package discover
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Loading
Loading
@@ -83,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
@@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
Loading
Loading
package healthcheck
import (
"context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
Loading
Loading
@@ -18,8 +19,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
response, err := c.runCheck()
func (c *Command) Execute(ctx context.Context) error {
response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
Loading
Loading
@@ -34,13 +35,13 @@ func (c *Command) Execute() error {
return nil
}
func (c *Command) runCheck() (*healthcheck.Response, error) {
func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
client, err := healthcheck.NewClient(c.Config)
if err != nil {
return nil, err
}
response, err := client.Check()
response, err := client.Check(ctx)
if err != nil {
return nil, err
}
Loading
Loading
Loading
Loading
@@ -2,6 +2,7 @@ package healthcheck
import (
"bytes"
"context"
"encoding/json"
"net/http"
"testing"
Loading
Loading
@@ -53,7 +54,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
Loading
Loading
@@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
Loading
Loading
@@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
}
package lfsauthenticate
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Loading
Loading
@@ -34,7 +35,7 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
func (c *Command) Execute() error {
func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
Loading
Loading
@@ -49,12 +50,12 @@ func (c *Command) Execute() error {
return err
}
accessResponse, err := c.verifyAccess(action, repo)
accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
}
payload, err := c.authenticate(operation, repo, accessResponse.UserId)
payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
return nil
Loading
Loading
@@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) {
return action, nil
}
func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
return cmd.Verify(action, repo)
return cmd.Verify(ctx, action, repo)
}
func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) {
func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) {
client, err := lfsauthenticate.NewClient(c.Config, c.Args)
if err != nil {
return nil, err
}
response, err := client.Authenticate(operation, repo, userId)
response, err := client.Authenticate(ctx, operation, repo, userId)
if err != nil {
return nil, err
}
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment