Skip to content

Commit 4ac6ae5

Browse files
committed
Refactor internal protocol layer.
1 parent 61b30c2 commit 4ac6ae5

4 files changed

Lines changed: 190 additions & 285 deletions

File tree

paseto/src/main/kotlin/net/aholbrook/paseto/protocol/PasetoV1.kt

Lines changed: 64 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,50 @@ import net.aholbrook.paseto.crypto.pae
1616
import net.aholbrook.paseto.crypto.rsaSign
1717
import net.aholbrook.paseto.crypto.rsaVerify
1818
import net.aholbrook.paseto.decodeOrNull
19+
import net.aholbrook.paseto.exception.InvalidHeaderException
1920
import net.aholbrook.paseto.exception.PasetoParseException
2021
import net.aholbrook.paseto.exception.SignatureVerificationException
2122
import kotlin.io.encoding.Base64
2223

23-
private const val VERSION = "v1"
24-
private const val HEADER_LOCAL: String = VERSION + SEPARATOR + PURPOSE_LOCAL + SEPARATOR // v1.local.
25-
private const val HEADER_PUBLIC: String = VERSION + SEPARATOR + PURPOSE_PUBLIC + SEPARATOR // v1.public.
2624
private val HKDF_INFO_EK: ByteArray = "paseto-encryption-key".toByteArray(Charsets.UTF_8)
2725
private val HKDF_INFO_AK: ByteArray = "paseto-auth-key-for-aead".toByteArray(Charsets.UTF_8)
2826

