As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit a487572a authored by Stan Hu's avatar Stan Hu
Browse files

Make it possible to propagate correlation ID across processes

Previously, gitlab-shell did not pass a context through the application.
Correlation IDs were generated down the call stack instead of passed
around from the start execution.

This has several potential downsides:

1. It's easier for programming mistakes to be made in future that lead
to multiple correlation IDs being generated for a single request.
2. Correlation IDs cannot be passed in from upstream requests
3. Other advantages of context passing, such as distributed tracing is
not possible.

This commit changes the behavior:

1. Extract the correlation ID from the environment at the start of
the application.
2. If no correlation ID exists, generate a random one.
3. Pass the correlation ID to the GitLabNet API requests.

This change also enables other clients of GitLabNet (e.g. Gitaly) to
pass along the correlation ID in the internal API requests
(https://gitlab.com/gitlab-org/gitaly/-/issues/2725).

Fixes https://gitlab.com/gitlab-org/gitlab-shell/-/issues/474
parent f100e7e8
No related branches found
No related tags found
No related merge requests found
Showing
with 178 additions and 66 deletions
package client
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Loading
Loading
@@ -78,7 +79,7 @@ func TestClients(t *testing.T) {
func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
t.Run("Successful get", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Get("/hello")
response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -104,7 +105,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
hook := testhelper.SetupLogger()
data := map[string]string{"key": "value"}
response, err := client.Post("/post_endpoint", data)
response, err := client.Post(context.Background(), "/post_endpoint", data)
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -128,7 +129,7 @@ func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Get("/missing")
response, err := client.Get(context.Background(), "/missing")
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
Loading
Loading
@@ -144,7 +145,7 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
t.Run("Missing error for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Post("/missing", map[string]string{})
response, err := client.Post(context.Background(), "/missing", map[string]string{})
assert.EqualError(t, err, "Internal API error (404)")
assert.Nil(t, response)
Loading
Loading
@@ -161,13 +162,13 @@ func testMissing(t *testing.T, client *GitlabNetClient) {
func testErrorMessage(t *testing.T, client *GitlabNetClient) {
t.Run("Error with message for GET", func(t *testing.T) {
response, err := client.Get("/error")
response, err := client.Get(context.Background(), "/error")
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
t.Run("Error with message for POST", func(t *testing.T) {
response, err := client.Post("/error", map[string]string{})
response, err := client.Post(context.Background(), "/error", map[string]string{})
assert.EqualError(t, err, "Don't do that")
assert.Nil(t, response)
})
Loading
Loading
@@ -177,7 +178,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for GET", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Get("/broken")
response, err := client.Get(context.Background(), "/broken")
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
Loading
Loading
@@ -194,7 +195,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
t.Run("Broken request for POST", func(t *testing.T) {
hook := testhelper.SetupLogger()
response, err := client.Post("/broken", map[string]string{})
response, err := client.Post(context.Background(), "/broken", map[string]string{})
assert.EqualError(t, err, "Internal API unreachable")
assert.Nil(t, response)
Loading
Loading
@@ -211,7 +212,7 @@ func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
t.Run("Authentication headers for GET", func(t *testing.T) {
response, err := client.Get("/auth")
response, err := client.Get(context.Background(), "/auth")
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -226,7 +227,7 @@ func testAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
t.Run("Authentication headers for POST", func(t *testing.T) {
response, err := client.Post("/auth", map[string]string{})
response, err := client.Post(context.Background(), "/auth", map[string]string{})
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
Loading
Loading
@@ -11,8 +11,9 @@ import (
"strings"
"time"
log "github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
log "github.com/sirupsen/logrus"
)
const (
Loading
Loading
@@ -59,7 +60,7 @@ func normalizePath(path string) string {
return path
}
func newRequest(method, host, path string, data interface{}) (*http.Request, string, error) {
func newRequest(ctx context.Context, method, host, path string, data interface{}) (*http.Request, string, error) {
var jsonReader io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
Loading
Loading
@@ -70,20 +71,13 @@ func newRequest(method, host, path string, data interface{}) (*http.Request, str
jsonReader = bytes.NewReader(jsonData)
}
correlationID, err := correlation.RandomID()
ctx := context.Background()
if err != nil {
log.WithError(err).Warn("unable to generate correlation ID")
} else {
ctx = correlation.ContextWithCorrelation(ctx, correlationID)
}
request, err := http.NewRequestWithContext(ctx, method, host+path, jsonReader)
if err != nil {
return nil, "", err
}
correlationID := correlation.ExtractFromContext(ctx)
return request, correlationID, nil
}
Loading
Loading
@@ -102,16 +96,16 @@ func parseError(resp *http.Response) error {
}
func (c *GitlabNetClient) Get(path string) (*http.Response, error) {
return c.DoRequest(http.MethodGet, normalizePath(path), nil)
func (c *GitlabNetClient) Get(ctx context.Context, path string) (*http.Response, error) {
return c.DoRequest(ctx, http.MethodGet, normalizePath(path), nil)
}
func (c *GitlabNetClient) Post(path string, data interface{}) (*http.Response, error) {
return c.DoRequest(http.MethodPost, normalizePath(path), data)
func (c *GitlabNetClient) Post(ctx context.Context, path string, data interface{}) (*http.Response, error) {
return c.DoRequest(ctx, http.MethodPost, normalizePath(path), data)
}
func (c *GitlabNetClient) DoRequest(method, path string, data interface{}) (*http.Response, error) {
request, correlationID, err := newRequest(method, c.httpClient.Host, path, data)
func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, data interface{}) (*http.Response, error) {
request, correlationID, err := newRequest(ctx, method, c.httpClient.Host, path, data)
if err != nil {
return nil, err
}
Loading
Loading
package client
import (
"context"
"encoding/base64"
"fmt"
"io/ioutil"
Loading
Loading
@@ -51,11 +52,11 @@ func TestBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, username, password, requests)
defer cleanup()
response, err := client.Get("/get_endpoint")
response, err := client.Get(context.Background(), "/get_endpoint")
require.NoError(t, err)
testBasicAuthHeaders(t, response)
response, err = client.Post("/post_endpoint", nil)
response, err = client.Post(context.Background(), "/post_endpoint", nil)
require.NoError(t, err)
testBasicAuthHeaders(t, response)
}
Loading
Loading
@@ -89,7 +90,7 @@ func TestEmptyBasicAuthSettings(t *testing.T) {
client, cleanup := setup(t, "", "", requests)
defer cleanup()
_, err := client.Get("/empty_basic_auth")
_, err := client.Get(context.Background(), "/empty_basic_auth")
require.NoError(t, err)
}
Loading
Loading
package client
import (
"context"
"fmt"
"io/ioutil"
"net/http"
Loading
Loading
@@ -43,7 +44,7 @@ func TestSuccessfulRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, tc.selfSigned)
defer cleanup()
response, err := client.Get("/hello")
response, err := client.Get(context.Background(), "/hello")
require.NoError(t, err)
require.NotNil(t, response)
Loading
Loading
@@ -80,7 +81,7 @@ func TestFailedRequests(t *testing.T) {
client, cleanup := setupWithRequests(t, tc.caFile, tc.caPath, false)
defer cleanup()
_, err := client.Get("/hello")
_, err := client.Get(context.Background(), "/hello")
require.Error(t, err)
assert.Equal(t, err.Error(), "Internal API unreachable")
Loading
Loading
Loading
Loading
@@ -38,7 +38,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
fmt.Fprintf(readWriter.ErrOut, "%v\n", err)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
Loading
Loading
@@ -41,7 +41,10 @@ func main() {
os.Exit(1)
}
if err = cmd.Execute(); err != nil {
ctx, finished := command.ContextWithCorrelationID()
defer finished()
if err = cmd.Execute(ctx); err != nil {
console.DisplayWarningMessage(err.Error(), readWriter.ErrOut)
os.Exit(1)
}
Loading
Loading
package authorizedkeys
import (
"context"
"fmt"
"strconv"
Loading
Loading
@@ -17,7 +18,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
func (c *Command) Execute(ctx context.Context) error {
// Do and return nothing when the expected and actual user don't match.
// This can happen when the user in sshd_config doesn't match the user
// trying to login. When nothing is printed, the user will be denied access.
Loading
Loading
@@ -27,15 +28,15 @@ func (c *Command) Execute() error {
return nil
}
if err := c.printKeyLine(); err != nil {
if err := c.printKeyLine(ctx); err != nil {
return err
}
return nil
}
func (c *Command) printKeyLine() error {
response, err := c.getAuthorizedKey()
func (c *Command) printKeyLine(ctx context.Context) error {
response, err := c.getAuthorizedKey(ctx)
if err != nil {
fmt.Fprintln(c.ReadWriter.Out, fmt.Sprintf("# No key was found for %s", c.Args.Key))
return nil
Loading
Loading
@@ -51,11 +52,11 @@ func (c *Command) printKeyLine() error {
return nil
}
func (c *Command) getAuthorizedKey() (*authorizedkeys.Response, error) {
func (c *Command) getAuthorizedKey(ctx context.Context) (*authorizedkeys.Response, error) {
client, err := authorizedkeys.NewClient(c.Config)
if err != nil {
return nil, err
}
return client.GetByKey(c.Args.Key)
return client.GetByKey(ctx, c.Args.Key)
}
Loading
Loading
@@ -2,6 +2,7 @@ package authorizedkeys
import (
"bytes"
"context"
"encoding/json"
"net/http"
"testing"
Loading
Loading
@@ -97,7 +98,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
package authorizedprincipals
import (
"context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
Loading
Loading
@@ -15,7 +16,7 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
func (c *Command) Execute(ctx context.Context) error {
if err := c.printPrincipalLines(); err != nil {
return err
}
Loading
Loading
Loading
Loading
@@ -2,6 +2,7 @@ package authorizedprincipals
import (
"bytes"
"context"
"testing"
"github.com/stretchr/testify/require"
Loading
Loading
@@ -54,7 +55,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
package command
import (
"context"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedkeys"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/authorizedprincipals"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
Loading
Loading
@@ -16,10 +18,13 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/command/uploadpack"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/labkit/tracing"
)
type Command interface {
Execute() error
Execute(ctx context.Context) error
}
func New(e *executable.Executable, arguments []string, config *config.Config, readWriter *readwriter.ReadWriter) (Command, error) {
Loading
Loading
@@ -35,6 +40,28 @@ func New(e *executable.Executable, arguments []string, config *config.Config, re
return nil, disallowedcommand.Error
}
// ContextWithCorrelationID() will always return a background Context
// with a correlation ID. It will first attempt to extract the ID from
// an environment variable. If is not available, a random one will be
// generated.
func ContextWithCorrelationID() (context.Context, func()) {
ctx, finished := tracing.ExtractFromEnv(context.Background())
defer finished()
correlationID := correlation.ExtractFromContext(ctx)
if correlationID == "" {
correlationID, err := correlation.RandomID()
if err != nil {
log.WithError(err).Warn("unable to generate correlation ID")
} else {
log.Info("generated random correlation ID")
ctx = correlation.ContextWithCorrelation(ctx, correlationID)
}
}
return ctx, finished
}
func buildCommand(e *executable.Executable, args commandargs.CommandArgs, config *config.Config, readWriter *readwriter.ReadWriter) Command {
switch e.Name {
case executable.GitlabShell:
Loading
Loading
Loading
Loading
@@ -2,6 +2,7 @@ package command
import (
"errors"
"os"
"testing"
"github.com/stretchr/testify/require"
Loading
Loading
@@ -20,6 +21,7 @@ import (
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/executable"
"gitlab.com/gitlab-org/gitlab-shell/internal/testhelper"
"gitlab.com/gitlab-org/labkit/correlation"
)
var (
Loading
Loading
@@ -151,3 +153,67 @@ func TestFailingNew(t *testing.T) {
})
}
}
func TestContextWithCorrelationID(t *testing.T) {
testCases := []struct {
name string
additionalEnv map[string]string
expectedCorrelationID string
}{
{
name: "no CORRELATION_ID in environment",
},
{
name: "CORRELATION_ID in environment",
additionalEnv: map[string]string{
"CORRELATION_ID": "abc123",
},
expectedCorrelationID: "abc123",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resetEnvironment := addAdditionalEnv(tc.additionalEnv)
defer resetEnvironment()
ctx, finished := ContextWithCorrelationID()
require.NotNil(t, ctx, "ctx is nil")
require.NotNil(t, finished, "finished is nil")
correlationID := correlation.ExtractFromContext(ctx)
require.NotEmpty(t, correlationID)
if tc.expectedCorrelationID != "" {
require.Equal(t, tc.expectedCorrelationID, correlationID)
}
defer finished()
})
}
}
// addAdditionalEnv will configure additional environment values
// and return a deferrable function to reset the environment to
// it's original state after the test
func addAdditionalEnv(envMap map[string]string) func() {
prevValues := map[string]string{}
unsetValues := []string{}
for k, v := range envMap {
value, exists := os.LookupEnv(k)
if exists {
prevValues[k] = value
} else {
unsetValues = append(unsetValues, k)
}
os.Setenv(k, v)
}
return func() {
for k, v := range prevValues {
os.Setenv(k, v)
}
for _, k := range unsetValues {
os.Unsetenv(k)
}
}
}
package discover
import (
"context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/commandargs"
Loading
Loading
@@ -15,8 +16,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
response, err := c.getUserInfo()
func (c *Command) Execute(ctx context.Context) error {
response, err := c.getUserInfo(ctx)
if err != nil {
return fmt.Errorf("Failed to get username: %v", err)
}
Loading
Loading
@@ -30,11 +31,11 @@ func (c *Command) Execute() error {
return nil
}
func (c *Command) getUserInfo() (*discover.Response, error) {
func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) {
client, err := discover.NewClient(c.Config)
if err != nil {
return nil, err
}
return client.GetByCommandArgs(c.Args)
return client.GetByCommandArgs(ctx, c.Args)
}
Loading
Loading
@@ -2,6 +2,7 @@ package discover
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
Loading
Loading
@@ -83,7 +84,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, tc.expectedOutput, buffer.String())
Loading
Loading
@@ -126,7 +127,7 @@ func TestFailingExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, tc.expectedError)
Loading
Loading
package healthcheck
import (
"context"
"fmt"
"gitlab.com/gitlab-org/gitlab-shell/internal/command/readwriter"
Loading
Loading
@@ -18,8 +19,8 @@ type Command struct {
ReadWriter *readwriter.ReadWriter
}
func (c *Command) Execute() error {
response, err := c.runCheck()
func (c *Command) Execute(ctx context.Context) error {
response, err := c.runCheck(ctx)
if err != nil {
return fmt.Errorf("%v: FAILED - %v", apiMessage, err)
}
Loading
Loading
@@ -34,13 +35,13 @@ func (c *Command) Execute() error {
return nil
}
func (c *Command) runCheck() (*healthcheck.Response, error) {
func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) {
client, err := healthcheck.NewClient(c.Config)
if err != nil {
return nil, err
}
response, err := client.Check()
response, err := client.Check(ctx)
if err != nil {
return nil, err
}
Loading
Loading
Loading
Loading
@@ -2,6 +2,7 @@ package healthcheck
import (
"bytes"
"context"
"encoding/json"
"net/http"
"testing"
Loading
Loading
@@ -53,7 +54,7 @@ func TestExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.NoError(t, err)
require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String())
Loading
Loading
@@ -69,7 +70,7 @@ func TestFailingRedisExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.Error(t, err, "Redis available via internal API: FAILED")
require.Equal(t, "Internal API available: OK\n", buffer.String())
}
Loading
Loading
@@ -84,7 +85,7 @@ func TestFailingAPIExecute(t *testing.T) {
ReadWriter: &readwriter.ReadWriter{Out: buffer},
}
err := cmd.Execute()
err := cmd.Execute(context.Background())
require.Empty(t, buffer.String())
require.EqualError(t, err, "Internal API available: FAILED - Internal API error (500)")
}
package lfsauthenticate
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Loading
Loading
@@ -34,7 +35,7 @@ type Payload struct {
ExpiresIn int `json:"expires_in,omitempty"`
}
func (c *Command) Execute() error {
func (c *Command) Execute(ctx context.Context) error {
args := c.Args.SshArgs
if len(args) < 3 {
return disallowedcommand.Error
Loading
Loading
@@ -49,12 +50,12 @@ func (c *Command) Execute() error {
return err
}
accessResponse, err := c.verifyAccess(action, repo)
accessResponse, err := c.verifyAccess(ctx, action, repo)
if err != nil {
return err
}
payload, err := c.authenticate(operation, repo, accessResponse.UserId)
payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId)
if err != nil {
// return nothing just like Ruby's GitlabShell#lfs_authenticate does
return nil
Loading
Loading
@@ -80,19 +81,19 @@ func actionFromOperation(operation string) (commandargs.CommandType, error) {
return action, nil
}
func (c *Command) verifyAccess(action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
func (c *Command) verifyAccess(ctx context.Context, action commandargs.CommandType, repo string) (*accessverifier.Response, error) {
cmd := accessverifier.Command{c.Config, c.Args, c.ReadWriter}
return cmd.Verify(action, repo)
return cmd.Verify(ctx, action, repo)
}
func (c *Command) authenticate(operation string, repo, userId string) ([]byte, error) {
func (c *Command) authenticate(ctx context.Context, operation string, repo, userId string) ([]byte, error) {
client, err := lfsauthenticate.NewClient(c.Config, c.Args)
if err != nil {
return nil, err
}
response, err := client.Authenticate(operation, repo, userId)
response, err := client.Authenticate(ctx, operation, repo, userId)
if err != nil {
return nil, err
}
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment