-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtoken.go
More file actions
334 lines (275 loc) · 7.75 KB
/
token.go
File metadata and controls
334 lines (275 loc) · 7.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
package barong
import (
"crypto/rand"
"encoding/base64"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/layer-3/barong/internal/eth"
)
// TokenType represents the type of token
type TokenType string
const (
// TokenTypeChallenge represents a challenge token
TokenTypeChallenge TokenType = "session:challenge"
// TokenTypeAccess represents an access token
TokenTypeAccess TokenType = "session:access"
// TokenTypeRefresh represents a refresh token
TokenTypeRefresh TokenType = "session:refresh"
// DefaultChallengeExpiry is the default expiration time for challenge tokens
DefaultChallengeExpiry = 5 * time.Minute
// DefaultAccessExpiry is the default expiration time for access tokens
DefaultAccessExpiry = 5 * time.Minute
// DefaultRefreshExpiry is the default expiration time for refresh tokens
DefaultRefreshExpiry = 120 * time.Hour // 5 days
)
// Token represents a JWT token
type Token struct {
jwt string
claims jwt.Claims
signer eth.Signer
tokenType TokenType
}
// ChallengeClaims represents the claims for a challenge token
type ChallengeClaims struct {
jwt.RegisteredClaims
Nonce string `json:"nonce"`
}
// AccessClaims represents the claims for an access token
type AccessClaims struct {
jwt.RegisteredClaims
RefreshID string `json:"rid,omitempty"`
}
// RefreshClaims represents the claims for a refresh token
type RefreshClaims struct {
jwt.RegisteredClaims
}
// NewToken creates a new token based on the provided claims and signer
func NewToken(address string, tokenType TokenType, signer eth.Signer) (Token, error) {
jti := uuid.New().String()
now := time.Now()
var claims jwt.Claims
var expiresAt time.Time
switch tokenType {
case TokenTypeChallenge:
nonce, err := generateNonce(32)
if err != nil {
return Token{}, err
}
expiresAt = now.Add(DefaultChallengeExpiry)
claims = &ChallengeClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: address,
Audience: jwt.ClaimStrings{string(TokenTypeChallenge)},
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
ID: jti,
},
Nonce: nonce,
}
case TokenTypeAccess:
expiresAt = now.Add(DefaultAccessExpiry)
claims = &AccessClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: address,
Audience: jwt.ClaimStrings{string(TokenTypeAccess)},
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
ID: jti,
},
}
case TokenTypeRefresh:
expiresAt = now.Add(DefaultRefreshExpiry)
claims = &RefreshClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: address,
Audience: jwt.ClaimStrings{string(TokenTypeRefresh)},
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
ID: jti,
},
}
default:
return Token{}, ErrInvalidToken
}
// Create JWT token
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
// Sign the token
sig, err := signer.Sign([]byte(token.Raw))
if err != nil {
return Token{}, err
}
// ES256 expects the signature in the format R || S
jwtSignature := append(sig.R, sig.S...)
// Set the signature
token.Signature = base64.RawURLEncoding.EncodeToString(jwtSignature)
// Encode the token
tokenString := token.Raw + "." + token.Signature
return Token{
jwt: tokenString,
claims: claims,
signer: signer,
tokenType: tokenType,
}, nil
}
// SetRefreshID sets the refresh ID for an access token
func (t *Token) SetRefreshID(refreshID string) error {
if t.tokenType != TokenTypeAccess {
return ErrInvalidToken
}
accessClaims, ok := t.claims.(*AccessClaims)
if !ok {
return ErrInvalidClaims
}
accessClaims.RefreshID = refreshID
// Re-create the token with updated claims
token := jwt.NewWithClaims(jwt.SigningMethodES256, accessClaims)
// Sign the token
sig, err := t.signer.Sign([]byte(token.Raw))
if err != nil {
return err
}
// ES256 expects the signature in the format R || S
jwtSignature := append(sig.R, sig.S...)
// Set the signature
token.Signature = base64.RawURLEncoding.EncodeToString(jwtSignature)
// Encode the token
t.jwt = token.Raw + "." + token.Signature
t.claims = accessClaims
return nil
}
// String returns the JWT string representation of the token
func (t Token) String() string {
return t.jwt
}
// Type returns the token type
func (t Token) Type() TokenType {
return t.tokenType
}
// Claims returns the token claims
func (t Token) Claims() jwt.Claims {
return t.claims
}
// GetNonce returns the nonce from a challenge token
func (t Token) GetNonce() (string, error) {
if t.tokenType != TokenTypeChallenge {
return "", ErrInvalidToken
}
challengeClaims, ok := t.claims.(*ChallengeClaims)
if !ok {
return "", ErrInvalidClaims
}
return challengeClaims.Nonce, nil
}
// GetJTI returns the JTI from a token
func (t Token) GetJTI() (string, error) {
id, err := t.claims.GetID()
if err != nil {
return "", ErrInvalidClaims
}
return id, nil
}
// GetSubject returns the subject from a token
func (t Token) GetSubject() (string, error) {
sub, err := t.claims.GetSubject()
if err != nil {
return "", ErrInvalidClaims
}
return sub, nil
}
// GetRefreshID returns the refresh ID from an access token
func (t Token) GetRefreshID() (string, error) {
if t.tokenType != TokenTypeAccess {
return "", ErrInvalidToken
}
accessClaims, ok := t.claims.(*AccessClaims)
if !ok {
return "", ErrInvalidClaims
}
return accessClaims.RefreshID, nil
}
// GetExpiresAt returns the expiration time of the token
func (t Token) GetExpiresAt() (time.Time, error) {
exp, err := t.claims.GetExpirationTime()
if err != nil {
return time.Time{}, ErrInvalidClaims
}
return exp.Time, nil
}
// Validate validates the token
func (t Token) Validate() error {
// Check token type audience
aud, err := t.claims.GetAudience()
if err != nil {
return ErrInvalidClaims
}
if len(aud) == 0 || aud[0] != string(t.tokenType) {
return ErrInvalidAudience
}
// Check expiration
exp, err := t.claims.GetExpirationTime()
if err != nil {
return ErrInvalidClaims
}
if exp.Before(time.Now()) {
return ErrTokenExpired
}
return nil
}
// ParseToken parses a JWT string and returns a Token
func ParseToken(tokenString string, expectedType TokenType) (Token, error) {
token, err := jwt.ParseWithClaims(tokenString, createClaimsForType(expectedType), keyFunc)
if err != nil {
return Token{}, ErrInvalidToken
}
if !token.Valid {
return Token{}, ErrInvalidToken
}
// Verify audience
claims := token.Claims
aud, err := claims.GetAudience()
if err != nil {
return Token{}, ErrInvalidClaims
}
if len(aud) == 0 || aud[0] != string(expectedType) {
return Token{}, ErrInvalidAudience
}
return Token{
jwt: tokenString,
claims: claims,
tokenType: expectedType,
}, nil
}
// createClaimsForType creates the appropriate claims object based on token type
func createClaimsForType(tokenType TokenType) jwt.Claims {
switch tokenType {
case TokenTypeChallenge:
return &ChallengeClaims{}
case TokenTypeAccess:
return &AccessClaims{}
case TokenTypeRefresh:
return &RefreshClaims{}
default:
return jwt.MapClaims{}
}
}
// keyFunc is used by jwt.Parse to validate the signing method
func keyFunc(token *jwt.Token) (interface{}, error) {
// Verify signing method
if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
return nil, ErrInvalidSigningMethod
}
// In a real implementation, we would retrieve the public key
// associated with the token here. For simplicity, we just validate
// the signing method.
return nil, nil
}
// generateNonce generates a secure random nonce of the specified length
func generateNonce(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}