Skip to content

Commit

Permalink
feat: custom claims for setting grpc/server addr in token (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
abelanger5 committed Feb 5, 2024
1 parent 7bdfa78 commit 73adb77
Show file tree
Hide file tree
Showing 22 changed files with 247 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export function CreateTokenDialog({
className="text-sm"
wrapLines={false}
maxWidth={'calc(700px - 4rem)'}
code={token}
code={'HATCHET_CLIENT_TOKEN="' + token + '"'}
copy
/>
</DialogContent>
Expand Down
2 changes: 1 addition & 1 deletion frontend/docs/pages/contributing/sdks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Each SDK should support the following environment variables:
| Variable | Description | Required | Default |
| -------- | ----------- | -------- | ------- |
| `HATCHET_CLIENT_TOKEN` | The tenant-scoped API token to use. | Yes | N/A |
| `HATCHET_CLIENT_HOST_PORT` | The host and port of the Hatchet server to connect to, in `host:port` format. SDKs should handle schemes and trailing slashes, i.e. `https://host:port | Yes | N/A |
| `HATCHET_CLIENT_HOST_PORT` | The host and port of the Hatchet server to connect to, in `host:port` format. SDKs should handle schemes and trailing slashes, i.e. `https://host:port | No | Automatically detected in new tokens. |
| `HATCHET_CLIENT_TLS_STRATEGY` | The TLS strategy to use. Valid values are `none`, `tls`, and `mtls`. | No | `tls` |
| `HATCHET_CLIENT_TLS_CERT_FILE` | The path to the TLS client certificate file to use. | Only if strategy is set to `mtls` | N/A |
| `HATCHET_CLIENT_TLS_CERT` | The TLS client key file to use. | Only if strategy is set to `mtls` | N/A |
Expand Down
1 change: 0 additions & 1 deletion frontend/docs/pages/home/python-sdk/setup.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ poetry add hatchet-sdk
Navigate to your Hatchet dashboard and navigate to your settings tab. You should see a section called "API Keys". Click "Create API Key", input a name for the key and copy the key. Then set the following environment variables:

```sh
HATCHET_CLIENT_HOST_PORT=<hatchet-domain>:443
HATCHET_CLIENT_TOKEN="<your-api-key>"
```

Expand Down
1 change: 0 additions & 1 deletion frontend/docs/pages/home/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ When you get access to Hatchet, you'll be given a development tenant to use. Thi
When you get access to the tenant, navigate to your Hatchet dashboard and to your settings tab. You should see a section called "API Keys". Click "Create API Key", input a name for the key and copy the key. Then set the following environment variables:

```sh
HATCHET_CLIENT_HOST_PORT=<hatchet-domain>:443
HATCHET_CLIENT_TOKEN="<your-api-key>"
```

Expand Down
1 change: 0 additions & 1 deletion frontend/docs/pages/home/typescript-sdk/setup.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ npm i @hatchet-dev/typescript-sdk
Navigate to your Hatchet dashboard and navigate to your settings tab. You should see a section called "API Keys". Click "Create API Key", input a name for the key and copy the key. Then set the following environment variables:

```sh
HATCHET_CLIENT_HOST_PORT=<hatchet-domain>:443
HATCHET_CLIENT_TOKEN="<your-api-key>"
```
35 changes: 32 additions & 3 deletions internal/auth/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ type JWTManager interface {
}

type TokenOpts struct {
Issuer string
Audience string
Issuer string
Audience string
ServerURL string
GRPCBroadcastAddress string
}

type jwtManagerImpl struct {
Expand Down Expand Up @@ -112,6 +114,31 @@ func (j *jwtManagerImpl) ValidateTenantToken(token string) (tenantId string, err
return "", fmt.Errorf("failed to read token_id claim: %v", err)
}

// ensure the current server url and grpc broadcast address match the token, if present
if hasServerURL := verifiedJwt.HasStringClaim("server_url"); hasServerURL {
serverURL, err := verifiedJwt.StringClaim("server_url")

if err != nil {
return "", fmt.Errorf("failed to read server_url claim: %v", err)
}

if serverURL != j.opts.ServerURL {
return "", fmt.Errorf("server_url claim does not match")
}
}

if hasGRPCBroadcastAddress := verifiedJwt.HasStringClaim("grpc_broadcast_address"); hasGRPCBroadcastAddress {
grpcBroadcastAddress, err := verifiedJwt.StringClaim("grpc_broadcast_address")

if err != nil {
return "", fmt.Errorf("failed to read grpc_broadcast_address claim: %v", err)
}

if grpcBroadcastAddress != j.opts.GRPCBroadcastAddress {
return "", fmt.Errorf("grpc_broadcast_address claim does not match")
}
}

// read the token from the database
dbToken, err := j.tokenRepo.GetAPITokenById(tokenId)

Expand Down Expand Up @@ -155,7 +182,9 @@ func (j *jwtManagerImpl) getJWTOptionsForTenant(tenantId string) (tokenId string
ExpiresAt: &expiresAt,
Issuer: &issuer,
CustomClaims: map[string]interface{}{
"token_id": tokenId,
"token_id": tokenId,
"server_url": j.opts.ServerURL,
"grpc_broadcast_address": j.opts.GRPCBroadcastAddress,
},
}

Expand Down
3 changes: 3 additions & 0 deletions internal/config/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type ClientConfig struct {
TenantId string
Token string

ServerURL string
GRPCBroadcastAddress string

TLSConfig *tls.Config
}

Expand Down
6 changes: 4 additions & 2 deletions internal/config/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ func GetServerConfigFromConfigfile(dc *database.Config, cf *server.ServerConfigF

// create a new JWT manager
auth.JWTManager, err = token.NewJWTManager(encryptionSvc, dc.Repository.APIToken(), &token.TokenOpts{
Issuer: cf.Runtime.ServerURL,
Audience: cf.Runtime.ServerURL,
Issuer: cf.Runtime.ServerURL,
Audience: cf.Runtime.ServerURL,
GRPCBroadcastAddress: cf.Runtime.GRPCBroadcastAddress,
ServerURL: cf.Runtime.ServerURL,
})

if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/config/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ type ConfigFileRuntime struct {
// GRPCBindAddress is the address that the grpc server binds to. Should set to 0.0.0.0 if binding in docker container.
GRPCBindAddress string `mapstructure:"grpcBindAddress" json:"grpcBindAddress,omitempty" default:"127.0.0.1"`

// GRPCBroadcastAddress is the address that the grpc server broadcasts to, which is what clients should use when connecting.
GRPCBroadcastAddress string `mapstructure:"grpcBroadcastAddress" json:"grpcBroadcastAddress,omitempty" default:"127.0.0.1:7070"`

// GRPCInsecure controls whether the grpc server is insecure or uses certs
GRPCInsecure bool `mapstructure:"grpcInsecure" json:"grpcInsecure,omitempty" default:"false"`
}
Expand Down Expand Up @@ -190,6 +193,7 @@ func BindAllEnv(v *viper.Viper) {
_ = v.BindEnv("runtime.url", "SERVER_URL")
_ = v.BindEnv("runtime.grpcPort", "SERVER_GRPC_PORT")
_ = v.BindEnv("runtime.grpcBindAddress", "SERVER_GRPC_BIND_ADDRESS")
_ = v.BindEnv("runtime.grpcBroadcastAddress", "SERVER_GRPC_BROADCAST_ADDRESS")
_ = v.BindEnv("runtime.grpcInsecure", "SERVER_GRPC_INSECURE")
_ = v.BindEnv("services", "SERVER_SERVICES")

Expand Down
2 changes: 1 addition & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func defaultClientOpts() *ClientOpts {
l: &logger,
v: validator.NewDefaultValidator(),
tls: clientConfig.TLSConfig,
hostPort: "localhost:7070",
hostPort: clientConfig.GRPCBroadcastAddress,
filesLoader: types.DefaultLoader,
}
}
Expand Down
33 changes: 30 additions & 3 deletions pkg/client/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,31 @@ func LoadClientConfigFile(files ...[]byte) (*client.ClientConfigFile, error) {
}

func GetClientConfigFromConfigFile(cf *client.ClientConfigFile) (res *client.ClientConfig, err error) {
// if token is empty, throw an error
if cf.Token == "" {
return nil, fmt.Errorf("API token is required. Set it via the HATCHET_CLIENT_TOKEN environment variable.")
}

grpcBroadcastAddress := cf.HostPort
serverURL := cf.HostPort

tokenAddresses, err := getAddressesFromJWT(cf.Token)

if err == nil {
if grpcBroadcastAddress == "" && tokenAddresses.grpcBroadcastAddress != "" {
grpcBroadcastAddress = tokenAddresses.grpcBroadcastAddress
}

if tokenAddresses.serverURL != "" {
serverURL = tokenAddresses.serverURL
}
}

// if there's no broadcast address at this point, throw an error
if grpcBroadcastAddress == "" {
return nil, fmt.Errorf("GRPC broadcast address is required. Set it via the HATCHET_CLIENT_HOST_PORT environment variable.")
}

tlsServerName := cf.TLS.TLSServerName

// if the tls server name is empty, parse the domain from the host:port
Expand All @@ -64,9 +89,11 @@ func GetClientConfigFromConfigFile(cf *client.ClientConfigFile) (res *client.Cli
}

return &client.ClientConfig{
TenantId: cf.TenantId,
TLSConfig: tlsConf,
Token: cf.Token,
TenantId: cf.TenantId,
TLSConfig: tlsConf,
Token: cf.Token,
ServerURL: serverURL,
GRPCBroadcastAddress: grpcBroadcastAddress,
}, nil
}

Expand Down
55 changes: 55 additions & 0 deletions pkg/client/loader/token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package loader

import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)

type addresses struct {
serverURL string
grpcBroadcastAddress string
}

func getAddressesFromJWT(token string) (*addresses, error) {
claims, err := extractClaimsFromJWT(token)
if err != nil {
return nil, err
}

serverURL, ok := claims["server_url"].(string)
if !ok {
return nil, fmt.Errorf("server_url claim not found")
}

grpcBroadcastAddress, ok := claims["grpc_broadcast_address"].(string)
if !ok {
return nil, fmt.Errorf("grpc_broadcast_address claim not found")
}

return &addresses{
serverURL: serverURL,
grpcBroadcastAddress: grpcBroadcastAddress,
}, nil
}

func extractClaimsFromJWT(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}

claimsData, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}

var claims map[string]interface{}
err = json.Unmarshal(claimsData, &claims)
if err != nil {
return nil, err
}

return claims, nil
}
18 changes: 18 additions & 0 deletions pkg/client/loader/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package loader

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractClaimsFromJWT(t *testing.T) {
token := "eyJhbGciOiJFUzI1NiIsICJraWQiOiJRMzNPaGcifQ.eyJhdWQiOiJodHRwczovL2FwcC5kZXYuaGF0Y2hldC10b29scy5jb20iLCAiZXhwIjoxNzE0ODc4NDEyLCAiZ3JwY19icm9hZGNhc3RfYWRkcmVzcyI6IjEyNy4wLjAuMTo3MDcwIiwgImlhdCI6MTcwNzEwMjQxMiwgImlzcyI6Imh0dHBzOi8vYXBwLmRldi5oYXRjaGV0LXRvb2xzLmNvbSIsICJzZXJ2ZXJfdXJsIjoiaHR0cHM6Ly9hcHAuZGV2LmhhdGNoZXQtdG9vbHMuY29tIiwgInN1YiI6IjcwN2QwODU1LTgwYWItNGUxZi1hMTU2LWYxYzQ1NDZjYmY1MiIsICJ0b2tlbl9pZCI6IjI1NzFkODMwLWFmNDgtNDYyZS1hNDFlLTRlZWJkMjUwN2I0NyJ9.abcdefg" // #nosec G101

claims, err := extractClaimsFromJWT(token)

assert.Nil(t, err)

assert.Equal(t, claims["server_url"], "https://app.dev.hatchet-tools.com")
assert.Equal(t, claims["grpc_broadcast_address"], "127.0.0.1:7070")
}
14 changes: 13 additions & 1 deletion python-sdk/hatchet_sdk/loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import yaml
from typing import Any, Optional, Dict
from .token import get_addresses_from_jwt

class ClientTLSConfig:
def __init__(self, tls_strategy: str, cert_file: str, key_file: str, ca_file: str, server_name: str):
Expand Down Expand Up @@ -34,8 +35,19 @@ def load_client_config(self) -> ClientConfig:
config_data = yaml.safe_load(file)

tenant_id = config_data['tenantId'] if 'tenantId' in config_data else self._get_env_var('HATCHET_CLIENT_TENANT_ID')
host_port = config_data['hostPort'] if 'hostPort' in config_data else self._get_env_var('HATCHET_CLIENT_HOST_PORT')
token = config_data['token'] if 'token' in config_data else self._get_env_var('HATCHET_CLIENT_TOKEN')

if not token:
raise ValueError('Token must be set via HATCHET_CLIENT_TOKEN environment variable')

host_port = config_data['hostPort'] if 'hostPort' in config_data else self._get_env_var('HATCHET_CLIENT_HOST_PORT')

if not host_port:
# extract host and port from token
server_url, grpc_broadcast_address = get_addresses_from_jwt(token)

host_port = grpc_broadcast_address

tls_config = self._load_tls_config(config_data['tls'], host_port)

return ClientConfig(tenant_id, tls_config, token, host_port)
Expand Down
19 changes: 19 additions & 0 deletions python-sdk/hatchet_sdk/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import base64
import json

def get_addresses_from_jwt(token: str) -> (str, str):
claims = extract_claims_from_jwt(token)

return claims.get('server_url'), claims.get('grpc_broadcast_address')

def extract_claims_from_jwt(token: str):
parts = token.split('.')
if len(parts) != 3:
raise ValueError('Invalid token format')

claims_part = parts[1]
claims_part += '=' * ((4 - len(claims_part) % 4) % 4) # Padding for base64 decoding
claims_data = base64.urlsafe_b64decode(claims_part)
claims = json.loads(claims_data)

return claims
2 changes: 1 addition & 1 deletion python-sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hatchet-sdk"
version = "0.8.0"
version = "0.9.0"
description = ""
authors = ["Alexander Belanger <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions typescript-sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@hatchet-dev/typescript-sdk",
"version": "0.1.11",
"version": "0.1.12",
"description": "Background task orchestration & visibility for developers",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down Expand Up @@ -72,4 +72,4 @@
"yaml": "^2.3.4",
"zod": "^3.22.4"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export class HatchetClient {
// Initializes a new Client instance.
// Loads config in the following order: config param > yaml file > env vars

const loaded = ConfigLoader.load_client_config({
const loaded = ConfigLoader.loadClientConfig({
path: options?.config_path,
});

Expand Down
10 changes: 5 additions & 5 deletions typescript-sdk/src/util/config-loader/config-loader.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fdescribe('ConfigLoader', () => {
});

it('should load from environment variables', () => {
const config = ConfigLoader.load_client_config();
const config = ConfigLoader.loadClientConfig();
expect(config).toEqual({
host_port: 'HOST_PORT',
log_level: 'INFO',
Expand All @@ -26,7 +26,7 @@ fdescribe('ConfigLoader', () => {

it('should throw an error if the file is not found', () => {
expect(() =>
ConfigLoader.load_client_config({
ConfigLoader.loadClientConfig({
path: './fixtures/not-found.yaml',
})
).toThrow();
Expand All @@ -35,14 +35,14 @@ fdescribe('ConfigLoader', () => {
xit('should throw an error if the yaml file fails validation', () => {
expect(() =>
// This test is failing because there is no invalid state of the yaml file, need to update with tls and mtls settings
ConfigLoader.load_client_config({
ConfigLoader.loadClientConfig({
path: './fixtures/.hatchet-invalid.yaml',
})
).toThrow();
});

it('should favor yaml config over env vars', () => {
const config = ConfigLoader.load_client_config({
const config = ConfigLoader.loadClientConfig({
path: './fixtures/.hatchet.yaml',
});
expect(config).toEqual({
Expand All @@ -61,7 +61,7 @@ fdescribe('ConfigLoader', () => {

xit('should attempt to load the root .hatchet.yaml config', () => {
// i'm not sure the best way to test this, maybe spy on readFileSync called with
const config = ConfigLoader.load_client_config({
const config = ConfigLoader.loadClientConfig({
path: './fixtures/.hatchet.yaml',
});
expect(config).toEqual({
Expand Down

0 comments on commit 73adb77

Please sign in to comment.