As we reevaluate how to best support and maintain Staging Ref in the future, we encourage development teams using this environment to highlight their use cases in the following issue: https://gitlab.com/gitlab-com/gl-infra/software-delivery/framework/software-delivery-framework-issue-tracker/-/issues/36.

Skip to content
Snippets Groups Projects
Commit 0d7ef238 authored by Igor Drozdov's avatar Igor Drozdov
Browse files

Merge branch 'sshd-forwarded-for' into 'main'

Pass original IP from PROXY requests to internal API calls

See merge request gitlab-org/gitlab-shell!665
parents 01f4e022 9b60ce49
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -76,6 +76,7 @@ func TestClients(t *testing.T) {
testErrorMessage(t, client)
testAuthenticationHeader(t, client)
testJWTAuthenticationHeader(t, client)
testXForwardedForHeader(t, client)
})
}
}
Loading
Loading
@@ -221,6 +222,21 @@ func testJWTAuthenticationHeader(t *testing.T, client *GitlabNetClient) {
})
}
func testXForwardedForHeader(t *testing.T, client *GitlabNetClient) {
t.Run("X-Forwarded-For Header inserted if original address in context", func(t *testing.T) {
ctx := context.WithValue(context.Background(), OriginalRemoteIPContextKey{}, "196.7.0.238")
response, err := client.Get(ctx, "/x_forwarded_for")
require.NoError(t, err)
require.NotNil(t, response)
defer response.Body.Close()
responseBody, err := io.ReadAll(response.Body)
require.NoError(t, err)
require.Equal(t, "196.7.0.238", string(responseBody))
})
}
func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestRequestHandler {
requests := []testserver.TestRequestHandler{
{
Loading
Loading
@@ -256,6 +272,12 @@ func buildRequests(t *testing.T, relativeURLRoot string) []testserver.TestReques
fmt.Fprint(w, r.Header.Get(apiSecretHeaderName))
},
},
{
Path: "/api/v4/internal/x_forwarded_for",
Handler: func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, r.Header.Get("X-Forwarded-For"))
},
},
{
Path: "/api/v4/internal/error",
Handler: func(w http.ResponseWriter, r *http.Request) {
Loading
Loading
Loading
Loading
@@ -41,6 +41,9 @@ type ApiError struct {
Msg string
}
// To use as the key in a Context to set an X-Forwarded-For header in a request
type OriginalRemoteIPContextKey struct{}
func (e *ApiError) Error() string {
return e.Msg
}
Loading
Loading
@@ -150,6 +153,11 @@ func (c *GitlabNetClient) DoRequest(ctx context.Context, method, path string, da
}
request.Header.Set(apiSecretHeaderName, tokenString)
originalRemoteIP, ok := ctx.Value(OriginalRemoteIPContextKey{}).(string)
if ok {
request.Header.Add("X-Forwarded-For", originalRemoteIP)
}
request.Header.Add("Content-Type", "application/json")
request.Header.Add("User-Agent", c.userAgent)
request.Close = true
Loading
Loading
Loading
Loading
@@ -3,7 +3,6 @@ package accessverifier
import (
"context"
"fmt"
"net"
"net/http"
pb "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
Loading
Loading
@@ -86,7 +85,7 @@ func (c *Client) Verify(ctx context.Context, args *commandargs.Shell, action com
request.KeyId = args.GitlabKeyId
}
request.CheckIp = parseIP(args.Env.RemoteAddr)
request.CheckIp = gitlabnet.ParseIP(args.Env.RemoteAddr)
response, err := c.client.Post(ctx, "/allowed", request)
if err != nil {
Loading
Loading
@@ -117,18 +116,3 @@ func parse(hr *http.Response, args *commandargs.Shell) (*Response, error) {
func (r *Response) IsCustomAction() bool {
return r.StatusCode == http.StatusMultipleChoices
}
func parseIP(remoteAddr string) string {
// The remoteAddr field can be filled by:
// 1. An IP address via the SSH_CONNECTION environment variable
// 2. A host:port combination via the PROXY protocol
ip, _, err := net.SplitHostPort(remoteAddr)
// If we don't have a port or can't parse this address for some reason,
// just return the original string.
if err != nil {
return remoteAddr
}
return ip
}
Loading
Loading
@@ -3,6 +3,7 @@ package gitlabnet
import (
"encoding/json"
"fmt"
"net"
"net/http"
"gitlab.com/gitlab-org/gitlab-shell/client"
Loading
Loading
@@ -34,3 +35,18 @@ func ParseJSON(hr *http.Response, response interface{}) error {
return nil
}
func ParseIP(remoteAddr string) string {
// The remoteAddr field can be filled by:
// 1. An IP address via the SSH_CONNECTION environment variable
// 2. A host:port combination via the PROXY protocol
ip, _, err := net.SplitHostPort(remoteAddr)
// If we don't have a port or can't parse this address for some reason,
// just return the original string.
if err != nil {
return remoteAddr
}
return ip
}
Loading
Loading
@@ -12,7 +12,9 @@ import (
"github.com/pires/go-proxyproto"
"golang.org/x/crypto/ssh"
"gitlab.com/gitlab-org/gitlab-shell/client"
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
"gitlab.com/gitlab-org/gitlab-shell/internal/gitlabnet"
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
"gitlab.com/gitlab-org/labkit/correlation"
Loading
Loading
@@ -145,13 +147,26 @@ func (s *Server) getStatus() status {
return s.status
}
func contextWithValues(parent context.Context, nconn net.Conn) context.Context {
ctx := correlation.ContextWithCorrelation(parent, correlation.SafeRandomID())
// If we're dealing with a PROXY connection, register the original requester's IP
mconn, ok := nconn.(*proxyproto.Conn)
if ok {
ip := gitlabnet.ParseIP(mconn.Raw().RemoteAddr().String())
ctx = context.WithValue(ctx, client.OriginalRemoteIPContextKey{}, ip)
}
return ctx
}
func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
defer s.wg.Done()
metrics.SshdConnectionsInFlight.Inc()
defer metrics.SshdConnectionsInFlight.Dec()
ctx, cancel := context.WithCancel(correlation.ContextWithCorrelation(ctx, correlation.SafeRandomID()))
ctx, cancel := context.WithCancel(contextWithValues(ctx, nconn))
defer cancel()
go func() {
<-ctx.Done()
Loading
Loading
Loading
Loading
@@ -27,6 +27,7 @@ const (
var (
correlationId = ""
xForwardedFor = ""
)
func TestListenAndServe(t *testing.T) {
Loading
Loading
@@ -63,6 +64,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
},
DestinationAddr: target,
}
xForwardedFor = "127.0.0.1"
defer func() {
xForwardedFor = "" // Cleanup for other test cases
}()
testCases := []struct {
desc string
Loading
Loading
@@ -132,9 +137,9 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
require.NoError(t, err)
}
sshConn, _, _, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
sshConn, sshChans, sshRequs, err := ssh.NewClientConn(conn, serverUrl, clientConfig(t))
if sshConn != nil {
sshConn.Close()
defer sshConn.Close()
}
if tc.isRejected {
Loading
Loading
@@ -142,6 +147,10 @@ func TestListenAndServeRejectsPlainConnectionsWhenProxyProtocolEnabled(t *testin
require.Regexp(t, "ssh: handshake failed", err.Error())
} else {
require.NoError(t, err)
client := ssh.NewClient(sshConn, sshChans, sshRequs)
defer client.Close()
holdSession(t, client)
}
})
}
Loading
Loading
@@ -306,6 +315,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
correlationId = r.Header.Get("X-Request-Id")
require.NotEmpty(t, correlationId)
require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
fmt.Fprint(w, `{"id": 1000, "key": "key"}`)
},
Loading
Loading
@@ -313,6 +323,7 @@ func setupServerWithContext(t *testing.T, cfg *config.Config, ctx context.Contex
Path: "/api/v4/internal/discover",
Handler: func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, correlationId, r.Header.Get("X-Request-Id"))
require.Equal(t, xForwardedFor, r.Header.Get("X-Forwarded-For"))
fmt.Fprint(w, `{"id": 1000, "name": "Test User", "username": "test-user"}`)
},
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment