mirror of
				https://github.com/elyby/chrly.git
				synced 2025-05-31 14:11:51 +05:30 
			
		
		
		
	Implemented API endpoint to sign arbitrary data
This commit is contained in:
		
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @@ -1,3 +1,5 @@ | |||||||
|  | github.com/SentimensRG/ctx v0.0.0-20180729130232-0bfd988c655d h1:CbB/Ef3TyBvSSJx2HDSUiw49ONTpaX6BGiI0jJEX6b8= | ||||||
|  | github.com/SentimensRG/ctx v0.0.0-20180729130232-0bfd988c655d/go.mod h1:cfn0Ycx1ASzCkl8+04zI4hrclf9YQ1QfncxzFiNtQLo= | ||||||
| github.com/brunomvsouza/singleflight v0.4.0 h1:9dNcTeYoXSus3xbZEM0EEZ11EcCRjUZOvVW8rnDMG5Y= | github.com/brunomvsouza/singleflight v0.4.0 h1:9dNcTeYoXSus3xbZEM0EEZ11EcCRjUZOvVW8rnDMG5Y= | ||||||
| github.com/brunomvsouza/singleflight v0.4.0/go.mod h1:8RYo9j5WQRupmsnUz5DlUWZxDLNi+t9Zhj3EZFmns7I= | github.com/brunomvsouza/singleflight v0.4.0/go.mod h1:8RYo9j5WQRupmsnUz5DlUWZxDLNi+t9Zhj3EZFmns7I= | ||||||
| github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d h1:S2NE3iHSwP0XV47EEXL8mWmRdEfGscSJ+7EgePNgt0s= | github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d h1:S2NE3iHSwP0XV47EEXL8mWmRdEfGscSJ+7EgePNgt0s= | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								internal/client/signer/local.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								internal/client/signer/local.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | package signer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"io" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Signer interface { | ||||||
|  | 	Sign(data io.Reader) ([]byte, error) | ||||||
|  | 	GetPublicKey(format string) ([]byte, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LocalSigner struct { | ||||||
|  | 	Signer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *LocalSigner) Sign(ctx context.Context, data string) (string, error) { | ||||||
|  | 	signed, err := s.Signer.Sign(strings.NewReader(data)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return string(signed), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *LocalSigner) GetPublicKey(ctx context.Context, format string) (string, error) { | ||||||
|  | 	publicKey, err := s.Signer.GetPublicKey(format) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return string(publicKey), nil | ||||||
|  | } | ||||||
							
								
								
									
										104
									
								
								internal/client/signer/local_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								internal/client/signer/local_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | |||||||
|  | package signer | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"io" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/mock" | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type SignerMock struct { | ||||||
|  | 	mock.Mock | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *SignerMock) Sign(data io.Reader) ([]byte, error) { | ||||||
|  | 	args := m.Called(data) | ||||||
|  | 	var result []byte | ||||||
|  | 	if casted, ok := args.Get(0).([]byte); ok { | ||||||
|  | 		result = casted | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result, args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *SignerMock) GetPublicKey(format string) ([]byte, error) { | ||||||
|  | 	args := m.Called(format) | ||||||
|  | 	var result []byte | ||||||
|  | 	if casted, ok := args.Get(0).([]byte); ok { | ||||||
|  | 		result = casted | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result, args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type LocalSignerServiceTestSuite struct { | ||||||
|  | 	suite.Suite | ||||||
|  |  | ||||||
|  | 	Service *LocalSigner | ||||||
|  |  | ||||||
|  | 	Signer *SignerMock | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *LocalSignerServiceTestSuite) SetupSubTest() { | ||||||
|  | 	t.Signer = &SignerMock{} | ||||||
|  |  | ||||||
|  | 	t.Service = &LocalSigner{ | ||||||
|  | 		Signer: t.Signer, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *LocalSignerServiceTestSuite) TearDownSubTest() { | ||||||
|  | 	t.Signer.AssertExpectations(t.T()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *LocalSignerServiceTestSuite) TestSign() { | ||||||
|  | 	t.Run("successfully sign", func() { | ||||||
|  | 		signature := []byte("mock signature") | ||||||
|  | 		t.Signer.On("Sign", mock.Anything).Return(signature, nil).Run(func(args mock.Arguments) { | ||||||
|  | 			r, _ := io.ReadAll(args.Get(0).(io.Reader)) | ||||||
|  | 			t.Equal([]byte("mock body to sign"), r) | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		result, err := t.Service.Sign(context.Background(), "mock body to sign") | ||||||
|  | 		t.NoError(err) | ||||||
|  | 		t.Equal(string(signature), result) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("handle error during sign", func() { | ||||||
|  | 		expectedErr := errors.New("mock error") | ||||||
|  | 		t.Signer.On("Sign", mock.Anything).Return(nil, expectedErr) | ||||||
|  |  | ||||||
|  | 		result, err := t.Service.Sign(context.Background(), "mock body to sign") | ||||||
|  | 		t.Error(err) | ||||||
|  | 		t.Same(expectedErr, err) | ||||||
|  | 		t.Empty(result) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *LocalSignerServiceTestSuite) TestGetPublicKey() { | ||||||
|  | 	t.Run("successfully get", func() { | ||||||
|  | 		publicKey := []byte("mock public key") | ||||||
|  | 		t.Signer.On("GetPublicKey", "pem").Return(publicKey, nil) | ||||||
|  |  | ||||||
|  | 		result, err := t.Service.GetPublicKey(context.Background(), "pem") | ||||||
|  | 		t.NoError(err) | ||||||
|  | 		t.Equal(string(publicKey), result) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("handle error", func() { | ||||||
|  | 		expectedErr := errors.New("mock error") | ||||||
|  | 		t.Signer.On("GetPublicKey", "pem").Return(nil, expectedErr) | ||||||
|  |  | ||||||
|  | 		result, err := t.Service.GetPublicKey(context.Background(), "pem") | ||||||
|  | 		t.Error(err) | ||||||
|  | 		t.Same(expectedErr, err) | ||||||
|  | 		t.Empty(result) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestLocalSignerService(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(LocalSignerServiceTestSuite)) | ||||||
|  | } | ||||||
| @@ -2,13 +2,15 @@ package cmd | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"github.com/spf13/cobra" | 	"github.com/spf13/cobra" | ||||||
|  |  | ||||||
|  | 	"ely.by/chrly/internal/di" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var serveCmd = &cobra.Command{ | var serveCmd = &cobra.Command{ | ||||||
| 	Use:   "serve", | 	Use:   "serve", | ||||||
| 	Short: "Starts HTTP handler for the skins system", | 	Short: "Starts HTTP handler for the skins system", | ||||||
| 	RunE: func(cmd *cobra.Command, args []string) error { | 	RunE: func(cmd *cobra.Command, args []string) error { | ||||||
| 		return startServer("skinsystem", "api") | 		return startServer(di.ModuleSkinsystem, di.ModuleProfiles, di.ModuleSigner) | ||||||
| 	}, | 	}, | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -9,8 +9,10 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var tokenCmd = &cobra.Command{ | var tokenCmd = &cobra.Command{ | ||||||
| 	Use:   "token", | 	Use:       "token scope1 ...", | ||||||
|  | 	Example:   "token profiles sign", | ||||||
| 	Short:     "Creates a new token, which allows to interact with Chrly API", | 	Short:     "Creates a new token, which allows to interact with Chrly API", | ||||||
|  | 	ValidArgs: []string{string(security.ProfilesScope), string(security.SignScope)}, | ||||||
| 	RunE: func(cmd *cobra.Command, args []string) error { | 	RunE: func(cmd *cobra.Command, args []string) error { | ||||||
| 		container := shouldGetContainer() | 		container := shouldGetContainer() | ||||||
| 		var auth *security.Jwt | 		var auth *security.Jwt | ||||||
| @@ -19,7 +21,12 @@ var tokenCmd = &cobra.Command{ | |||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		token, err := auth.NewToken(security.ProfileScope) | 		scopes := make([]security.Scope, len(args)) | ||||||
|  | 		for i := range args { | ||||||
|  | 			scopes[i] = security.Scope(args[i]) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		token, err := auth.NewToken(scopes...) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return fmt.Errorf("Unable to create a new token. The error is %v\n", err) | 			return fmt.Errorf("Unable to create a new token. The error is %v\n", err) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -6,9 +6,5 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var configDiOptions = di.Options( | var configDiOptions = di.Options( | ||||||
| 	di.Provide(newConfig), | 	di.Provide(viper.GetViper), | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func newConfig() *viper.Viper { |  | ||||||
| 	return viper.GetViper() |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -11,12 +11,18 @@ import ( | |||||||
| 	"github.com/spf13/viper" | 	"github.com/spf13/viper" | ||||||
|  |  | ||||||
| 	. "ely.by/chrly/internal/http" | 	. "ely.by/chrly/internal/http" | ||||||
|  | 	"ely.by/chrly/internal/security" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ModuleSkinsystem = "skinsystem" | ||||||
|  | const ModuleProfiles = "profiles" | ||||||
|  | const ModuleSigner = "signer" | ||||||
|  |  | ||||||
| var handlersDiOptions = di.Options( | var handlersDiOptions = di.Options( | ||||||
| 	di.Provide(newHandlerFactory, di.As(new(http.Handler))), | 	di.Provide(newHandlerFactory, di.As(new(http.Handler))), | ||||||
| 	di.Provide(newSkinsystemHandler, di.WithName("skinsystem")), | 	di.Provide(newSkinsystemHandler, di.WithName(ModuleSkinsystem)), | ||||||
| 	di.Provide(newApiHandler, di.WithName("api")), | 	di.Provide(newProfilesApiHandler, di.WithName(ModuleProfiles)), | ||||||
|  | 	di.Provide(newSignerApiHandler, di.WithName(ModuleSigner)), | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func newHandlerFactory( | func newHandlerFactory( | ||||||
| @@ -30,8 +36,8 @@ func newHandlerFactory( | |||||||
| 	// if you set an empty prefix. Since the main application should be mounted at the root prefix, | 	// if you set an empty prefix. Since the main application should be mounted at the root prefix, | ||||||
| 	// we use it as the base router | 	// we use it as the base router | ||||||
| 	var router *mux.Router | 	var router *mux.Router | ||||||
| 	if slices.Contains(enabledModules, "skinsystem") { | 	if slices.Contains(enabledModules, ModuleSkinsystem) { | ||||||
| 		if err := container.Resolve(&router, di.Name("skinsystem")); err != nil { | 		if err := container.Resolve(&router, di.Name(ModuleSkinsystem)); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| @@ -41,9 +47,9 @@ func newHandlerFactory( | |||||||
| 	router.StrictSlash(true) | 	router.StrictSlash(true) | ||||||
| 	router.NotFoundHandler = http.HandlerFunc(NotFoundHandler) | 	router.NotFoundHandler = http.HandlerFunc(NotFoundHandler) | ||||||
|  |  | ||||||
| 	if slices.Contains(enabledModules, "api") { | 	if slices.Contains(enabledModules, ModuleProfiles) { | ||||||
| 		var apiRouter *mux.Router | 		var profilesApiRouter *mux.Router | ||||||
| 		if err := container.Resolve(&apiRouter, di.Name("api")); err != nil { | 		if err := container.Resolve(&profilesApiRouter, di.Name(ModuleProfiles)); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -52,9 +58,29 @@ func newHandlerFactory( | |||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		apiRouter.Use(CreateAuthenticationMiddleware(authenticator)) | 		profilesApiRouter.Use(NewAuthenticationMiddleware(authenticator, security.ProfilesScope)) | ||||||
|  |  | ||||||
| 		mount(router, "/api", apiRouter) | 		mount(router, "/api/profiles", profilesApiRouter) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if slices.Contains(enabledModules, ModuleSigner) { | ||||||
|  | 		var signerApiRouter *mux.Router | ||||||
|  | 		if err := container.Resolve(&signerApiRouter, di.Name(ModuleSigner)); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		var authenticator Authenticator | ||||||
|  | 		if err := container.Resolve(&authenticator); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		authMiddleware := NewAuthenticationMiddleware(authenticator, security.SignScope) | ||||||
|  | 		conditionalAuth := NewConditionalMiddleware(func(req *http.Request) bool { | ||||||
|  | 			return req.Method != "GET" | ||||||
|  | 		}, authMiddleware) | ||||||
|  | 		signerApiRouter.Use(conditionalAuth) | ||||||
|  |  | ||||||
|  | 		mount(router, "/api/signer", signerApiRouter) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Resolve health checkers last, because all the services required by the application | 	// Resolve health checkers last, because all the services required by the application | ||||||
| @@ -79,25 +105,31 @@ func newHandlerFactory( | |||||||
| func newSkinsystemHandler( | func newSkinsystemHandler( | ||||||
| 	config *viper.Viper, | 	config *viper.Viper, | ||||||
| 	profilesProvider ProfilesProvider, | 	profilesProvider ProfilesProvider, | ||||||
| 	texturesSigner TexturesSigner, | 	texturesSigner SignerService, | ||||||
| ) *mux.Router { | ) *mux.Router { | ||||||
| 	config.SetDefault("textures.extra_param_name", "chrly") | 	config.SetDefault("textures.extra_param_name", "chrly") | ||||||
| 	config.SetDefault("textures.extra_param_value", "how do you tame a horse in Minecraft?") | 	config.SetDefault("textures.extra_param_value", "how do you tame a horse in Minecraft?") | ||||||
|  |  | ||||||
| 	return (&Skinsystem{ | 	return (&Skinsystem{ | ||||||
| 		ProfilesProvider:        profilesProvider, | 		ProfilesProvider:        profilesProvider, | ||||||
| 		TexturesSigner:          texturesSigner, | 		SignerService:           texturesSigner, | ||||||
| 		TexturesExtraParamName:  config.GetString("textures.extra_param_name"), | 		TexturesExtraParamName:  config.GetString("textures.extra_param_name"), | ||||||
| 		TexturesExtraParamValue: config.GetString("textures.extra_param_value"), | 		TexturesExtraParamValue: config.GetString("textures.extra_param_value"), | ||||||
| 	}).Handler() | 	}).Handler() | ||||||
| } | } | ||||||
|  |  | ||||||
| func newApiHandler(profilesManager ProfilesManager) *mux.Router { | func newProfilesApiHandler(profilesManager ProfilesManager) *mux.Router { | ||||||
| 	return (&Api{ | 	return (&ProfilesApi{ | ||||||
| 		ProfilesManager: profilesManager, | 		ProfilesManager: profilesManager, | ||||||
| 	}).Handler() | 	}).Handler() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func newSignerApiHandler(signer Signer) *mux.Router { | ||||||
|  | 	return (&SignerApi{ | ||||||
|  | 		Signer: signer, | ||||||
|  | 	}).Handler() | ||||||
|  | } | ||||||
|  |  | ||||||
| func mount(router *mux.Router, path string, handler http.Handler) { | func mount(router *mux.Router, path string, handler http.Handler) { | ||||||
| 	router.PathPrefix(path).Handler( | 	router.PathPrefix(path).Handler( | ||||||
| 		http.StripPrefix( | 		http.StripPrefix( | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"encoding/pem" | 	"encoding/pem" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	signerClient "ely.by/chrly/internal/client/signer" | ||||||
| 	"ely.by/chrly/internal/http" | 	"ely.by/chrly/internal/http" | ||||||
| 	"ely.by/chrly/internal/security" | 	"ely.by/chrly/internal/security" | ||||||
|  |  | ||||||
| @@ -16,12 +17,14 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| var securityDiOptions = di.Options( | var securityDiOptions = di.Options( | ||||||
| 	di.Provide(newTexturesSigner, | 	di.Provide(newSigner, | ||||||
| 		di.As(new(http.TexturesSigner)), | 		di.As(new(http.Signer)), | ||||||
|  | 		di.As(new(signerClient.Signer)), | ||||||
| 	), | 	), | ||||||
|  | 	di.Provide(newSignerService), | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func newTexturesSigner(config *viper.Viper) (*security.Signer, error) { | func newSigner(config *viper.Viper) (*security.Signer, error) { | ||||||
| 	keyStr := config.GetString("chrly.signing.key") | 	keyStr := config.GetString("chrly.signing.key") | ||||||
| 	if keyStr == "" { | 	if keyStr == "" { | ||||||
| 		// TODO: log a message about the generated signing key and the way to specify it permanently | 		// TODO: log a message about the generated signing key and the way to specify it permanently | ||||||
| @@ -54,3 +57,9 @@ func newTexturesSigner(config *viper.Viper) (*security.Signer, error) { | |||||||
|  |  | ||||||
| 	return security.NewSigner(privateKey), nil | 	return security.NewSigner(privateKey), nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func newSignerService(signer signerClient.Signer) http.SignerService { | ||||||
|  | 	return &signerClient.LocalSigner{ | ||||||
|  | 		Signer: signer, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -13,12 +13,10 @@ import ( | |||||||
| 	"github.com/mono83/slf" | 	"github.com/mono83/slf" | ||||||
| 	"github.com/mono83/slf/wd" | 	"github.com/mono83/slf/wd" | ||||||
|  |  | ||||||
| 	"ely.by/chrly/internal/version" | 	"ely.by/chrly/internal/security" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func StartServer(server *http.Server, logger slf.Logger) { | func StartServer(server *http.Server, logger slf.Logger) { | ||||||
| 	logger.Debug("Chrly :v (:c)", wd.StringParam("v", version.Version()), wd.StringParam("c", version.Commit())) |  | ||||||
|  |  | ||||||
| 	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, os.Kill) | 	ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM, os.Kill) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
|  |  | ||||||
| @@ -45,15 +43,13 @@ func StartServer(server *http.Server, logger slf.Logger) { | |||||||
| } | } | ||||||
|  |  | ||||||
| type Authenticator interface { | type Authenticator interface { | ||||||
| 	Authenticate(req *http.Request) error | 	Authenticate(req *http.Request, scope security.Scope) error | ||||||
| } | } | ||||||
|  |  | ||||||
| // The current middleware implementation doesn't check the scope assigned to the token. | func NewAuthenticationMiddleware(authenticator Authenticator, scope security.Scope) mux.MiddlewareFunc { | ||||||
| // For now there is only one scope and at this moment I don't want to spend time on it |  | ||||||
| func CreateAuthenticationMiddleware(checker Authenticator) mux.MiddlewareFunc { |  | ||||||
| 	return func(handler http.Handler) http.Handler { | 	return func(handler http.Handler) http.Handler { | ||||||
| 		return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | 		return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | ||||||
| 			err := checker.Authenticate(req) | 			err := authenticator.Authenticate(req, scope) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				apiForbidden(resp, err.Error()) | 				apiForbidden(resp, err.Error()) | ||||||
| 				return | 				return | ||||||
| @@ -64,6 +60,18 @@ func CreateAuthenticationMiddleware(checker Authenticator) mux.MiddlewareFunc { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func NewConditionalMiddleware(cond func(req *http.Request) bool, m mux.MiddlewareFunc) mux.MiddlewareFunc { | ||||||
|  | 	return func(handler http.Handler) http.Handler { | ||||||
|  | 		return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | ||||||
|  | 			if cond(req) { | ||||||
|  | 				handler = m.Middleware(handler) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			handler.ServeHTTP(resp, req) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func NotFoundHandler(response http.ResponseWriter, _ *http.Request) { | func NotFoundHandler(response http.ResponseWriter, _ *http.Request) { | ||||||
| 	data, _ := json.Marshal(map[string]string{ | 	data, _ := json.Marshal(map[string]string{ | ||||||
| 		"status":  "404", | 		"status":  "404", | ||||||
| @@ -78,7 +86,7 @@ func NotFoundHandler(response http.ResponseWriter, _ *http.Request) { | |||||||
| func apiBadRequest(resp http.ResponseWriter, errorsPerField map[string][]string) { | func apiBadRequest(resp http.ResponseWriter, errorsPerField map[string][]string) { | ||||||
| 	resp.WriteHeader(http.StatusBadRequest) | 	resp.WriteHeader(http.StatusBadRequest) | ||||||
| 	resp.Header().Set("Content-Type", "application/json") | 	resp.Header().Set("Content-Type", "application/json") | ||||||
| 	result, _ := json.Marshal(map[string]interface{}{ | 	result, _ := json.Marshal(map[string]any{ | ||||||
| 		"errors": errorsPerField, | 		"errors": errorsPerField, | ||||||
| 	}) | 	}) | ||||||
| 	_, _ = resp.Write(result) | 	_, _ = resp.Write(result) | ||||||
| @@ -95,7 +103,7 @@ func apiServerError(resp http.ResponseWriter, err error) { | |||||||
| func apiForbidden(resp http.ResponseWriter, reason string) { | func apiForbidden(resp http.ResponseWriter, reason string) { | ||||||
| 	resp.WriteHeader(http.StatusForbidden) | 	resp.WriteHeader(http.StatusForbidden) | ||||||
| 	resp.Header().Set("Content-Type", "application/json") | 	resp.Header().Set("Content-Type", "application/json") | ||||||
| 	result, _ := json.Marshal(map[string]interface{}{ | 	result, _ := json.Marshal(map[string]any{ | ||||||
| 		"error": reason, | 		"error": reason, | ||||||
| 	}) | 	}) | ||||||
| 	_, _ = resp.Write(result) | 	_, _ = resp.Write(result) | ||||||
|   | |||||||
| @@ -2,34 +2,35 @@ package http | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"io/ioutil" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	testify "github.com/stretchr/testify/assert" |  | ||||||
| 	"github.com/stretchr/testify/mock" | 	"github.com/stretchr/testify/mock" | ||||||
|  | 	testify "github.com/stretchr/testify/require" | ||||||
|  |  | ||||||
|  | 	"ely.by/chrly/internal/security" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type authCheckerMock struct { | type authCheckerMock struct { | ||||||
| 	mock.Mock | 	mock.Mock | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *authCheckerMock) Authenticate(req *http.Request) error { | func (m *authCheckerMock) Authenticate(req *http.Request, scope security.Scope) error { | ||||||
| 	args := m.Called(req) | 	return m.Called(req, scope).Error(0) | ||||||
| 	return args.Error(0) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestCreateAuthenticationMiddleware(t *testing.T) { | func TestAuthenticationMiddleware(t *testing.T) { | ||||||
| 	t.Run("pass", func(t *testing.T) { | 	t.Run("pass", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("GET", "http://example.com", nil) | 		req := httptest.NewRequest("GET", "https://example.com", nil) | ||||||
| 		resp := httptest.NewRecorder() | 		resp := httptest.NewRecorder() | ||||||
|  |  | ||||||
| 		auth := &authCheckerMock{} | 		auth := &authCheckerMock{} | ||||||
| 		auth.On("Authenticate", req).Once().Return(nil) | 		auth.On("Authenticate", req, security.Scope("mock")).Once().Return(nil) | ||||||
|  |  | ||||||
| 		isHandlerCalled := false | 		isHandlerCalled := false | ||||||
| 		middlewareFunc := CreateAuthenticationMiddleware(auth) | 		middlewareFunc := NewAuthenticationMiddleware(auth, "mock") | ||||||
| 		middlewareFunc.Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | 		middlewareFunc.Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | ||||||
| 			isHandlerCalled = true | 			isHandlerCalled = true | ||||||
| 		})).ServeHTTP(resp, req) | 		})).ServeHTTP(resp, req) | ||||||
| @@ -40,21 +41,21 @@ func TestCreateAuthenticationMiddleware(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("fail", func(t *testing.T) { | 	t.Run("fail", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("GET", "http://example.com", nil) | 		req := httptest.NewRequest("GET", "https://example.com", nil) | ||||||
| 		resp := httptest.NewRecorder() | 		resp := httptest.NewRecorder() | ||||||
|  |  | ||||||
| 		auth := &authCheckerMock{} | 		auth := &authCheckerMock{} | ||||||
| 		auth.On("Authenticate", req).Once().Return(errors.New("error reason")) | 		auth.On("Authenticate", req, security.Scope("mock")).Once().Return(errors.New("error reason")) | ||||||
|  |  | ||||||
| 		isHandlerCalled := false | 		isHandlerCalled := false | ||||||
| 		middlewareFunc := CreateAuthenticationMiddleware(auth) | 		middlewareFunc := NewAuthenticationMiddleware(auth, "mock") | ||||||
| 		middlewareFunc.Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | 		middlewareFunc.Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | ||||||
| 			isHandlerCalled = true | 			isHandlerCalled = true | ||||||
| 		})).ServeHTTP(resp, req) | 		})).ServeHTTP(resp, req) | ||||||
|  |  | ||||||
| 		testify.False(t, isHandlerCalled, "Handler shouldn't be called") | 		testify.False(t, isHandlerCalled, "Handler shouldn't be called") | ||||||
| 		testify.Equal(t, 403, resp.Code) | 		testify.Equal(t, 403, resp.Code) | ||||||
| 		body, _ := ioutil.ReadAll(resp.Body) | 		body, _ := io.ReadAll(resp.Body) | ||||||
| 		testify.JSONEq(t, `{ | 		testify.JSONEq(t, `{ | ||||||
| 			"error": "error reason" | 			"error": "error reason" | ||||||
| 		}`, string(body)) | 		}`, string(body)) | ||||||
| @@ -63,10 +64,56 @@ func TestCreateAuthenticationMiddleware(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestConditionalMiddleware(t *testing.T) { | ||||||
|  | 	t.Run("true", func(t *testing.T) { | ||||||
|  | 		req := httptest.NewRequest("GET", "https://example.com", nil) | ||||||
|  | 		resp := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		isNestedMiddlewareCalled := false | ||||||
|  | 		isHandlerCalled := false | ||||||
|  | 		NewConditionalMiddleware( | ||||||
|  | 			func(req *http.Request) bool { | ||||||
|  | 				return true | ||||||
|  | 			}, | ||||||
|  | 			func(handler http.Handler) http.Handler { | ||||||
|  | 				isNestedMiddlewareCalled = true | ||||||
|  | 				return handler | ||||||
|  | 			}, | ||||||
|  | 		).Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | ||||||
|  | 			isHandlerCalled = true | ||||||
|  | 		})).ServeHTTP(resp, req) | ||||||
|  |  | ||||||
|  | 		testify.True(t, isNestedMiddlewareCalled, "Nested middleware wasn't called") | ||||||
|  | 		testify.True(t, isHandlerCalled, "Handler wasn't called from the middleware") | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("false", func(t *testing.T) { | ||||||
|  | 		req := httptest.NewRequest("GET", "https://example.com", nil) | ||||||
|  | 		resp := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		isNestedMiddlewareCalled := false | ||||||
|  | 		isHandlerCalled := false | ||||||
|  | 		NewConditionalMiddleware( | ||||||
|  | 			func(req *http.Request) bool { | ||||||
|  | 				return false | ||||||
|  | 			}, | ||||||
|  | 			func(handler http.Handler) http.Handler { | ||||||
|  | 				isNestedMiddlewareCalled = true | ||||||
|  | 				return handler | ||||||
|  | 			}, | ||||||
|  | 		).Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | ||||||
|  | 			isHandlerCalled = true | ||||||
|  | 		})).ServeHTTP(resp, req) | ||||||
|  |  | ||||||
|  | 		testify.False(t, isNestedMiddlewareCalled, "Nested middleware shouldn't be called") | ||||||
|  | 		testify.True(t, isHandlerCalled, "Handler wasn't called from the middleware") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestNotFoundHandler(t *testing.T) { | func TestNotFoundHandler(t *testing.T) { | ||||||
| 	assert := testify.New(t) | 	assert := testify.New(t) | ||||||
|  |  | ||||||
| 	req := httptest.NewRequest("GET", "http://example.com", nil) | 	req := httptest.NewRequest("GET", "https://example.com", nil) | ||||||
| 	w := httptest.NewRecorder() | 	w := httptest.NewRecorder() | ||||||
|  |  | ||||||
| 	NotFoundHandler(w, req) | 	NotFoundHandler(w, req) | ||||||
| @@ -74,7 +121,7 @@ func TestNotFoundHandler(t *testing.T) { | |||||||
| 	resp := w.Result() | 	resp := w.Result() | ||||||
| 	assert.Equal(404, resp.StatusCode) | 	assert.Equal(404, resp.StatusCode) | ||||||
| 	assert.Equal("application/json", resp.Header.Get("Content-Type")) | 	assert.Equal("application/json", resp.Header.Get("Content-Type")) | ||||||
| 	response, _ := ioutil.ReadAll(resp.Body) | 	response, _ := io.ReadAll(resp.Body) | ||||||
| 	assert.JSONEq(`{ | 	assert.JSONEq(`{ | ||||||
| 		"status": "404", | 		"status": "404", | ||||||
| 		"message": "Not Found" | 		"message": "Not Found" | ||||||
|   | |||||||
| @@ -17,19 +17,19 @@ type ProfilesManager interface { | |||||||
| 	RemoveProfileByUuid(ctx context.Context, uuid string) error | 	RemoveProfileByUuid(ctx context.Context, uuid string) error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type Api struct { | type ProfilesApi struct { | ||||||
| 	ProfilesManager | 	ProfilesManager | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (ctx *Api) Handler() *mux.Router { | func (ctx *ProfilesApi) Handler() *mux.Router { | ||||||
| 	router := mux.NewRouter().StrictSlash(true) | 	router := mux.NewRouter().StrictSlash(true) | ||||||
| 	router.HandleFunc("/profiles", ctx.postProfileHandler).Methods(http.MethodPost) | 	router.HandleFunc("/", ctx.postProfileHandler).Methods(http.MethodPost) | ||||||
| 	router.HandleFunc("/profiles/{uuid}", ctx.deleteProfileByUuidHandler).Methods(http.MethodDelete) | 	router.HandleFunc("/{uuid}", ctx.deleteProfileByUuidHandler).Methods(http.MethodDelete) | ||||||
| 
 | 
 | ||||||
| 	return router | 	return router | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (ctx *Api) postProfileHandler(resp http.ResponseWriter, req *http.Request) { | func (ctx *ProfilesApi) postProfileHandler(resp http.ResponseWriter, req *http.Request) { | ||||||
| 	err := req.ParseForm() | 	err := req.ParseForm() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		apiBadRequest(resp, map[string][]string{ | 		apiBadRequest(resp, map[string][]string{ | ||||||
| @@ -63,7 +63,7 @@ func (ctx *Api) postProfileHandler(resp http.ResponseWriter, req *http.Request) | |||||||
| 	resp.WriteHeader(http.StatusCreated) | 	resp.WriteHeader(http.StatusCreated) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (ctx *Api) deleteProfileByUuidHandler(resp http.ResponseWriter, req *http.Request) { | func (ctx *ProfilesApi) deleteProfileByUuidHandler(resp http.ResponseWriter, req *http.Request) { | ||||||
| 	uuid := mux.Vars(req)["uuid"] | 	uuid := mux.Vars(req)["uuid"] | ||||||
| 	err := ctx.ProfilesManager.RemoveProfileByUuid(req.Context(), uuid) | 	err := ctx.ProfilesManager.RemoveProfileByUuid(req.Context(), uuid) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -30,26 +30,26 @@ func (m *ProfilesManagerMock) RemoveProfileByUuid(ctx context.Context, uuid stri | |||||||
| 	return m.Called(ctx, uuid).Error(0) | 	return m.Called(ctx, uuid).Error(0) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type ApiTestSuite struct { | type ProfilesTestSuite struct { | ||||||
| 	suite.Suite | 	suite.Suite | ||||||
| 
 | 
 | ||||||
| 	App *Api | 	App *ProfilesApi | ||||||
| 
 | 
 | ||||||
| 	ProfilesManager *ProfilesManagerMock | 	ProfilesManager *ProfilesManagerMock | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *ApiTestSuite) SetupSubTest() { | func (t *ProfilesTestSuite) SetupSubTest() { | ||||||
| 	t.ProfilesManager = &ProfilesManagerMock{} | 	t.ProfilesManager = &ProfilesManagerMock{} | ||||||
| 	t.App = &Api{ | 	t.App = &ProfilesApi{ | ||||||
| 		ProfilesManager: t.ProfilesManager, | 		ProfilesManager: t.ProfilesManager, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *ApiTestSuite) TearDownSubTest() { | func (t *ProfilesTestSuite) TearDownSubTest() { | ||||||
| 	t.ProfilesManager.AssertExpectations(t.T()) | 	t.ProfilesManager.AssertExpectations(t.T()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *ApiTestSuite) TestPostProfile() { | func (t *ProfilesTestSuite) TestPostProfile() { | ||||||
| 	t.Run("successfully post profile", func() { | 	t.Run("successfully post profile", func() { | ||||||
| 		t.ProfilesManager.On("PersistProfile", mock.Anything, &db.Profile{ | 		t.ProfilesManager.On("PersistProfile", mock.Anything, &db.Profile{ | ||||||
| 			Uuid:            "0f657aa8-bfbe-415d-b700-5750090d3af3", | 			Uuid:            "0f657aa8-bfbe-415d-b700-5750090d3af3", | ||||||
| @@ -61,7 +61,7 @@ func (t *ApiTestSuite) TestPostProfile() { | |||||||
| 			MojangSignature: "bW9jawo=", | 			MojangSignature: "bW9jawo=", | ||||||
| 		}).Once().Return(nil) | 		}).Once().Return(nil) | ||||||
| 
 | 
 | ||||||
| 		req := httptest.NewRequest("POST", "http://chrly/profiles", bytes.NewBufferString(url.Values{ | 		req := httptest.NewRequest("POST", "http://chrly/", bytes.NewBufferString(url.Values{ | ||||||
| 			"uuid":            {"0f657aa8-bfbe-415d-b700-5750090d3af3"}, | 			"uuid":            {"0f657aa8-bfbe-415d-b700-5750090d3af3"}, | ||||||
| 			"username":        {"mock_username"}, | 			"username":        {"mock_username"}, | ||||||
| 			"skinUrl":         {"https://example.com/skin.png"}, | 			"skinUrl":         {"https://example.com/skin.png"}, | ||||||
| @@ -82,7 +82,7 @@ func (t *ApiTestSuite) TestPostProfile() { | |||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	t.Run("handle malformed body", func() { | 	t.Run("handle malformed body", func() { | ||||||
| 		req := httptest.NewRequest("POST", "http://chrly/profiles", strings.NewReader("invalid;=url?encoded_string")) | 		req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("invalid;=url?encoded_string")) | ||||||
| 		req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | 		req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| @@ -107,7 +107,7 @@ func (t *ApiTestSuite) TestPostProfile() { | |||||||
| 			}, | 			}, | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		req := httptest.NewRequest("POST", "http://chrly/profiles", strings.NewReader("")) | 		req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("")) | ||||||
| 		req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | 		req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| @@ -129,7 +129,7 @@ func (t *ApiTestSuite) TestPostProfile() { | |||||||
| 	t.Run("receive other error", func() { | 	t.Run("receive other error", func() { | ||||||
| 		t.ProfilesManager.On("PersistProfile", mock.Anything, mock.Anything).Once().Return(errors.New("mock error")) | 		t.ProfilesManager.On("PersistProfile", mock.Anything, mock.Anything).Once().Return(errors.New("mock error")) | ||||||
| 
 | 
 | ||||||
| 		req := httptest.NewRequest("POST", "http://chrly/profiles", strings.NewReader("")) | 		req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("")) | ||||||
| 		req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | 		req.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| @@ -140,11 +140,11 @@ func (t *ApiTestSuite) TestPostProfile() { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (t *ApiTestSuite) TestDeleteProfileByUuid() { | func (t *ProfilesTestSuite) TestDeleteProfileByUuid() { | ||||||
| 	t.Run("successfully delete", func() { | 	t.Run("successfully delete", func() { | ||||||
| 		t.ProfilesManager.On("RemoveProfileByUuid", mock.Anything, "0f657aa8-bfbe-415d-b700-5750090d3af3").Once().Return(nil) | 		t.ProfilesManager.On("RemoveProfileByUuid", mock.Anything, "0f657aa8-bfbe-415d-b700-5750090d3af3").Once().Return(nil) | ||||||
| 
 | 
 | ||||||
| 		req := httptest.NewRequest("DELETE", "http://chrly/profiles/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) | 		req := httptest.NewRequest("DELETE", "http://chrly/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) | ||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 		t.App.Handler().ServeHTTP(w, req) | 		t.App.Handler().ServeHTTP(w, req) | ||||||
| @@ -158,7 +158,7 @@ func (t *ApiTestSuite) TestDeleteProfileByUuid() { | |||||||
| 	t.Run("error from manager", func() { | 	t.Run("error from manager", func() { | ||||||
| 		t.ProfilesManager.On("RemoveProfileByUuid", mock.Anything, mock.Anything).Return(errors.New("mock error")) | 		t.ProfilesManager.On("RemoveProfileByUuid", mock.Anything, mock.Anything).Return(errors.New("mock error")) | ||||||
| 
 | 
 | ||||||
| 		req := httptest.NewRequest("DELETE", "http://chrly/profiles/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) | 		req := httptest.NewRequest("DELETE", "http://chrly/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) | ||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 		t.App.Handler().ServeHTTP(w, req) | 		t.App.Handler().ServeHTTP(w, req) | ||||||
| @@ -168,6 +168,6 @@ func (t *ApiTestSuite) TestDeleteProfileByUuid() { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestApi(t *testing.T) { | func TestProfilesApi(t *testing.T) { | ||||||
| 	suite.Run(t, new(ApiTestSuite)) | 	suite.Run(t, new(ProfilesTestSuite)) | ||||||
| } | } | ||||||
							
								
								
									
										60
									
								
								internal/http/signer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								internal/http/signer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,60 @@ | |||||||
|  | package http | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  |  | ||||||
|  | 	"github.com/gorilla/mux" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Signer interface { | ||||||
|  | 	Sign(data io.Reader) ([]byte, error) | ||||||
|  | 	GetPublicKey(format string) ([]byte, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type SignerApi struct { | ||||||
|  | 	Signer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *SignerApi) Handler() *mux.Router { | ||||||
|  | 	router := mux.NewRouter().StrictSlash(true) | ||||||
|  | 	router.HandleFunc("/", s.signHandler).Methods(http.MethodPost) | ||||||
|  | 	router.HandleFunc("/public-key.{format:(?:pem|der)}", s.getPublicKeyHandler).Methods(http.MethodGet) | ||||||
|  |  | ||||||
|  | 	return router | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *SignerApi) signHandler(resp http.ResponseWriter, req *http.Request) { | ||||||
|  | 	signature, err := s.Signer.Sign(req.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		apiServerError(resp, fmt.Errorf("unable to sign the value: %w", err)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	resp.Header().Set("Content-Type", "application/octet-stream+base64") | ||||||
|  |  | ||||||
|  | 	buf := make([]byte, base64.StdEncoding.EncodedLen(len(signature))) | ||||||
|  | 	base64.StdEncoding.Encode(buf, signature) | ||||||
|  | 	_, _ = resp.Write(buf) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *SignerApi) getPublicKeyHandler(resp http.ResponseWriter, req *http.Request) { | ||||||
|  | 	format := mux.Vars(req)["format"] | ||||||
|  | 	publicKey, err := s.Signer.GetPublicKey(format) | ||||||
|  | 	if err != nil { | ||||||
|  | 		apiServerError(resp, fmt.Errorf("unable to retrieve public key: %w", err)) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if format == "pem" { | ||||||
|  | 		resp.Header().Set("Content-Type", "application/x-pem-file") | ||||||
|  | 		resp.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.pem"`) | ||||||
|  | 	} else { | ||||||
|  | 		resp.Header().Set("Content-Type", "application/octet-stream") | ||||||
|  | 		resp.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.der"`) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, _ = resp.Write(publicKey) | ||||||
|  | } | ||||||
							
								
								
									
										146
									
								
								internal/http/signer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										146
									
								
								internal/http/signer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,146 @@ | |||||||
|  | package http | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"errors" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/mock" | ||||||
|  | 	"github.com/stretchr/testify/suite" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type SignerMock struct { | ||||||
|  | 	mock.Mock | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *SignerMock) Sign(data io.Reader) ([]byte, error) { | ||||||
|  | 	args := m.Called(data) | ||||||
|  | 	var result []byte | ||||||
|  | 	if casted, ok := args.Get(0).([]byte); ok { | ||||||
|  | 		result = casted | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result, args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *SignerMock) GetPublicKey(format string) ([]byte, error) { | ||||||
|  | 	args := m.Called(format) | ||||||
|  | 	var result []byte | ||||||
|  | 	if casted, ok := args.Get(0).([]byte); ok { | ||||||
|  | 		result = casted | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return result, args.Error(1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type SignerApiTestSuite struct { | ||||||
|  | 	suite.Suite | ||||||
|  |  | ||||||
|  | 	App *SignerApi | ||||||
|  |  | ||||||
|  | 	Signer *SignerMock | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *SignerApiTestSuite) SetupSubTest() { | ||||||
|  | 	t.Signer = &SignerMock{} | ||||||
|  |  | ||||||
|  | 	t.App = &SignerApi{ | ||||||
|  | 		Signer: t.Signer, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *SignerApiTestSuite) TearDownSubTest() { | ||||||
|  | 	t.Signer.AssertExpectations(t.T()) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *SignerApiTestSuite) TestSign() { | ||||||
|  | 	t.Run("successfully sign", func() { | ||||||
|  | 		signature := []byte("mock signature") | ||||||
|  | 		t.Signer.On("Sign", mock.Anything).Return(signature, nil).Run(func(args mock.Arguments) { | ||||||
|  | 			buf := &bytes.Buffer{} | ||||||
|  | 			_, _ = io.Copy(buf, args.Get(0).(io.Reader)) | ||||||
|  | 			r, _ := io.ReadAll(buf) | ||||||
|  |  | ||||||
|  | 			t.Equal([]byte("mock body to sign"), r) | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("mock body to sign")) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusOK, result.StatusCode) | ||||||
|  | 		t.Equal("application/octet-stream+base64", result.Header.Get("Content-Type")) | ||||||
|  | 		body, _ := io.ReadAll(result.Body) | ||||||
|  | 		t.Equal([]byte{0x62, 0x57, 0x39, 0x6a, 0x61, 0x79, 0x42, 0x7a, 0x61, 0x57, 0x64, 0x75, 0x59, 0x58, 0x52, 0x31, 0x63, 0x6d, 0x55, 0x3d}, body) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("handle error during sign", func() { | ||||||
|  | 		t.Signer.On("Sign", mock.Anything).Return(nil, errors.New("mock error")) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("mock body to sign")) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusInternalServerError, result.StatusCode) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *SignerApiTestSuite) TestGetPublicKey() { | ||||||
|  | 	t.Run("in pem format", func() { | ||||||
|  | 		publicKey := []byte("mock public key in pem format") | ||||||
|  | 		t.Signer.On("GetPublicKey", "pem").Return(publicKey, nil) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("GET", "http://chrly/public-key.pem", nil) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusOK, result.StatusCode) | ||||||
|  | 		t.Equal("application/x-pem-file", result.Header.Get("Content-Type")) | ||||||
|  | 		t.Equal(`attachment; filename="yggdrasil_session_pubkey.pem"`, result.Header.Get("Content-Disposition")) | ||||||
|  | 		body, _ := io.ReadAll(result.Body) | ||||||
|  | 		t.Equal(publicKey, body) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("in der format", func() { | ||||||
|  | 		publicKey := []byte("mock public key in der format") | ||||||
|  | 		t.Signer.On("GetPublicKey", "der").Return(publicKey, nil) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("GET", "http://chrly/public-key.der", nil) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusOK, result.StatusCode) | ||||||
|  | 		t.Equal("application/octet-stream", result.Header.Get("Content-Type")) | ||||||
|  | 		t.Equal(`attachment; filename="yggdrasil_session_pubkey.der"`, result.Header.Get("Content-Disposition")) | ||||||
|  | 		body, _ := io.ReadAll(result.Body) | ||||||
|  | 		t.Equal(publicKey, body) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("handle error", func() { | ||||||
|  | 		t.Signer.On("GetPublicKey", "pem").Return(nil, errors.New("mock error")) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("GET", "http://chrly/public-key.pem", nil) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusInternalServerError, result.StatusCode) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestSignerApi(t *testing.T) { | ||||||
|  | 	suite.Run(t, new(SignerApiTestSuite)) | ||||||
|  | } | ||||||
| @@ -2,12 +2,10 @@ package http | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/rsa" |  | ||||||
| 	"crypto/x509" |  | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"encoding/pem" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -25,40 +23,39 @@ type ProfilesProvider interface { | |||||||
| 	FindProfileByUsername(ctx context.Context, username string, allowProxy bool) (*db.Profile, error) | 	FindProfileByUsername(ctx context.Context, username string, allowProxy bool) (*db.Profile, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| // TexturesSigner uses context because in the future we may separate this logic into a separate microservice | // SignerService uses context because in the future we may separate this logic as an external microservice | ||||||
| type TexturesSigner interface { | type SignerService interface { | ||||||
| 	SignTextures(ctx context.Context, textures string) (string, error) | 	Sign(ctx context.Context, data string) (string, error) | ||||||
| 	GetPublicKey(ctx context.Context) (*rsa.PublicKey, error) | 	GetPublicKey(ctx context.Context, format string) (string, error) | ||||||
| } | } | ||||||
|  |  | ||||||
| type Skinsystem struct { | type Skinsystem struct { | ||||||
| 	ProfilesProvider | 	ProfilesProvider | ||||||
| 	TexturesSigner | 	SignerService | ||||||
| 	TexturesExtraParamName  string | 	TexturesExtraParamName  string | ||||||
| 	TexturesExtraParamValue string | 	TexturesExtraParamValue string | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) Handler() *mux.Router { | func (s *Skinsystem) Handler() *mux.Router { | ||||||
| 	router := mux.NewRouter().StrictSlash(true) | 	router := mux.NewRouter().StrictSlash(true) | ||||||
|  |  | ||||||
| 	router.HandleFunc("/skins/{username}", ctx.skinHandler).Methods(http.MethodGet) | 	router.HandleFunc("/skins/{username}", s.skinHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/cloaks/{username}", ctx.capeHandler).Methods(http.MethodGet) | 	router.HandleFunc("/cloaks/{username}", s.capeHandler).Methods(http.MethodGet) | ||||||
| 	// TODO: alias /capes/{username}? | 	// TODO: alias /capes/{username}? | ||||||
| 	router.HandleFunc("/textures/{username}", ctx.texturesHandler).Methods(http.MethodGet) | 	router.HandleFunc("/textures/{username}", s.texturesHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/textures/signed/{username}", ctx.signedTexturesHandler).Methods(http.MethodGet) | 	router.HandleFunc("/textures/signed/{username}", s.signedTexturesHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/profile/{username}", ctx.profileHandler).Methods(http.MethodGet) | 	router.HandleFunc("/profile/{username}", s.profileHandler).Methods(http.MethodGet) | ||||||
| 	// Legacy | 	// Legacy | ||||||
| 	router.HandleFunc("/skins", ctx.skinGetHandler).Methods(http.MethodGet) | 	router.HandleFunc("/skins", s.skinGetHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/cloaks", ctx.capeGetHandler).Methods(http.MethodGet) | 	router.HandleFunc("/cloaks", s.capeGetHandler).Methods(http.MethodGet) | ||||||
| 	// Utils | 	// Utils | ||||||
| 	router.HandleFunc("/signature-verification-key.der", ctx.signatureVerificationKeyHandler).Methods(http.MethodGet) | 	router.HandleFunc("/signature-verification-key.{format:(?:pem|der)}", s.signatureVerificationKeyHandler).Methods(http.MethodGet) | ||||||
| 	router.HandleFunc("/signature-verification-key.pem", ctx.signatureVerificationKeyHandler).Methods(http.MethodGet) |  | ||||||
|  |  | ||||||
| 	return router | 	return router | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) skinHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) skinHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) | 	profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | ||||||
| 		return | 		return | ||||||
| @@ -71,7 +68,7 @@ func (ctx *Skinsystem) skinHandler(response http.ResponseWriter, request *http.R | |||||||
| 	http.Redirect(response, request, profile.SkinUrl, http.StatusMovedPermanently) | 	http.Redirect(response, request, profile.SkinUrl, http.StatusMovedPermanently) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) skinGetHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) skinGetHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	username := request.URL.Query().Get("name") | 	username := request.URL.Query().Get("name") | ||||||
| 	if username == "" { | 	if username == "" { | ||||||
| 		response.WriteHeader(http.StatusBadRequest) | 		response.WriteHeader(http.StatusBadRequest) | ||||||
| @@ -80,11 +77,11 @@ func (ctx *Skinsystem) skinGetHandler(response http.ResponseWriter, request *htt | |||||||
|  |  | ||||||
| 	mux.Vars(request)["username"] = username | 	mux.Vars(request)["username"] = username | ||||||
|  |  | ||||||
| 	ctx.skinHandler(response, request) | 	s.skinHandler(response, request) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) capeHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) capeHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) | 	profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | ||||||
| 		return | 		return | ||||||
| @@ -97,7 +94,7 @@ func (ctx *Skinsystem) capeHandler(response http.ResponseWriter, request *http.R | |||||||
| 	http.Redirect(response, request, profile.CapeUrl, http.StatusMovedPermanently) | 	http.Redirect(response, request, profile.CapeUrl, http.StatusMovedPermanently) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) capeGetHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) capeGetHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	username := request.URL.Query().Get("name") | 	username := request.URL.Query().Get("name") | ||||||
| 	if username == "" { | 	if username == "" { | ||||||
| 		response.WriteHeader(http.StatusBadRequest) | 		response.WriteHeader(http.StatusBadRequest) | ||||||
| @@ -106,11 +103,11 @@ func (ctx *Skinsystem) capeGetHandler(response http.ResponseWriter, request *htt | |||||||
|  |  | ||||||
| 	mux.Vars(request)["username"] = username | 	mux.Vars(request)["username"] = username | ||||||
|  |  | ||||||
| 	ctx.capeHandler(response, request) | 	s.capeHandler(response, request) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) texturesHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) texturesHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) | 	profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | ||||||
| 		return | 		return | ||||||
| @@ -133,8 +130,8 @@ func (ctx *Skinsystem) texturesHandler(response http.ResponseWriter, request *ht | |||||||
| 	_, _ = response.Write(responseData) | 	_, _ = response.Write(responseData) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) signedTexturesHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) signedTexturesHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	profile, err := ctx.ProfilesProvider.FindProfileByUsername( | 	profile, err := s.ProfilesProvider.FindProfileByUsername( | ||||||
| 		request.Context(), | 		request.Context(), | ||||||
| 		mux.Vars(request)["username"], | 		mux.Vars(request)["username"], | ||||||
| 		getToBool(request.URL.Query().Get("proxy")), | 		getToBool(request.URL.Query().Get("proxy")), | ||||||
| @@ -164,8 +161,8 @@ func (ctx *Skinsystem) signedTexturesHandler(response http.ResponseWriter, reque | |||||||
| 				Value:     profile.MojangTextures, | 				Value:     profile.MojangTextures, | ||||||
| 			}, | 			}, | ||||||
| 			{ | 			{ | ||||||
| 				Name:  ctx.TexturesExtraParamName, | 				Name:  s.TexturesExtraParamName, | ||||||
| 				Value: ctx.TexturesExtraParamValue, | 				Value: s.TexturesExtraParamValue, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| @@ -175,8 +172,8 @@ func (ctx *Skinsystem) signedTexturesHandler(response http.ResponseWriter, reque | |||||||
| 	_, _ = response.Write(responseJson) | 	_, _ = response.Write(responseJson) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) profileHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) | 	profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | 		apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) | ||||||
| 		return | 		return | ||||||
| @@ -203,7 +200,7 @@ func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *htt | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if request.URL.Query().Has("unsigned") && !getToBool(request.URL.Query().Get("unsigned")) { | 	if request.URL.Query().Has("unsigned") && !getToBool(request.URL.Query().Get("unsigned")) { | ||||||
| 		signature, err := ctx.TexturesSigner.SignTextures(request.Context(), texturesProp.Value) | 		signature, err := s.SignerService.Sign(request.Context(), texturesProp.Value) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			apiServerError(response, fmt.Errorf("unable to sign textures: %w", err)) | 			apiServerError(response, fmt.Errorf("unable to sign textures: %w", err)) | ||||||
| 			return | 			return | ||||||
| @@ -218,8 +215,8 @@ func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *htt | |||||||
| 		Props: []*mojang.Property{ | 		Props: []*mojang.Property{ | ||||||
| 			texturesProp, | 			texturesProp, | ||||||
| 			{ | 			{ | ||||||
| 				Name:  ctx.TexturesExtraParamName, | 				Name:  s.TexturesExtraParamName, | ||||||
| 				Value: ctx.TexturesExtraParamValue, | 				Value: s.TexturesExtraParamValue, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| @@ -229,32 +226,23 @@ func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *htt | |||||||
| 	_, _ = response.Write(responseJson) | 	_, _ = response.Write(responseJson) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ctx *Skinsystem) signatureVerificationKeyHandler(response http.ResponseWriter, request *http.Request) { | func (s *Skinsystem) signatureVerificationKeyHandler(response http.ResponseWriter, request *http.Request) { | ||||||
| 	publicKey, err := ctx.TexturesSigner.GetPublicKey(request.Context()) | 	format := mux.Vars(request)["format"] | ||||||
|  | 	publicKey, err := s.SignerService.GetPublicKey(request.Context(), format) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic(err) | 		apiServerError(response, fmt.Errorf("unable to retrieve public key: %w", err)) | ||||||
|  | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	asn1Bytes, err := x509.MarshalPKIXPublicKey(publicKey) | 	if format == "pem" { | ||||||
| 	if err != nil { | 		response.Header().Set("Content-Type", "application/x-pem-file") | ||||||
| 		panic(err) | 		response.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.pem"`) | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if strings.HasSuffix(request.URL.Path, ".pem") { |  | ||||||
| 		publicKeyBlock := pem.Block{ |  | ||||||
| 			Type:  "PUBLIC KEY", |  | ||||||
| 			Bytes: asn1Bytes, |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		publicKeyPemBytes := pem.EncodeToMemory(&publicKeyBlock) |  | ||||||
|  |  | ||||||
| 		response.Header().Set("Content-Disposition", "attachment; filename=\"yggdrasil_session_pubkey.pem\"") |  | ||||||
| 		_, _ = response.Write(publicKeyPemBytes) |  | ||||||
| 	} else { | 	} else { | ||||||
| 		response.Header().Set("Content-Type", "application/octet-stream") | 		response.Header().Set("Content-Type", "application/octet-stream") | ||||||
| 		response.Header().Set("Content-Disposition", "attachment; filename=\"yggdrasil_session_pubkey.der\"") | 		response.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.der"`) | ||||||
| 		_, _ = response.Write(asn1Bytes) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	_, _ = io.WriteString(response, publicKey) | ||||||
| } | } | ||||||
|  |  | ||||||
| func parseUsername(username string) string { | func parseUsername(username string) string { | ||||||
|   | |||||||
| @@ -2,14 +2,10 @@ package http | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/rsa" |  | ||||||
| 	"crypto/x509" |  | ||||||
| 	"encoding/pem" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"strings" |  | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -34,23 +30,18 @@ func (m *ProfilesProviderMock) FindProfileByUsername(ctx context.Context, userna | |||||||
| 	return result, args.Error(1) | 	return result, args.Error(1) | ||||||
| } | } | ||||||
|  |  | ||||||
| type TexturesSignerMock struct { | type SignerServiceMock struct { | ||||||
| 	mock.Mock | 	mock.Mock | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *TexturesSignerMock) SignTextures(ctx context.Context, textures string) (string, error) { | func (m *SignerServiceMock) Sign(ctx context.Context, data string) (string, error) { | ||||||
| 	args := m.Called(ctx, textures) | 	args := m.Called(ctx, data) | ||||||
| 	return args.String(0), args.Error(1) | 	return args.String(0), args.Error(1) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *TexturesSignerMock) GetPublicKey(ctx context.Context) (*rsa.PublicKey, error) { | func (m *SignerServiceMock) GetPublicKey(ctx context.Context, format string) (string, error) { | ||||||
| 	args := m.Called(ctx) | 	args := m.Called(ctx, format) | ||||||
| 	var publicKey *rsa.PublicKey | 	return args.String(0), args.Error(1) | ||||||
| 	if casted, ok := args.Get(0).(*rsa.PublicKey); ok { |  | ||||||
| 		publicKey = casted |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return publicKey, args.Error(1) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type SkinsystemTestSuite struct { | type SkinsystemTestSuite struct { | ||||||
| @@ -59,7 +50,7 @@ type SkinsystemTestSuite struct { | |||||||
| 	App *Skinsystem | 	App *Skinsystem | ||||||
|  |  | ||||||
| 	ProfilesProvider *ProfilesProviderMock | 	ProfilesProvider *ProfilesProviderMock | ||||||
| 	TexturesSigner   *TexturesSignerMock | 	SignerService    *SignerServiceMock | ||||||
| } | } | ||||||
|  |  | ||||||
| /******************** | /******************** | ||||||
| @@ -73,11 +64,11 @@ func (t *SkinsystemTestSuite) SetupSubTest() { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	t.ProfilesProvider = &ProfilesProviderMock{} | 	t.ProfilesProvider = &ProfilesProviderMock{} | ||||||
| 	t.TexturesSigner = &TexturesSignerMock{} | 	t.SignerService = &SignerServiceMock{} | ||||||
|  |  | ||||||
| 	t.App = &Skinsystem{ | 	t.App = &Skinsystem{ | ||||||
| 		ProfilesProvider:        t.ProfilesProvider, | 		ProfilesProvider:        t.ProfilesProvider, | ||||||
| 		TexturesSigner:          t.TexturesSigner, | 		SignerService:           t.SignerService, | ||||||
| 		TexturesExtraParamName:  "texturesParamName", | 		TexturesExtraParamName:  "texturesParamName", | ||||||
| 		TexturesExtraParamValue: "texturesParamValue", | 		TexturesExtraParamValue: "texturesParamValue", | ||||||
| 	} | 	} | ||||||
| @@ -85,7 +76,7 @@ func (t *SkinsystemTestSuite) SetupSubTest() { | |||||||
|  |  | ||||||
| func (t *SkinsystemTestSuite) TearDownSubTest() { | func (t *SkinsystemTestSuite) TearDownSubTest() { | ||||||
| 	t.ProfilesProvider.AssertExpectations(t.T()) | 	t.ProfilesProvider.AssertExpectations(t.T()) | ||||||
| 	t.TexturesSigner.AssertExpectations(t.T()) | 	t.SignerService.AssertExpectations(t.T()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *SkinsystemTestSuite) TestSkinHandler() { | func (t *SkinsystemTestSuite) TestSkinHandler() { | ||||||
| @@ -470,7 +461,7 @@ func (t *SkinsystemTestSuite) TestProfile() { | |||||||
| 			SkinUrl:   "https://example.com/skin.png", | 			SkinUrl:   "https://example.com/skin.png", | ||||||
| 			SkinModel: "slim", | 			SkinModel: "slim", | ||||||
| 		}, nil) | 		}, nil) | ||||||
| 		t.TexturesSigner.On("SignTextures", mock.Anything, "eyJ0aW1lc3RhbXAiOjE2MTQyMTQyMjMwMDAsInByb2ZpbGVJZCI6Im1vY2stdXVpZCIsInByb2ZpbGVOYW1lIjoibW9ja191c2VybmFtZSIsInRleHR1cmVzIjp7IlNLSU4iOnsidXJsIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9za2luLnBuZyIsIm1ldGFkYXRhIjp7Im1vZGVsIjoic2xpbSJ9fX19").Return("mock signature", nil) | 		t.SignerService.On("Sign", mock.Anything, "eyJ0aW1lc3RhbXAiOjE2MTQyMTQyMjMwMDAsInByb2ZpbGVJZCI6Im1vY2stdXVpZCIsInByb2ZpbGVOYW1lIjoibW9ja191c2VybmFtZSIsInRleHR1cmVzIjp7IlNLSU4iOnsidXJsIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9za2luLnBuZyIsIm1ldGFkYXRhIjp7Im1vZGVsIjoic2xpbSJ9fX19").Return("mock signature", nil) | ||||||
|  |  | ||||||
| 		t.App.Handler().ServeHTTP(w, req) | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
| @@ -526,7 +517,7 @@ func (t *SkinsystemTestSuite) TestProfile() { | |||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
| 		t.ProfilesProvider.On("FindProfileByUsername", mock.Anything, "mock_username", true).Return(&db.Profile{}, nil) | 		t.ProfilesProvider.On("FindProfileByUsername", mock.Anything, "mock_username", true).Return(&db.Profile{}, nil) | ||||||
| 		t.TexturesSigner.On("SignTextures", mock.Anything, mock.Anything).Return("", errors.New("mock error")) | 		t.SignerService.On("Sign", mock.Anything, mock.Anything).Return("", errors.New("mock error")) | ||||||
|  |  | ||||||
| 		t.App.Handler().ServeHTTP(w, req) | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
| @@ -535,77 +526,52 @@ func (t *SkinsystemTestSuite) TestProfile() { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| type signingKeyTestCase struct { |  | ||||||
| 	Name       string |  | ||||||
| 	KeyFormat  string |  | ||||||
| 	BeforeTest func(suite *SkinsystemTestSuite) |  | ||||||
| 	PanicErr   string |  | ||||||
| 	AfterTest  func(suite *SkinsystemTestSuite, response *http.Response) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var signingKeyTestsCases = []*signingKeyTestCase{ |  | ||||||
| 	{ |  | ||||||
| 		Name:      "Get public key in DER format", |  | ||||||
| 		KeyFormat: "DER", |  | ||||||
| 		BeforeTest: func(suite *SkinsystemTestSuite) { |  | ||||||
| 			pubPem, _ := pem.Decode([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnm\nUDlzHBQH3DpYef5WCO32TDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQ==\n-----END PUBLIC KEY-----")) |  | ||||||
| 			publicKey, _ := x509.ParsePKIXPublicKey(pubPem.Bytes) |  | ||||||
|  |  | ||||||
| 			suite.TexturesSigner.On("GetPublicKey", mock.Anything).Return(publicKey, nil) |  | ||||||
| 		}, |  | ||||||
| 		AfterTest: func(suite *SkinsystemTestSuite, response *http.Response) { |  | ||||||
| 			suite.Equal(200, response.StatusCode) |  | ||||||
| 			suite.Equal("application/octet-stream", response.Header.Get("Content-Type")) |  | ||||||
| 			suite.Equal("attachment; filename=\"yggdrasil_session_pubkey.der\"", response.Header.Get("Content-Disposition")) |  | ||||||
| 			body, _ := io.ReadAll(response.Body) |  | ||||||
| 			suite.Equal([]byte{48, 92, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 3, 75, 0, 48, 72, 2, 65, 0, 214, 212, 165, 80, 153, 144, 194, 169, 126, 246, 25, 211, 197, 183, 150, 233, 157, 1, 166, 49, 44, 25, 230, 80, 57, 115, 28, 20, 7, 220, 58, 88, 121, 254, 86, 8, 237, 246, 76, 53, 58, 125, 226, 9, 231, 192, 52, 148, 12, 176, 130, 214, 120, 195, 8, 182, 116, 97, 206, 207, 253, 97, 2, 247, 2, 3, 1, 0, 1}, body) |  | ||||||
| 		}, |  | ||||||
| 	}, |  | ||||||
| 	{ |  | ||||||
| 		Name:      "Get public key in PEM format", |  | ||||||
| 		KeyFormat: "PEM", |  | ||||||
| 		BeforeTest: func(suite *SkinsystemTestSuite) { |  | ||||||
| 			pubPem, _ := pem.Decode([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnm\nUDlzHBQH3DpYef5WCO32TDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQ==\n-----END PUBLIC KEY-----")) |  | ||||||
| 			publicKey, _ := x509.ParsePKIXPublicKey(pubPem.Bytes) |  | ||||||
|  |  | ||||||
| 			suite.TexturesSigner.On("GetPublicKey", mock.Anything).Return(publicKey, nil) |  | ||||||
| 		}, |  | ||||||
| 		AfterTest: func(suite *SkinsystemTestSuite, response *http.Response) { |  | ||||||
| 			suite.Equal(200, response.StatusCode) |  | ||||||
| 			suite.Equal("text/plain; charset=utf-8", response.Header.Get("Content-Type")) |  | ||||||
| 			suite.Equal("attachment; filename=\"yggdrasil_session_pubkey.pem\"", response.Header.Get("Content-Disposition")) |  | ||||||
| 			body, _ := io.ReadAll(response.Body) |  | ||||||
| 			suite.Equal("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnm\nUDlzHBQH3DpYef5WCO32TDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQ==\n-----END PUBLIC KEY-----\n", string(body)) |  | ||||||
| 		}, |  | ||||||
| 	}, |  | ||||||
| 	{ |  | ||||||
| 		Name:      "Error while obtaining public key", |  | ||||||
| 		KeyFormat: "DER", |  | ||||||
| 		BeforeTest: func(suite *SkinsystemTestSuite) { |  | ||||||
| 			suite.TexturesSigner.On("GetPublicKey", mock.Anything).Return(nil, errors.New("textures signer error")) |  | ||||||
| 		}, |  | ||||||
| 		PanicErr: "textures signer error", |  | ||||||
| 	}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *SkinsystemTestSuite) TestSignatureVerificationKey() { | func (t *SkinsystemTestSuite) TestSignatureVerificationKey() { | ||||||
| 	for _, testCase := range signingKeyTestsCases { | 	t.Run("in pem format", func() { | ||||||
| 		t.Run(testCase.Name, func() { | 		publicKey := "mock public key in pem format" | ||||||
| 			testCase.BeforeTest(t) | 		t.SignerService.On("GetPublicKey", mock.Anything, "pem").Return(publicKey, nil) | ||||||
|  |  | ||||||
| 			req := httptest.NewRequest("GET", "http://chrly/signature-verification-key."+strings.ToLower(testCase.KeyFormat), nil) | 		req := httptest.NewRequest("GET", "http://chrly/signature-verification-key.pem", nil) | ||||||
| 		w := httptest.NewRecorder() | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
| 			if testCase.PanicErr != "" { |  | ||||||
| 				t.PanicsWithError(testCase.PanicErr, func() { |  | ||||||
| 		t.App.Handler().ServeHTTP(w, req) | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusOK, result.StatusCode) | ||||||
|  | 		t.Equal("application/x-pem-file", result.Header.Get("Content-Type")) | ||||||
|  | 		t.Equal(`attachment; filename="yggdrasil_session_pubkey.pem"`, result.Header.Get("Content-Disposition")) | ||||||
|  | 		body, _ := io.ReadAll(result.Body) | ||||||
|  | 		t.Equal(publicKey, string(body)) | ||||||
| 	}) | 	}) | ||||||
| 			} else { |  | ||||||
|  | 	t.Run("in der format", func() { | ||||||
|  | 		publicKey := "mock public key in der format" | ||||||
|  | 		t.SignerService.On("GetPublicKey", mock.Anything, "der").Return(publicKey, nil) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("GET", "http://chrly/signature-verification-key.der", nil) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
| 		t.App.Handler().ServeHTTP(w, req) | 		t.App.Handler().ServeHTTP(w, req) | ||||||
| 				testCase.AfterTest(t, w.Result()) |  | ||||||
| 			} | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusOK, result.StatusCode) | ||||||
|  | 		t.Equal("application/octet-stream", result.Header.Get("Content-Type")) | ||||||
|  | 		t.Equal(`attachment; filename="yggdrasil_session_pubkey.der"`, result.Header.Get("Content-Disposition")) | ||||||
|  | 		body, _ := io.ReadAll(result.Body) | ||||||
|  | 		t.Equal(publicKey, string(body)) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("handle error", func() { | ||||||
|  | 		t.SignerService.On("GetPublicKey", mock.Anything, "pem").Return("", errors.New("mock error")) | ||||||
|  |  | ||||||
|  | 		req := httptest.NewRequest("GET", "http://chrly/signature-verification-key.pem", nil) | ||||||
|  | 		w := httptest.NewRecorder() | ||||||
|  |  | ||||||
|  | 		t.App.Handler().ServeHTTP(w, req) | ||||||
|  |  | ||||||
|  | 		result := w.Result() | ||||||
|  | 		t.Equal(http.StatusInternalServerError, result.StatusCode) | ||||||
| 	}) | 	}) | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestSkinsystem(t *testing.T) { | func TestSkinsystem(t *testing.T) { | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"slices" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -15,14 +16,23 @@ import ( | |||||||
| var now = time.Now | var now = time.Now | ||||||
| var signingMethod = jwt.SigningMethodHS256 | var signingMethod = jwt.SigningMethodHS256 | ||||||
|  |  | ||||||
| const scopesClaim = "scopes" |  | ||||||
|  |  | ||||||
| type Scope string | type Scope string | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	ProfileScope Scope = "profiles" | 	ProfilesScope Scope = "profiles" | ||||||
|  | 	SignScope     Scope = "sign" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var validScopes = []Scope{ | ||||||
|  | 	ProfilesScope, | ||||||
|  | 	SignScope, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type claims struct { | ||||||
|  | 	jwt.RegisteredClaims | ||||||
|  | 	Scopes []Scope `json:"scopes"` | ||||||
|  | } | ||||||
|  |  | ||||||
| func NewJwt(key []byte) *Jwt { | func NewJwt(key []byte) *Jwt { | ||||||
| 	return &Jwt{ | 	return &Jwt{ | ||||||
| 		Key: key, | 		Key: key, | ||||||
| @@ -38,11 +48,20 @@ func (t *Jwt) NewToken(scopes ...Scope) (string, error) { | |||||||
| 		return "", errors.New("you must specify at least one scope") | 		return "", errors.New("you must specify at least one scope") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	token := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ | 	for _, scope := range scopes { | ||||||
| 		"iss":       "chrly", | 		if !slices.Contains(validScopes, scope) { | ||||||
| 		"iat":       now().Unix(), | 			return "", fmt.Errorf("unknown scope %s", scope) | ||||||
| 		scopesClaim: scopes, | 		} | ||||||
| 	}) | 	} | ||||||
|  |  | ||||||
|  | 	token := jwt.New(signingMethod) | ||||||
|  | 	token.Claims = &claims{ | ||||||
|  | 		jwt.RegisteredClaims{ | ||||||
|  | 			Issuer:   "chrly", | ||||||
|  | 			IssuedAt: jwt.NewNumericDate(now()), | ||||||
|  | 		}, | ||||||
|  | 		scopes, | ||||||
|  | 	} | ||||||
| 	token.Header["v"] = version.MajorVersion | 	token.Header["v"] = version.MajorVersion | ||||||
|  |  | ||||||
| 	return token.SignedString(t.Key) | 	return token.SignedString(t.Key) | ||||||
| @@ -52,7 +71,7 @@ func (t *Jwt) NewToken(scopes ...Scope) (string, error) { | |||||||
| var MissingAuthenticationError = errors.New("authentication value not provided") | var MissingAuthenticationError = errors.New("authentication value not provided") | ||||||
| var InvalidTokenError = errors.New("passed authentication value is invalid") | var InvalidTokenError = errors.New("passed authentication value is invalid") | ||||||
|  |  | ||||||
| func (t *Jwt) Authenticate(req *http.Request) error { | func (t *Jwt) Authenticate(req *http.Request, scope Scope) error { | ||||||
| 	bearerToken := req.Header.Get("Authorization") | 	bearerToken := req.Header.Get("Authorization") | ||||||
| 	if bearerToken == "" { | 	if bearerToken == "" { | ||||||
| 		return MissingAuthenticationError | 		return MissingAuthenticationError | ||||||
| @@ -62,8 +81,8 @@ func (t *Jwt) Authenticate(req *http.Request) error { | |||||||
| 		return InvalidTokenError | 		return InvalidTokenError | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	tokenStr := bearerToken[7:] | 	tokenStr := bearerToken[7:] // trim "bearer " part | ||||||
| 	token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { | 	token, err := jwt.ParseWithClaims(tokenStr, &claims{}, func(token *jwt.Token) (interface{}, error) { | ||||||
| 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | ||||||
| 			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) | 			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) | ||||||
| 		} | 		} | ||||||
| @@ -78,5 +97,10 @@ func (t *Jwt) Authenticate(req *http.Request) error { | |||||||
| 		return errors.Join(InvalidTokenError, errors.New("missing v header")) | 		return errors.Join(InvalidTokenError, errors.New("missing v header")) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	claims := token.Claims.(*claims) | ||||||
|  | 	if !slices.Contains(claims.Scopes, scope) { | ||||||
|  | 		return errors.New("the token doesn't have the scope to perform the action") | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -16,10 +16,16 @@ func TestJwtAuth_NewToken(t *testing.T) { | |||||||
| 		return time.Date(2024, 2, 1, 11, 26, 15, 0, time.UTC) | 		return time.Date(2024, 2, 1, 11, 26, 15, 0, time.UTC) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	t.Run("with scope", func(t *testing.T) { | 	t.Run("with known scope", func(t *testing.T) { | ||||||
| 		token, err := jwt.NewToken(ProfileScope, "custom-scope") | 		token, err := jwt.NewToken(ProfilesScope, SignScope) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 		require.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInYiOjV9.eyJpYXQiOjE3MDY3ODY3NzUsImlzcyI6ImNocmx5Iiwic2NvcGVzIjpbInByb2ZpbGVzIiwiY3VzdG9tLXNjb3BlIl19.Iq673YyWWkJZjIkBmKYRN8Lx9qoD39S_e-MegG0aORM", token) | 		require.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInYiOjV9.eyJpc3MiOiJjaHJseSIsImlhdCI6MTcwNjc4Njc3NSwic2NvcGVzIjpbInByb2ZpbGVzIiwic2lnbiJdfQ.HkNGiDba3I_bLGN6sF0eTE5n6rMLgYfAZZEqI4xb2X4", token) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("with unknown scope", func(t *testing.T) { | ||||||
|  | 		token, err := jwt.NewToken("scope-123") | ||||||
|  | 		require.ErrorContains(t, err, "unknown") | ||||||
|  | 		require.Empty(t, token) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("no scopes", func(t *testing.T) { | 	t.Run("no scopes", func(t *testing.T) { | ||||||
| @@ -34,41 +40,48 @@ func TestJwtAuth_Authenticate(t *testing.T) { | |||||||
| 	t.Run("success", func(t *testing.T) { | 	t.Run("success", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("POST", "http://localhost", nil) | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
| 		req.Header.Add("Authorization", "Bearer "+jwtString) | 		req.Header.Add("Authorization", "Bearer "+jwtString) | ||||||
| 		err := jwt.Authenticate(req) | 		err := jwt.Authenticate(req, ProfilesScope) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("has no required scope", func(t *testing.T) { | ||||||
|  | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
|  | 		req.Header.Add("Authorization", "Bearer "+jwtString) | ||||||
|  | 		err := jwt.Authenticate(req, SignScope) | ||||||
|  | 		require.ErrorContains(t, err, "scope") | ||||||
|  | 	}) | ||||||
|  |  | ||||||
| 	t.Run("request without auth header", func(t *testing.T) { | 	t.Run("request without auth header", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("POST", "http://localhost", nil) | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
| 		err := jwt.Authenticate(req) | 		err := jwt.Authenticate(req, ProfilesScope) | ||||||
| 		require.ErrorIs(t, err, MissingAuthenticationError) | 		require.ErrorIs(t, err, MissingAuthenticationError) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("no bearer token prefix", func(t *testing.T) { | 	t.Run("no bearer token prefix", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("POST", "http://localhost", nil) | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
| 		req.Header.Add("Authorization", "trash") | 		req.Header.Add("Authorization", "trash") | ||||||
| 		err := jwt.Authenticate(req) | 		err := jwt.Authenticate(req, ProfilesScope) | ||||||
| 		require.ErrorIs(t, err, InvalidTokenError) | 		require.ErrorIs(t, err, InvalidTokenError) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("bearer token but not jwt", func(t *testing.T) { | 	t.Run("bearer token but not jwt", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("POST", "http://localhost", nil) | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
| 		req.Header.Add("Authorization", "Bearer seems.like.jwt") | 		req.Header.Add("Authorization", "Bearer seems.like.jwt") | ||||||
| 		err := jwt.Authenticate(req) | 		err := jwt.Authenticate(req, ProfilesScope) | ||||||
| 		require.ErrorIs(t, err, InvalidTokenError) | 		require.ErrorIs(t, err, InvalidTokenError) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("invalid signature", func(t *testing.T) { | 	t.Run("invalid signature", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("POST", "http://localhost", nil) | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
| 		req.Header.Add("Authorization", "Bearer "+jwtString+"123") | 		req.Header.Add("Authorization", "Bearer "+jwtString+"123") | ||||||
| 		err := jwt.Authenticate(req) | 		err := jwt.Authenticate(req, ProfilesScope) | ||||||
| 		require.ErrorIs(t, err, InvalidTokenError) | 		require.ErrorIs(t, err, InvalidTokenError) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("missing v header", func(t *testing.T) { | 	t.Run("missing v header", func(t *testing.T) { | ||||||
| 		req := httptest.NewRequest("POST", "http://localhost", nil) | 		req := httptest.NewRequest("POST", "http://localhost", nil) | ||||||
| 		req.Header.Add("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE3MDY3ODY3NzUsImlzcyI6ImNocmx5Iiwic2NvcGVzIjpbInByb2ZpbGVzIl19.zOX2ZKyU37kjwt1p9uCHxALxWQD2UC0wWcAcNvBXGq0") | 		req.Header.Add("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE3MDY3ODY3NzUsImlzcyI6ImNocmx5Iiwic2NvcGVzIjpbInByb2ZpbGVzIl19.zOX2ZKyU37kjwt1p9uCHxALxWQD2UC0wWcAcNvBXGq0") | ||||||
| 		err := jwt.Authenticate(req) | 		err := jwt.Authenticate(req, ProfilesScope) | ||||||
| 		require.ErrorIs(t, err, InvalidTokenError) | 		require.ErrorIs(t, err, InvalidTokenError) | ||||||
| 		require.ErrorContains(t, err, "missing v header") | 		require.ErrorContains(t, err, "missing v header") | ||||||
| 	}) | 	}) | ||||||
|   | |||||||
| @@ -1,15 +1,18 @@ | |||||||
| package security | package security | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"crypto" | 	"crypto" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
| 	"crypto/rsa" | 	"crypto/rsa" | ||||||
| 	"crypto/sha1" | 	"crypto/sha1" | ||||||
| 	"encoding/base64" | 	"crypto/x509" | ||||||
|  | 	"encoding/pem" | ||||||
|  | 	"errors" | ||||||
|  | 	"io" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var randomReader = rand.Reader | var randomReader = rand.Reader | ||||||
|  | var invalidKeyFormat = errors.New(`invalid key format: it should be"der" or "pem"`) | ||||||
|  |  | ||||||
| func NewSigner(key *rsa.PrivateKey) *Signer { | func NewSigner(key *rsa.PrivateKey) *Signer { | ||||||
| 	return &Signer{Key: key} | 	return &Signer{Key: key} | ||||||
| @@ -19,23 +22,38 @@ type Signer struct { | |||||||
| 	Key *rsa.PrivateKey | 	Key *rsa.PrivateKey | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *Signer) SignTextures(ctx context.Context, textures string) (string, error) { | func (s *Signer) Sign(data io.Reader) ([]byte, error) { | ||||||
| 	message := []byte(textures) |  | ||||||
| 	messageHash := sha1.New() | 	messageHash := sha1.New() | ||||||
| 	_, err := messageHash.Write(message) | 	_, err := io.Copy(messageHash, data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	messageHashSum := messageHash.Sum(nil) | 	messageHashSum := messageHash.Sum(nil) | ||||||
| 	signature, err := rsa.SignPKCS1v15(randomReader, s.Key, crypto.SHA1, messageHashSum) | 	signature, err := rsa.SignPKCS1v15(randomReader, s.Key, crypto.SHA1, messageHashSum) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return base64.StdEncoding.EncodeToString(signature), nil | 	return signature, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *Signer) GetPublicKey(ctx context.Context) (*rsa.PublicKey, error) { | func (s *Signer) GetPublicKey(format string) ([]byte, error) { | ||||||
| 	return &s.Key.PublicKey, nil | 	if format != "der" && format != "pem" { | ||||||
|  | 		return nil, invalidKeyFormat | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	asn1Bytes, err := x509.MarshalPKIXPublicKey(s.Key.Public()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if format == "pem" { | ||||||
|  | 		return pem.EncodeToMemory(&pem.Block{ | ||||||
|  | 			Type:  "PUBLIC KEY", | ||||||
|  | 			Bytes: asn1Bytes, | ||||||
|  | 		}), nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return asn1Bytes, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,10 +1,9 @@ | |||||||
| package security | package security | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"crypto/rsa" |  | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"encoding/pem" | 	"encoding/pem" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| @@ -17,7 +16,7 @@ func (c *ConstantReader) Read(p []byte) (int, error) { | |||||||
| 	return 1, nil | 	return 1, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestSigner_SignTextures(t *testing.T) { | func TestSigner_Sign(t *testing.T) { | ||||||
| 	randomReader = &ConstantReader{} | 	randomReader = &ConstantReader{} | ||||||
|  |  | ||||||
| 	rawKey, _ := pem.Decode([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOwIBAAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnmUDlzHBQH3DpYef5WCO32\nTDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQJAItaxSHTe6PKbyEU/9pxj\nONdhYRYwVLLo56gnMYhkyoEqaaMsfov8hhoepkYZBMvZFB2bDOsQ2SaJ+E2eiBO4\nAQIhAPssS0+BR9w0bOdmjGqmdE9NrN5UJQcOW13s29+6QzUBAiEA2vWOepA5Apiu\npEA3pwoGdkVCrNSnnKjDQzDXBnpd3/cCIEFNd9sY4qUG4FWdXN6RnmXL7Sj0uZfH\nDMwzu8rEM5sBAiEAhvdoDNqLmbMdq3c+FsPSOeL1d21Zp/JK8kbPtFmHNf8CIQDV\n6FSZDwvWfuxaM7BsycQONkjDBTPNu+lqctJBGnBv3A==\n-----END RSA PRIVATE KEY-----\n")) | 	rawKey, _ := pem.Decode([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOwIBAAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnmUDlzHBQH3DpYef5WCO32\nTDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQJAItaxSHTe6PKbyEU/9pxj\nONdhYRYwVLLo56gnMYhkyoEqaaMsfov8hhoepkYZBMvZFB2bDOsQ2SaJ+E2eiBO4\nAQIhAPssS0+BR9w0bOdmjGqmdE9NrN5UJQcOW13s29+6QzUBAiEA2vWOepA5Apiu\npEA3pwoGdkVCrNSnnKjDQzDXBnpd3/cCIEFNd9sY4qUG4FWdXN6RnmXL7Sj0uZfH\nDMwzu8rEM5sBAiEAhvdoDNqLmbMdq3c+FsPSOeL1d21Zp/JK8kbPtFmHNf8CIQDV\n6FSZDwvWfuxaM7BsycQONkjDBTPNu+lqctJBGnBv3A==\n-----END RSA PRIVATE KEY-----\n")) | ||||||
| @@ -25,9 +24,14 @@ func TestSigner_SignTextures(t *testing.T) { | |||||||
|  |  | ||||||
| 	signer := NewSigner(key) | 	signer := NewSigner(key) | ||||||
|  |  | ||||||
| 	signature, err := signer.SignTextures(context.Background(), "eyJ0aW1lc3RhbXAiOjE2MTQzMDcxMzQsInByb2ZpbGVJZCI6ImZmYzhmZGM5NTgyNDUwOWU4YTU3Yzk5Yjk0MGZiOTk2IiwicHJvZmlsZU5hbWUiOiJFcmlja1NrcmF1Y2giLCJ0ZXh0dXJlcyI6eyJTS0lOIjp7InVybCI6Imh0dHA6Ly9lbHkuYnkvc3RvcmFnZS9za2lucy82OWM2NzQwZDI5OTNlNWQ2ZjZhN2ZjOTI0MjBlZmMyOS5wbmcifX0sImVseSI6dHJ1ZX0") | 	signature, err := signer.Sign(strings.NewReader("mock string to sign")) | ||||||
| 	require.NoError(t, err) | 	require.NoError(t, err) | ||||||
| 	require.Equal(t, "IyHCxTP5ITquEXTHcwCtLd08jWWy16JwlQeWg8naxhoAVQecHGRdzHRscuxtdq/446kmeox7h4EfRN2A2ZLL+A==", signature) | 	require.Equal(t, []byte{ | ||||||
|  | 		0xd0, 0x88, 0xc6, 0x65, 0x27, 0x5d, 0xe4, 0x86, 0x6b, 0x7a, 0x5a, 0xd, 0x94, 0x6f, 0x80, 0x88, 0x12, 0x8e, 0x65, | ||||||
|  | 		0x75, 0xfb, 0xba, 0xcb, 0x7f, 0x90, 0xf5, 0xae, 0x5d, 0x2c, 0x5d, 0x60, 0xf6, 0x83, 0x54, 0xd3, 0x40, 0xd, 0x1f, | ||||||
|  | 		0xc0, 0xbc, 0x6d, 0xa8, 0x6f, 0x6, 0xd8, 0x38, 0x74, 0x5b, 0x4f, 0x15, 0x82, 0x6d, 0x67, 0x95, 0x7b, 0xf, 0xcc, | ||||||
|  | 		0xf3, 0x51, 0xfe, 0xcd, 0xb9, 0x1e, 0xdf, | ||||||
|  | 	}, signature) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestSigner_GetPublicKey(t *testing.T) { | func TestSigner_GetPublicKey(t *testing.T) { | ||||||
| @@ -38,7 +42,40 @@ func TestSigner_GetPublicKey(t *testing.T) { | |||||||
|  |  | ||||||
| 	signer := NewSigner(key) | 	signer := NewSigner(key) | ||||||
|  |  | ||||||
| 	publicKey, err := signer.GetPublicKey(context.Background()) | 	t.Run("pem format", func(t *testing.T) { | ||||||
|  | 		publicKey, err := signer.GetPublicKey("pem") | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 	require.IsType(t, &rsa.PublicKey{}, publicKey) | 		require.Equal(t, []byte{ | ||||||
|  | 			0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x50, 0x55, 0x42, 0x4c, 0x49, 0x43, 0x20, | ||||||
|  | 			0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0xa, 0x4d, 0x46, 0x77, 0x77, 0x44, 0x51, 0x59, 0x4a, 0x4b, | ||||||
|  | 			0x6f, 0x5a, 0x49, 0x68, 0x76, 0x63, 0x4e, 0x41, 0x51, 0x45, 0x42, 0x42, 0x51, 0x41, 0x44, 0x53, 0x77, 0x41, | ||||||
|  | 			0x77, 0x53, 0x41, 0x4a, 0x42, 0x41, 0x4e, 0x62, 0x55, 0x70, 0x56, 0x43, 0x5a, 0x6b, 0x4d, 0x4b, 0x70, 0x66, | ||||||
|  | 			0x76, 0x59, 0x5a, 0x30, 0x38, 0x57, 0x33, 0x6c, 0x75, 0x6d, 0x64, 0x41, 0x61, 0x59, 0x78, 0x4c, 0x42, 0x6e, | ||||||
|  | 			0x6d, 0xa, 0x55, 0x44, 0x6c, 0x7a, 0x48, 0x42, 0x51, 0x48, 0x33, 0x44, 0x70, 0x59, 0x65, 0x66, 0x35, 0x57, | ||||||
|  | 			0x43, 0x4f, 0x33, 0x32, 0x54, 0x44, 0x55, 0x36, 0x66, 0x65, 0x49, 0x4a, 0x35, 0x38, 0x41, 0x30, 0x6c, 0x41, | ||||||
|  | 			0x79, 0x77, 0x67, 0x74, 0x5a, 0x34, 0x77, 0x77, 0x69, 0x32, 0x64, 0x47, 0x48, 0x4f, 0x7a, 0x2f, 0x31, 0x68, | ||||||
|  | 			0x41, 0x76, 0x63, 0x43, 0x41, 0x77, 0x45, 0x41, 0x41, 0x51, 0x3d, 0x3d, 0xa, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, | ||||||
|  | 			0x45, 0x4e, 0x44, 0x20, 0x50, 0x55, 0x42, 0x4c, 0x49, 0x43, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, | ||||||
|  | 			0x2d, 0xa, | ||||||
|  | 		}, publicKey) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("der format", func(t *testing.T) { | ||||||
|  | 		publicKey, err := signer.GetPublicKey("der") | ||||||
|  | 		require.NoError(t, err) | ||||||
|  | 		require.Equal(t, []byte{ | ||||||
|  | 			0x30, 0x5c, 0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0x1, 0x5, 0x0, 0x3, 0x4b, 0x0, | ||||||
|  | 			0x30, 0x48, 0x2, 0x41, 0x0, 0xd6, 0xd4, 0xa5, 0x50, 0x99, 0x90, 0xc2, 0xa9, 0x7e, 0xf6, 0x19, 0xd3, 0xc5, | ||||||
|  | 			0xb7, 0x96, 0xe9, 0x9d, 0x1, 0xa6, 0x31, 0x2c, 0x19, 0xe6, 0x50, 0x39, 0x73, 0x1c, 0x14, 0x7, 0xdc, 0x3a, | ||||||
|  | 			0x58, 0x79, 0xfe, 0x56, 0x8, 0xed, 0xf6, 0x4c, 0x35, 0x3a, 0x7d, 0xe2, 0x9, 0xe7, 0xc0, 0x34, 0x94, 0xc, | ||||||
|  | 			0xb0, 0x82, 0xd6, 0x78, 0xc3, 0x8, 0xb6, 0x74, 0x61, 0xce, 0xcf, 0xfd, 0x61, 0x2, 0xf7, 0x2, 0x3, 0x1, 0x0, | ||||||
|  | 			0x1, | ||||||
|  | 		}, publicKey) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("unknown format", func(t *testing.T) { | ||||||
|  | 		publicKey, err := signer.GetPublicKey("unknown") | ||||||
|  | 		require.Nil(t, publicKey) | ||||||
|  | 		require.ErrorContains(t, err, "invalid") | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user