Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions code/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ services:
networks:
- threadit-network

keycloak:
image: quay.io/keycloak/keycloak:21.1
container_name: keycloak
restart: always
command:
- start-dev
- --import-realm
environment:
KEYCLOAK_ADMIN: admin
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
KC_HOSTNAME_STRICT: false
KC_HOSTNAME_STRICT_HTTPS: false
KC_HTTP_ENABLED: "true"
KC_PROXY: edge
volumes:
- ./keycloak/realm-export.json:/opt/keycloak/data/import/realm.json:ro
ports:
- "${KEYCLOAK_PORT}:8080"
networks:
- threadit-network

volumes:
db_data:
driver: local
Expand Down
105 changes: 77 additions & 28 deletions code/grpc-gateway/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/emptypb"
"threadit/grpc-gateway/middleware"
)

func getGrpcServerAddress(hostEnvVar string, portEnvVar string) string {
Expand Down Expand Up @@ -113,7 +114,9 @@ func main() {
gorun.GOMAXPROCS(gorun.NumCPU())

gwmux := runtime.NewServeMux()
ctx := context.Background()

// gRPC dial options with message size configurations
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
Expand All @@ -122,48 +125,94 @@ func main() {
),
}

err := communitypb.RegisterCommunityServiceHandlerFromEndpoint(context.Background(), gwmux, getGrpcServerAddress("COMMUNITY_SERVICE_HOST", "COMMUNITY_SERVICE_PORT"), opts)
if err != nil {
log.Fatalf("Failed to register gRPC gateway: %v", err)
}
// Initialize auth handler
authHandler := middleware.NewAuthHandler(
os.Getenv("KEYCLOAK_URL"),
os.Getenv("KEYCLOAK_CLIENT_ID"),
os.Getenv("KEYCLOAK_CLIENT_SECRET"),
os.Getenv("KEYCLOAK_REALM"),
)

err = threadpb.RegisterThreadServiceHandlerFromEndpoint(context.Background(), gwmux, getGrpcServerAddress("THREAD_SERVICE_HOST", "THREAD_SERVICE_PORT"), opts)
if err != nil {
log.Fatalf("Failed to register gRPC gateway: %v", err)
// Create a new ServeMux for both gRPC-Gateway and auth routes
httpMux := http.NewServeMux()

// Register auth routes
authHandler.RegisterRoutes(httpMux)

// Register gRPC-Gateway routes with auth middleware
httpMux.Handle("/api", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Auth middleware for API routes
authMiddleware := middleware.NewAuthMiddleware(middleware.KeycloakConfig{
Realm: os.Getenv("KEYCLOAK_REALM"),
ClientID: os.Getenv("KEYCLOAK_CLIENT_ID"),
ClientSecret: os.Getenv("KEYCLOAK_CLIENT_SECRET"),
KeycloakURL: os.Getenv("KEYCLOAK_URL"),
})

authMiddleware.Handler(gwmux).ServeHTTP(w, r)
}))

// Register service handlers
if err := registerServices(ctx, gwmux, opts); err != nil {
log.Fatalf("Failed to register services: %v", err)
}

err = commentpb.RegisterCommentServiceHandlerFromEndpoint(context.Background(), gwmux, getGrpcServerAddress("COMMENT_SERVICE_HOST", "COMMENT_SERVICE_PORT"), opts)
if err != nil {
log.Fatalf("Failed to register gRPC gateway: %v", err)
port := os.Getenv("GRPC_GATEWAY_PORT")
if port == "" {
log.Fatalf("missing GRPC_GATEWAY_PORT env var")
}

err = votepb.RegisterVoteServiceHandlerFromEndpoint(context.Background(), gwmux, getGrpcServerAddress("VOTE_SERVICE_HOST", "VOTE_SERVICE_PORT"), opts)
if err != nil {
log.Fatalf("Failed to register gRPC gateway: %v", err)
log.Printf("gRPC Gateway server listening on :%s", port)
if err := http.ListenAndServe(":"+port, httpMux); err != nil {
log.Fatalf("Failed to serve: %v", err)
}
}

err = searchpb.RegisterSearchServiceHandlerFromEndpoint(context.Background(), gwmux, getGrpcServerAddress("SEARCH_SERVICE_HOST", "SEARCH_SERVICE_PORT"), opts)
if err != nil {
log.Fatalf("Failed to register gRPC gateway: %v", err)
func registerServices(ctx context.Context, mux *runtime.ServeMux, opts []grpc.DialOption) error {
// Register Community Service
if err := communitypb.RegisterCommunityServiceHandlerFromEndpoint(
ctx, mux, getGrpcServerAddress("COMMUNITY_SERVICE_HOST", "COMMUNITY_SERVICE_PORT"), opts,
); err != nil {
return fmt.Errorf("failed to register community service: %v", err)
}

err = popularpb.RegisterPopularServiceHandlerFromEndpoint(context.Background(), gwmux, getGrpcServerAddress("POPULAR_SERVICE_HOST", "POPULAR_SERVICE_PORT"), opts)
if err != nil {
log.Fatalf("Failed to register gRPC gateway: %v", err)
// Register Thread Service
if err := threadpb.RegisterThreadServiceHandlerFromEndpoint(
ctx, mux, getGrpcServerAddress("THREAD_SERVICE_HOST", "THREAD_SERVICE_PORT"), opts,
); err != nil {
return fmt.Errorf("failed to register thread service: %v", err)
}

http.HandleFunc("/health", handleHealthCheck)
// Register Comment Service
if err := commentpb.RegisterCommentServiceHandlerFromEndpoint(
ctx, mux, getGrpcServerAddress("COMMENT_SERVICE_HOST", "COMMENT_SERVICE_PORT"), opts,
); err != nil {
return fmt.Errorf("failed to register comment service: %v", err)
}

http.Handle("/", gwmux)
// Register Vote Service
if err := votepb.RegisterVoteServiceHandlerFromEndpoint(
ctx, mux, getGrpcServerAddress("VOTE_SERVICE_HOST", "VOTE_SERVICE_PORT"), opts,
); err != nil {
return fmt.Errorf("failed to register vote service: %v", err)
}

port := os.Getenv("GRPC_GATEWAY_PORT")
if port == "" {
log.Fatalf("missing GRPC_GATEWAY_PORT env var")
// Register Search Service
if err := searchpb.RegisterSearchServiceHandlerFromEndpoint(
ctx, mux, getGrpcServerAddress("SEARCH_SERVICE_HOST", "SEARCH_SERVICE_PORT"), opts,
); err != nil {
return fmt.Errorf("failed to register search service: %v", err)
}

log.Printf("gRPC Gateway server listening on :%s", port)
err = http.ListenAndServe(fmt.Sprintf(":%s", port), nil)
if err != nil {
log.Fatalf("Failed to start HTTP server: %v", err)
// Register Popular Service
if err := popularpb.RegisterPopularServiceHandlerFromEndpoint(
ctx, mux, getGrpcServerAddress("POPULAR_SERVICE_HOST", "POPULAR_SERVICE_PORT"), opts,
); err != nil {
return fmt.Errorf("failed to register popular service: %v", err)
}

http.HandleFunc("/health", handleHealthCheck)
http.Handle("/", mux)

return nil
}
163 changes: 163 additions & 0 deletions code/grpc-gateway/middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package middleware

import (
"context"
"net/http"
"strings"

"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"google.golang.org/grpc/metadata"
"your-module/code/services/auth"
)

type AuthMiddleware struct {
keycloak *auth.KeycloakClient
}

func NewAuthMiddleware(config auth.KeycloakConfig) (*AuthMiddleware, error) {
kc, err := auth.NewKeycloakClient(config)
if err != nil {
return nil, err
}
return &AuthMiddleware{keycloak: kc}, nil
}

func (am *AuthMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip auth for public endpoints
if isPublicEndpoint(r.URL.Path, r.Method) {
next.ServeHTTP(w, r)
return
}

// Extract token from Authorization header
token, err := auth.ExtractBearerToken(r.Header.Get("Authorization"))
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}

// Validate token
claims, err := am.keycloak.ValidateToken(r.Context(), token)
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}

// Check required roles for protected endpoints
if !hasRequiredRole(r.URL.Path, claims) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}

// Add user info to context
ctx := context.WithValue(r.Context(), "user_claims", claims)

// Forward token to gRPC services
md := metadata.Pairs("authorization", "Bearer "+token)
ctx = metadata.NewOutgoingContext(ctx, md)

next.ServeHTTP(w, r.WithContext(ctx))
})
}

func isPublicEndpoint(path, method string) bool {
// Auth endpoints are always public
authPaths := []string{
"/auth/login",
"/auth/register",
"/auth/logout",
}
for _, ap := range authPaths {
if path == ap {
return true
}
}

// Only GET requests can be public for these paths
if method != http.MethodGet {
return false
}

publicGetPaths := []string{
"/communities",
"/threads",
"/comments",
"/search",
"/search/thread",
"/search/community",
"/popular/threads",
"/popular/comments",
}

// Check exact matches for list endpoints
for _, pp := range publicGetPaths {
if path == pp {
return true
}
}

// Check id based paths
idBasedPaths := []string{
"/communities/",
"/threads/",
"/comments/",
}

for _, pp := range idBasedPaths {
if strings.HasPrefix(path, pp) && path != pp {
return true
}
}

return false
}

func hasRequiredRole(path string, claims *auth.TokenClaims) bool {
roleRequirements := map[string]string{
// Communities
"POST /communities": "user",
"PATCH /communities/": "moderator",
"DELETE /communities/": "moderator",

// Threads
"POST /threads": "user",
"PATCH /threads/": "user",
"DELETE /threads/": "user",

// Comment sdpoints
"POST /comments": "user",
"PATCH /comments/": "user",
"DELETE /comments/": "user",

// Votes
"POST /votes/thread/": "user",
"POST /votes/comment/": "user",

// Admin
"POST /admin/": "admin",
"PUT /admin/": "admin",
"DELETE /admin/": "admin",
}

// Check each role requirement
for pathPattern, requiredRole := range roleRequirements {
parts := strings.SplitN(pathPattern, " ", 2)
method, pattern := parts[0], parts[1]
if strings.HasPrefix(path, pattern) {
return claims.RealmAccess.Roles != nil && contains(claims.RealmAccess.Roles, requiredRole)
}
}

// If no specific role requirement, allow access
return true
}

func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
Loading
Loading