As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Select Git revision
  • main default protected
  • 608-improve-gitlab-shell-logging-structure
  • ashmckenzie/update-golangci-setup
  • 762_use_workhorse_ssh_endpoint
  • fix-issue-708
  • id-use-workhorse-git-ssh-rpc
  • ashmckenzie/debug-yamux-issues
  • ag-remove-geo-ffs
  • aakriti.gupta-main-patch-64039
  • ashmckenzie/gssapi-fixes
  • 671-race-golang-1-x-failed-with-stdin-send-error-eof
  • igor.drozdov-main-patch-82081
  • ashmckenzie/include-metadata-in-access-finish-log-line
  • 660-job-failed-4563144016
  • id-bump-logrus
  • sh-ssh-certificates
  • tmp-geo-push-poc
  • igor.drozdov-main-patch-40896
  • tmp-kerberos-testing
  • id-test-agains-1.19
  • v14.39.0
  • v14.38.0
  • v14.37.0
  • v14.36.0
  • v14.35.0
  • v14.34.0
  • v14.33.0
  • v14.32.0
  • v14.31.0
  • v14.30.1
  • v14.30.0
  • v14.29.0
  • v14.28.0
  • v14.27.0
  • v14.26.0
  • v14.25.0
  • v14.24.1
  • v14.24.0
  • v14.23.0
  • v14.22.0
40 results

client_test.go

Code owners
Assign users and groups as approvers for specific file changes. Learn more.
client_test.go 9.00 KiB
package client

import (
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"path"
	"strings"
	"testing"
	"time"

	"github.com/golang-jwt/jwt/v5"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver"
	"gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper"
)

var (
	secret          = "sssh, it's a secret"
	defaultHttpOpts = []HTTPClientOpt{WithHTTPRetryOpts(time.Millisecond, time.Millisecond, 2)}
)

func TestClients(t *testing.T) {
	testRoot := testhelper.PrepareTestRootDir(t)

	testCases := []struct {
		desc            string
		relativeURLRoot string
		caFile          string
		server          func(*testing.T, []testserver.TestRequestHandler) string
		secret          string
	}{
		{
			desc:   "Socket client",
			server: testserver.StartSocketHTTPServer,
			secret: secret,
		},
		{
			desc:            "Socket client with a relative URL at /",
			relativeURLRoot: "/",
			server:          testserver.StartSocketHTTPServer,
			secret:          secret,
		},
		{
			desc:            "Socket client with relative URL at /gitlab",
			relativeURLRoot: "/gitlab",
			server:          testserver.StartSocketHTTPServer,
			secret:          secret,
		},
		{
			desc:   "Http client",
			server: testserver.StartHTTPServer,
			secret: secret,
		},
		{
			desc:   "Https client",
			caFile: path.Join(testRoot, "certs/valid/server.crt"),
			server: func(t *testing.T, handlers []testserver.TestRequestHandler) string {
				return testserver.StartHTTPSServer(t, handlers, "")
			},
			secret: secret,
		},
		{
			desc:   "Secret with newlines",
			caFile: path.Join(testRoot, "certs/valid/server.crt"),
			server: func(t *testing.T, handlers []testserver.TestRequestHandler) string {
				return testserver.StartHTTPSServer(t, handlers, "")
			},
			secret: "\n" + secret + "\n",
		},
		{
			desc:   "Retry client",
			server: testserver.StartRetryHTTPServer,
			secret: secret,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.desc, func(t *testing.T) {
			url := tc.server(t, buildRequests(t, tc.relativeURLRoot))

			httpClient, err := NewHTTPClientWithOpts(url, tc.relativeURLRoot, tc.caFile, "", 1, defaultHttpOpts)
			require.NoError(t, err)

			client, err := NewGitlabNetClient("", "", tc.secret, httpClient)
			require.NoError(t, err)

			testBrokenRequest(t, client)
			testSuccessfulGet(t, client)
			testSuccessfulPost(t, client)
			testMissing(t, client)
			testErrorMessage(t, client)
			testJWTAuthenticationHeader(t, client)
			testXForwardedForHeader(t, client)
			testHostWithTrailingSlash(t, client)
		})
	}
}

func testSuccessfulGet(t *testing.T, client *GitlabNetClient) {
	t.Run("Successful get", func(t *testing.T) {
		response, err := client.Get(context.Background(), "/hello")
		require.NoError(t, err)
		require.NotNil(t, response)

		defer response.Body.Close()

		responseBody, err := io.ReadAll(response.Body)
		require.NoError(t, err)
		require.Equal(t, string(responseBody), "Hello")
	})
}

func testSuccessfulPost(t *testing.T, client *GitlabNetClient) {
	t.Run("Successful Post", func(t *testing.T) {
		data := map[string]string{"key": "value"}

		response, err := client.Post(context.Background(), "/post_endpoint", data)
		require.NoError(t, err)
		require.NotNil(t, response)

		defer response.Body.Close()

		responseBody, err := io.ReadAll(response.Body)
		require.NoError(t, err)
		require.Equal(t, "Echo: {\"key\":\"value\"}", string(responseBody))
	})
}

func testMissing(t *testing.T, client *GitlabNetClient) {
	t.Run("Missing error for GET", func(t *testing.T) {
		response, err := client.Get(context.Background(), "/missing")
		require.EqualError(t, err, "Internal API error (404)")
		require.Nil(t, response)
	})

	t.Run("Missing error for POST", func(t *testing.T) {
		response, err := client.Post(context.Background(), "/missing", map[string]string{})
		require.EqualError(t, err, "Internal API error (404)")
		require.Nil(t, response)
	})
}