2927
internal object PasetoV1 : Paseto {
3028
override val version: Version = Version.V1
3129
override val supportsImplicitAssertion: Boolean = false
3230

31+
private fun getNonce(m: ByteArray, n: ByteArray) = hmacSha384(m, n).copyOfRange(0, 32)
32+
3333
override fun encrypt(m: ByteArray, key: SymmetricKey, footer: String, implicitAssertion: String): String {
3434
val cleanup = mutableListOf<Runnable>()
3535

3636
try {
37-
// Verify key version.
38-
val keyMaterial = key.getKeyMaterialFor(Version.V1, Purpose.LOCAL)
39-
40-
val footerBytes = footer.toByteArray(Charsets.UTF_8)
37+
val k = key.getKeyMaterialFor(Version.V1, Purpose.LOCAL)
38+
val h = "v1.local."
39+
val f = footer.toByteArray(Charsets.UTF_8)
40+
val b = generateNonce(32)
41+
cleanup.add { b.fill(0) }
42+
val n = getNonce(m, b)
4143

42-
// Generate n
43-
val random = generateNonce(32)
44-
cleanup.add { random.fill(0) }
45-
val n = ByteArray(32)
46-
System.arraycopy(hmacSha384(m, random), 0, n, 0, n.size)
47-
48-
// Split N into salt/nonce
49-
val salt = ByteArray(HKDF_SALT_LEN)
50-
val nonce = ByteArray(HKDF_SALT_LEN)
51-
System.arraycopy(n, 0, salt, 0, salt.size)
52-
System.arraycopy(n, salt.size, nonce, 0, nonce.size)
44+
val salt = n.copyOfRange(0, HKDF_SALT_LEN)
45+
val nonce = n.copyOfRange(HKDF_SALT_LEN, HKDF_SALT_LEN * 2)
5346

5447
// Create ek/ak for AEAD
55-
val ek = hkdfExtractAndExpand(salt, keyMaterial, HKDF_INFO_EK)
48+
val ek = hkdfExtractAndExpand(salt, k, HKDF_INFO_EK)
5649
cleanup.add { ek.fill(0) }
57-
val ak = hkdfExtractAndExpand(salt, keyMaterial, HKDF_INFO_AK)
50+
val ak = hkdfExtractAndExpand(salt, k, HKDF_INFO_AK)
5851
cleanup.add { ak.fill(0) }
5952

6053
val c = aes256CtrEncrypt(m, ek, nonce)
61-
val preAuth = pae(HEADER_LOCAL.toByteArray(Charsets.UTF_8), n, c, footerBytes)
54+
val preAuth = pae(h.toByteArray(Charsets.UTF_8), n, c, f)
6255
val t = hmacSha384(preAuth, ak)
6356

64-
val nct = ByteArray(n.size + c.size + t.size)
65-
System.arraycopy(n, 0, nct, 0, n.size)
66-
System.arraycopy(c, 0, nct, n.size, c.size)
67-
System.arraycopy(t, 0, nct, n.size + c.size, t.size)
68-
69-
return if (footerBytes.isNotEmpty()) {
70-
HEADER_LOCAL + Base64.UrlSafeNoPadding.encode(nct) + SEPARATOR +
71-
Base64.UrlSafeNoPadding.encode(footerBytes)
72-
} else {
73-
HEADER_LOCAL + Base64.UrlSafeNoPadding.encode(nct)
74-
}
57+
return h + Base64.UrlSafeNoPadding.encode(n + c + t) +
58+
if (f.isEmpty()) {
59+
""
60+
} else {
61+
".${Base64.UrlSafeNoPadding.encode(f)}"
62+
}
7563
} finally {
7664
key.clear()
7765
cleanup.forEach { it.run() }
@@ -87,60 +75,48 @@ internal object PasetoV1 : Paseto {
8775
val cleanup = mutableListOf<Runnable>()
8876

8977
try {
90-
// Verify key version.
91-
val keyMaterial = key.getKeyMaterialFor(Version.V1, Purpose.LOCAL)
92-
93-
// Split token into sections
78+
val k = key.getKeyMaterialFor(Version.V1, Purpose.LOCAL)
79+
val h = "v1.local."
9480
val sections = split(token)
81+
val f = decodeFooter(token, sections, footer) // TODO review
9582

9683
// Check header
97-
checkHeader(token, sections, HEADER_LOCAL)
98-
99-
// Decode footer
100-
val decodedFooter = decodeFooter(token, sections, footer)
84+
if (!token.startsWith(h)) {
85+
throw InvalidHeaderException(sections.version + SEPARATOR + sections.purpose + SEPARATOR, h, token)
86+
}
10187

10288
// Decrypt
10389
val nct = Base64.UrlSafeNoPadding.decodeOrNull(sections.payload)
10490
?: throw PasetoParseException(PasetoParseException.Reason.INVALID_BASE64, token)
105-
val n = ByteArray(32)
106-
val t = ByteArray(SHA384_OUT_LEN)
10791
// verify length
108-
if (nct.size < n.size + t.size + 1) {
92+
if (nct.size < 32 + SHA384_OUT_LEN + 1) {
10993
throw PasetoParseException(PasetoParseException.Reason.PAYLOAD_LENGTH, token).apply {
110-
minLength = n.size + t.size + 1
94+
minLength = 32 + SHA384_OUT_LEN + 1
11195
}
11296
}
113-
val c = ByteArray(nct.size - n.size - t.size)
114-
System.arraycopy(nct, 0, n, 0, n.size)
115-
System.arraycopy(nct, n.size, c, 0, c.size)
116-
System.arraycopy(nct, n.size + c.size, t, 0, t.size)
97+
val n = nct.copyOfRange(0, 32)
98+
val t = nct.copyOfRange(nct.size - SHA384_OUT_LEN, nct.size)
99+
val c = nct.copyOfRange(32, nct.size - SHA384_OUT_LEN)
117100

118101
// Split N into salt/nonce
119-
val salt = ByteArray(HKDF_SALT_LEN)
120-
val nonce = ByteArray(HKDF_SALT_LEN)
121-
System.arraycopy(n, 0, salt, 0, salt.size)
122-
System.arraycopy(n, salt.size, nonce, 0, nonce.size)
102+
val salt = n.copyOfRange(0, HKDF_SALT_LEN)
103+
val nonce = n.copyOfRange(HKDF_SALT_LEN, HKDF_SALT_LEN * 2)
123104

124105
// Create ek/ak for AEAD
125-
val ek = hkdfExtractAndExpand(salt, keyMaterial, HKDF_INFO_EK)
106+
val ek = hkdfExtractAndExpand(salt, k, HKDF_INFO_EK)
126107
cleanup.add { ek.fill(0) }
127-
val ak = hkdfExtractAndExpand(salt, keyMaterial, HKDF_INFO_AK)
108+
val ak = hkdfExtractAndExpand(salt, k, HKDF_INFO_AK)
128109
cleanup.add { ak.fill(0) }
129110

130-
val preAuth = pae(
131-
HEADER_LOCAL.toByteArray(Charsets.UTF_8),
132-
n,
133-
c,
134-
decodedFooter.toByteArray(Charsets.UTF_8),
135-
)
111+
val preAuth = pae(h.toByteArray(Charsets.UTF_8), n, c, f.toByteArray(Charsets.UTF_8))
136112
val t2 = hmacSha384(preAuth, ak)
137113
if (!t.constantTimeEquals(t2)) {
138114
throw SignatureVerificationException(token)
139115
}
140116

141117
val m = aes256CtrDecrypt(c, ek, nonce)
142118

143-
return Pair(m.toString(Charsets.UTF_8), decodedFooter)
119+
return Pair(m.toString(Charsets.UTF_8), f)
144120
} finally {
145121
key.clear()
146122
cleanup.forEach { it.run() }
@@ -154,24 +130,19 @@ internal object PasetoV1 : Paseto {
154130
implicitAssertion: String,
155131
): String {
156132
try {
157-
// Verify key version.
158-
val keyMaterial = secretKey.getKeyMaterialFor(Version.V1, Purpose.PUBLIC)
159-
160-
val footerBytes = footer.toByteArray(Charsets.UTF_8)
161-
162-
val m2 = pae(HEADER_PUBLIC.toByteArray(Charsets.UTF_8), m, footerBytes)
163-
val sig = rsaSign(m2, keyMaterial)
164-
165-
val msig = ByteArray(sig.size + m.size)
166-
System.arraycopy(m, 0, msig, 0, m.size)
167-
System.arraycopy(sig, 0, msig, m.size, sig.size)
168-
169-
return if (footerBytes.isNotEmpty()) {
170-
HEADER_PUBLIC + Base64.UrlSafeNoPadding.encode(msig) + SEPARATOR +
171-
Base64.UrlSafeNoPadding.encode(footerBytes)
172-
} else {
173-
HEADER_PUBLIC + Base64.UrlSafeNoPadding.encode(msig)
174-
}
133+
val k = secretKey.getKeyMaterialFor(Version.V1, Purpose.PUBLIC)
134+
val h = "v1.public."
135+
val f = footer.toByteArray(Charsets.UTF_8)
136+
137+
val m2 = pae(h.toByteArray(Charsets.UTF_8), m, f)
138+
val sig = rsaSign(m2, k)
139+
140+
return h + Base64.UrlSafeNoPadding.encode(m + sig) +
141+
if (f.isEmpty()) {
142+
""
143+
} else {
144+
".${Base64.UrlSafeNoPadding.encode(f)}"
145+
}
175146
} finally {
176147
secretKey.clear()
177148
}
@@ -183,37 +154,33 @@ internal object PasetoV1 : Paseto {
183154
footer: String,
184155
implicitAssertion: String,
185156
): Pair<String, String> {
186-
// Verify key version.
187-
val keyMaterial = publicKey.getKeyMaterialFor(Version.V1, Purpose.PUBLIC)
188-
189-
// Split token into sections
157+
val pk = publicKey.getKeyMaterialFor(Version.V1, Purpose.PUBLIC)
158+
val h = "v1.public."
190159
val sections = split(token)
160+
val f = decodeFooter(token, sections, footer) // TODO review
191161

192162
// Check header
193-
checkHeader(token, sections, HEADER_PUBLIC)
194-
195-
// Decode footer
196-
val decodedFooter = decodeFooter(token, sections, footer)
163+
if (!token.startsWith(h)) {
164+
throw InvalidHeaderException(sections.version + SEPARATOR + sections.purpose + SEPARATOR, h, token)
165+
}
197166

198167
// Verify
199-
val msig = Base64.UrlSafeNoPadding.decodeOrNull(sections.payload)
168+
val sm = Base64.UrlSafeNoPadding.decodeOrNull(sections.payload)
200169
?: throw PasetoParseException(PasetoParseException.Reason.INVALID_BASE64, token)
201-
val s = ByteArray(RSA_SIGNATURE_LEN)
202170
// verify length
203-
if (msig.size < s.size + 1) {
171+
if (sm.size < RSA_SIGNATURE_LEN + 1) {
204172
throw PasetoParseException(PasetoParseException.Reason.PAYLOAD_LENGTH, token).apply {
205-
minLength = s.size + 1
173+
minLength = RSA_SIGNATURE_LEN + 1
206174
}
207175
}
208-
val m = ByteArray(msig.size - s.size)
209-
System.arraycopy(msig, msig.size - s.size, s, 0, s.size)
210-
System.arraycopy(msig, 0, m, 0, m.size)
176+
val s = sm.copyOfRange(sm.size - RSA_SIGNATURE_LEN, sm.size)
177+
val m = sm.copyOfRange(0, sm.size - RSA_SIGNATURE_LEN)
211178

212-
val m2 = pae(HEADER_PUBLIC.toByteArray(Charsets.UTF_8), m, decodedFooter.toByteArray(Charsets.UTF_8))
213-
if (!rsaVerify(m2, s, keyMaterial)) {
179+
val m2 = pae(h.toByteArray(Charsets.UTF_8), m, f.toByteArray(Charsets.UTF_8))
180+
if (!rsaVerify(m2, s, pk)) {
214181
throw SignatureVerificationException(token)
215182
}
216183

217-
return Pair(m.toString(Charsets.UTF_8), decodedFooter)
184+
return Pair(m.toString(Charsets.UTF_8), f)
218185
}
219186
}

0 commit comments

Comments
 (0)