func testErrorMessage(t *testing.T, client *GitlabNetClient) {
	t.Run("Error with message for GET", func(t *testing.T) {
		response, err := client.Get(context.Background(), "/error")
		require.EqualError(t, err, "Don't do that")
		require.Nil(t, response)
	})

	t.Run("Error with message for POST", func(t *testing.T) {
		response, err := client.Post(context.Background(), "/error", map[string]string{})
		require.EqualError(t, err, "Don't do that")
		require.Nil(t, response)
	})
}

func testBrokenRequest(t *testing.T, client *GitlabNetClient) {
	t.Run("Broken request for GET", func(t *testing.T) {
		response, err := client.Get(context.Background(), "/broken")
		require.EqualError(t, err, "Internal API unreachable")
		require.Nil(t, response)
	})

	t.Run("Broken request for POST", func(t *testing.T) {
		response, err := client.Post(context.Background(), "/broken", map[string]string{})
		require.EqualError(t, err, "Internal API unreachable")
		require.Nil(t, response)
	})
}

func testJWTAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
	verifyJWTToken := func(t *testing.T, response *http.Response) {
		responseBody, err := io.ReadAll(response.Body)
		require.NoError(t, err)

		claims := &jwt.RegisteredClaims{}
		token, err := jwt.ParseWithClaims(string(responseBody), claims, func(token *jwt.Token) (interface{}, error) {
			return []byte(secret), nil
		})
		require.NoError(t, err)
		require.True(t, token.Valid)
		require.Equal(t, "gitlab-shell", claims.Issuer)
		require.WithinDuration(t, time.Now().Truncate(time.Second), claims.IssuedAt.Time, time.Second)
		require.WithinDuration(t, time.Now().Truncate(time.Second).Add(time.Minute), claims.ExpiresAt.Time, time.Second)
	}

	t.Run("JWT authentication headers for GET", func(t *testing.T) {
		response, err := client.Get(context.Background(), "/jwt_auth")
		require.NoError(t, err)
		require.NotNil(t, response)

		defer response.Body.Close()

		verifyJWTToken(t, response)
	})

	t.Run("JWT authentication headers for POST", func(t *testing.T) {
		response, err := client.Post(context.Background(), "/jwt_auth", map[string]string{})
		require.NoError(t, err)
		require.NotNil(t, response)

		defer response.Body.Close()

		verifyJWTToken(t, response)
	})
}

func testXForwardedForHeader(t *testing.T, client *GitlabNetClient) {
	t.Run("X-Forwarded-For Header inserted if original address in context", func(t *testing.T) {
		ctx := context.WithValue(context.Background(), OriginalRemoteIPContextKey{}, "196.7.0.238")
		response, err := client.Get(ctx, "/x_forwarded_for")
		require.NoError(t, err)
		require.NotNil(t, response)

		defer response.Body.Close()

		responseBody, err := io.ReadAll(response.Body)
		require.NoError(t, err)
		require.Equal(t, "196.7.0.238", string(responseBody))
	})
}

func testHostWithTrailingSlash(t *testing.T, client *GitlabNetClient) {
	oldHost := client.httpClient.Host
	client.httpClient.Host = oldHost + "/"

	testSuccessfulGet(t, client)
	testSuccessfulPost(t, client)

	client.httpClient.Host = oldHost
}

func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestRequestHandler {
	requests := []testserver.TestRequestHandler{
		{
			Path: "/api/v4/internal/hello",
			Handler: func(w http.ResponseWriter, r *http.Request) {
				assert.Equal(t, http.MethodGet, r.Method)

				fmt.Fprint(w, "Hello")
			},
		},
		{
			Path: "/api/v4/internal/post_endpoint",
			Handler: func(w http.ResponseWriter, r *http.Request) {
				assert.Equal(t, http.MethodPost, r.Method)

				b, err := io.ReadAll(r.Body)
				defer r.Body.Close()

				assert.NoError(t, err)

				fmt.Fprint(w, "Echo: "+string(b))
			},
		},
		{
			Path: "/api/v4/internal/jwt_auth",
			Handler: func(w http.ResponseWriter, r *http.Request) {
				fmt.Fprint(w, r.Header.Get(apiSecretHeaderName))
			},
		},
		{
			Path: "/api/v4/internal/x_forwarded_for",
			Handler: func(w http.ResponseWriter, r *http.Request) {
				fmt.Fprint(w, r.Header.Get("X-Forwarded-For"))
			},
		},
		{
			Path: "/api/v4/internal/error",
			Handler: func(w http.ResponseWriter, r *http.Request) {
				w.Header().Set("Content-Type", "application/json")
				w.WriteHeader(http.StatusBadRequest)
				body := map[string]string{
					"message": "Don't do that",
				}
				json.NewEncoder(w).Encode(body)
			},
		},
		{
			Path: "/api/v4/internal/broken",
			Handler: func(w http.ResponseWriter, r *http.Request) {
				panic("Broken")
			},
		},
	}

	relativeURLRoot = strings.Trim(relativeURLRoot, "/")
	if relativeURLRoot != "" {
		for i, r := range requests {
			requests[i].Path = fmt.Sprintf("/%s%s", relativeURLRoot, r.Path)
		}
	}

	return requests
}

func TestRetryOnFailure(t *testing.T) {
	reqAttempts := 0
	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		reqAttempts++
		w.WriteHeader(500)
	}))
	defer srv.Close()

	httpClient, err := NewHTTPClientWithOpts(srv.URL, "/", "", "", 1, defaultHttpOpts)
	require.NoError(t, err)
	require.NotNil(t, httpClient.RetryableHTTP)
	client, err := NewGitlabNetClient("", "", "", httpClient)
	require.NoError(t, err)

	_, err = client.Get(context.Background(), "/")
	require.EqualError(t, err, "Internal API unreachable")
	require.Equal(t, 3, reqAttempts)
}