Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
13 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@ class TestKeyStore {
}

@Test
fun testDecryptMnemonicAes256() {
fun testDecryptMnemonicAes192Ctr() {
val keyStore = StoredKey("Test Wallet", "password".toByteArray(), StoredKeyEncryption.AES192CTR)
val result = keyStore.decryptMnemonic("wrong".toByteArray())
val result2 = keyStore.decryptMnemonic("password".toByteArray())

assertNull(result)
assertNotNull(result2)
}

@Test
fun testDecryptMnemonicAes256Ctr() {
val keyStore = StoredKey("Test Wallet", "password".toByteArray(), StoredKeyEncryption.AES256CTR)
val result = keyStore.decryptMnemonic("wrong".toByteArray())
val result2 = keyStore.decryptMnemonic("password".toByteArray())
Expand Down
1 change: 0 additions & 1 deletion include/TrustWalletCore/TWStoredKeyEncryption.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ TW_EXTERN_C_BEGIN
TW_EXPORT_ENUM(uint32_t)
enum TWStoredKeyEncryption {
TWStoredKeyEncryptionAes128Ctr = 0,
TWStoredKeyEncryptionAes128Cbc = 1,
TWStoredKeyEncryptionAes192Ctr = 2,
TWStoredKeyEncryptionAes256Ctr = 3,
};
Expand Down
15 changes: 15 additions & 0 deletions src/Data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,19 @@ Data subData(const Data& data, size_t startIndex) {
return TW::data(data.data() + startIndex, subLength);
}

bool isEqualConstantTime(const Data& in_a, const Data& in_b) {
if (in_a.size() != in_b.size()) {
return false;
}

const volatile unsigned char *a = in_a.data();
const volatile unsigned char *b = in_b.data();
unsigned char result = 0;

for (size_t i = 0; i < in_a.size(); i++) {
result |= a[i] ^ b[i];
}
return result == 0;
}

} // namespace TW
5 changes: 5 additions & 0 deletions src/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ inline bool has_prefix(const Data& data, T& prefix) {
return std::equal(prefix.begin(), prefix.end(), data.begin(), data.begin() + std::min(data.size(), prefix.size()));
}

/// Constant-time comparison to prevent timing attacks.
/// Note: This function assumes that `a` and `b` are of the same size. If they are not, it will return false immediately.
/// https://github.com/openssl/openssl/blob/94c36852d254a626739667874587b5364ddf087e/crypto/cpuid.c#L198
bool isEqualConstantTime(const Data& in_a, const Data& in_b);

// Custom hash function for `Data` type.
struct DataHash {
std::size_t operator()(const Data& data) const {
Expand Down
6 changes: 3 additions & 3 deletions src/HDWallet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ PrivateKey HDWallet<seedSize>::getKeyByCurve(TWCurve curve, const DerivationPath
auto node = getNode<seedSize>(*this, curve, derivationPath);
switch (privateKeyType) {
case TWPrivateKeyTypeCardano: {
if (derivationPath.indices.size() < 4 || derivationPath.indices[3].value > 1) {
// invalid derivation path
return PrivateKey(Data(PrivateKey::cardanoKeySize), curve);
if (derivationPath.indices.size() < 5 || derivationPath.indices[3].value > 1) {
TW::memzero(&node);
throw std::invalid_argument("Invalid derivation path");
}
const DerivationPath stakingPath = cardanoStakingDerivationPath(derivationPath);

Expand Down
30 changes: 26 additions & 4 deletions src/Keystore/AESParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "../HexCoding.h"

#include <sstream>
#include <TrezorCrypto/rand.h>

using namespace TW;
Expand All @@ -21,17 +22,21 @@ Data generateIv(std::size_t blockSize = TW::Keystore::gBlockSize) {
static TWStoredKeyEncryption getCipher(const std::string& cipher) {
if (cipher == Keystore::gAes128Ctr) {
return TWStoredKeyEncryption::TWStoredKeyEncryptionAes128Ctr;
} else if (cipher == Keystore::gAes192Ctr) {
}
if (cipher == Keystore::gAes192Ctr) {
return TWStoredKeyEncryption::TWStoredKeyEncryptionAes192Ctr;
} else if (cipher == Keystore::gAes256Ctr) {
}
if (cipher == Keystore::gAes256Ctr) {
return TWStoredKeyEncryption::TWStoredKeyEncryptionAes256Ctr;
}
return TWStoredKeyEncryptionAes128Ctr;

std::stringstream ss;
ss << "Unsupported cipher: " << cipher;
throw std::invalid_argument(ss.str());
}

const std::unordered_map<TWStoredKeyEncryption, Keystore::AESParameters> gEncryptionRegistry{
{TWStoredKeyEncryptionAes128Ctr, Keystore::AESParameters{.mKeyLength = Keystore::A128, .mCipher = Keystore::gAes128Ctr, .mCipherEncryption = TWStoredKeyEncryptionAes128Ctr, .iv{}}},
{TWStoredKeyEncryptionAes128Cbc, Keystore::AESParameters{.mKeyLength = Keystore::A128, .mCipher = Keystore::gAes128Cbc, .mCipherEncryption = TWStoredKeyEncryptionAes128Cbc, .iv{}}},
{TWStoredKeyEncryptionAes192Ctr, Keystore::AESParameters{.mKeyLength = Keystore::A192, .mCipher = Keystore::gAes192Ctr, .mCipherEncryption = TWStoredKeyEncryptionAes192Ctr, .iv{}}},
{TWStoredKeyEncryptionAes256Ctr, Keystore::AESParameters{.mKeyLength = Keystore::A256, .mCipher = Keystore::gAes256Ctr, .mCipherEncryption = TWStoredKeyEncryptionAes256Ctr, .iv{}}}
};
Expand All @@ -43,6 +48,15 @@ namespace CodingKeys {
static const auto iv = "iv";
} // namespace CodingKeys

std::string toString(AESValidationError error) {
switch (error) {
case AESValidationError::InvalidIV:
return "IV must be 16 bytes long";
default:
return "Unknown error";
}
}

/// Initializes `AESParameters` with a JSON object.
AESParameters AESParameters::AESParametersFromJson(const nlohmann::json& json, const std::string& cipher) {
auto parameters = AESParameters::AESParametersFromEncryption(getCipher(cipher));
Expand All @@ -64,4 +78,12 @@ AESParameters AESParameters::AESParametersFromEncryption(TWStoredKeyEncryption e
return parameters;
}

std::optional<AESValidationError> AESParameters::validate() const noexcept {
if (iv.size() != static_cast<std::size_t>(mBlockSize)) {
return AESValidationError::InvalidIV;
}

return {};
}

} // namespace TW::Keystore
11 changes: 7 additions & 4 deletions src/Keystore/AESParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ enum AESKeySize : std::int32_t {

inline constexpr std::size_t gBlockSize{16};
inline constexpr const char* gAes128Ctr{"aes-128-ctr"};
inline constexpr const char* gAes128Cbc{"aes-128-cbc"};
inline constexpr const char* gAes192Ctr{"aes-192-ctr"};
inline constexpr const char* gAes256Ctr{"aes-256-ctr"};

enum class AESValidationError {
InvalidIV,
};

std::string toString(AESValidationError error);

// AES128/192/256 parameters.
struct AESParameters {
// For AES, your block length is always going to be 128 bits/16 bytes
Expand All @@ -43,9 +48,7 @@ struct AESParameters {
nlohmann::json json() const;

/// Validates AES parameters.
[[nodiscard]] bool isValid() const {
return iv.size() == static_cast<std::size_t>(mBlockSize);
}
[[nodiscard]] std::optional<AESValidationError> validate() const noexcept;
};

} // namespace TW::Keystore
50 changes: 27 additions & 23 deletions src/Keystore/EncryptionParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "EncryptionParameters.h"

#include "memory/memzero_wrapper.h"
#include "../Hash.h"

#include <TrezorCrypto/aes.h>
Expand Down Expand Up @@ -40,8 +41,10 @@ static const auto mac = "mac";
EncryptionParameters::EncryptionParameters(const nlohmann::json& json) {
auto cipher = json[CodingKeys::cipher].get<std::string>();
cipherParams = AESParameters::AESParametersFromJson(json[CodingKeys::cipherParams], cipher);
if (!cipherParams.isValid()) {
throw std::invalid_argument("Invalid cipher params");
if (const auto error = cipherParams.validate(); error.has_value()) {
std::stringstream ss;
ss << "Invalid cipher params: " << toString(*error);
throw std::invalid_argument(ss.str());
}

auto kdf = json[CodingKeys::kdf].get<std::string>();
Expand All @@ -68,23 +71,22 @@ nlohmann::json EncryptionParameters::json() const {
return j;
}

EncryptedPayload::EncryptedPayload(const Data& password, const Data& data, const EncryptionParameters& params)
: params(std::move(params)), _mac() {
if (!this->params.cipherParams.isValid()) {
throw std::invalid_argument("Invalid cipher params");
EncryptedPayload::EncryptedPayload(const Data& password, const Data& data, const AESParameters& cipherParams, const ScryptParameters& scryptParams) {
if (const auto error = cipherParams.validate(); error.has_value()) {
std::stringstream ss;
ss << "Invalid cipher params: " << toString(*error);
throw std::invalid_argument(ss.str());
}

auto scryptParams = std::get<ScryptParameters>(this->params.kdfParams);
auto derivedKey = Data(scryptParams.desiredKeyLength);
scrypt(reinterpret_cast<const byte*>(password.data()), password.size(), scryptParams.salt.data(),
scryptParams.salt.size(), scryptParams.n, scryptParams.r, scryptParams.p, derivedKey.data(),
scryptParams.desiredKeyLength);

aes_encrypt_ctx ctx;
auto result = 0;
switch(this->params.cipherParams.mCipherEncryption) {
switch(cipherParams.mCipherEncryption) {
case TWStoredKeyEncryptionAes128Ctr:
case TWStoredKeyEncryptionAes128Cbc:
result = aes_encrypt_key128(derivedKey.data(), &ctx);
break;
case TWStoredKeyEncryptionAes192Ctr:
Expand All @@ -96,19 +98,23 @@ EncryptedPayload::EncryptedPayload(const Data& password, const Data& data, const
}
assert(result == EXIT_SUCCESS);
if (result == EXIT_SUCCESS) {
Data iv = this->params.cipherParams.iv;
Data iv = cipherParams.iv;
// iv size should have been validated in `AESParameters::isValid()`.
assert(iv.size() == gBlockSize);

params = { cipherParams, scryptParams };
encrypted = Data(data.size());
aes_ctr_encrypt(data.data(), encrypted.data(), static_cast<int>(data.size()), iv.data(), aes_ctr_cbuf_inc, &ctx);
_mac = computeMAC(derivedKey.end() - params.getKeyBytesSize(), derivedKey.end(), encrypted);
}

memzero(&ctx, sizeof(ctx));
memzero(derivedKey.data(), derivedKey.size());
}

EncryptedPayload::~EncryptedPayload() {
std::fill(encrypted.begin(), encrypted.end(), 0);
std::fill(_mac.begin(), _mac.end(), 0);
memzero(encrypted.data(), encrypted.size());
memzero(_mac.data(), _mac.size());
}

Data EncryptedPayload::decrypt(const Data& password) const {
Expand All @@ -131,36 +137,34 @@ Data EncryptedPayload::decrypt(const Data& password) const {
throw DecryptionError::unsupportedKDF;
}

if (mac != _mac) {
if (!isEqualConstantTime(mac, _mac)) {
memzero(derivedKey.data(), derivedKey.size());
throw DecryptionError::invalidPassword;
}

// Even though the cipher params should have been validated in `EncryptedPayload` constructor,
// double check them here.
if (!params.cipherParams.isValid()) {
if (params.cipherParams.validate().has_value()) {
throw DecryptionError::invalidCipher;
}
assert(params.cipherParams.iv.size() == gBlockSize);

Data decrypted(encrypted.size());
Data iv = params.cipherParams.iv;
const auto encryption = params.cipherParams.mCipherEncryption;
if (encryption == TWStoredKeyEncryptionAes128Ctr || encryption == TWStoredKeyEncryptionAes256Ctr) {
if (encryption == TWStoredKeyEncryptionAes128Ctr
|| encryption == TWStoredKeyEncryptionAes192Ctr
|| encryption == TWStoredKeyEncryptionAes256Ctr) {
aes_encrypt_ctx ctx;
[[maybe_unused]] auto result = aes_encrypt_key(derivedKey.data(), params.getKeyBytesSize(), &ctx);
assert(result != EXIT_FAILURE);

aes_ctr_decrypt(encrypted.data(), decrypted.data(), static_cast<int>(encrypted.size()), iv.data(),
aes_ctr_cbuf_inc, &ctx);
} else if (encryption == TWStoredKeyEncryptionAes128Cbc) {
aes_decrypt_ctx ctx;
[[maybe_unused]] auto result = aes_decrypt_key(derivedKey.data(), params.getKeyBytesSize(), &ctx);
assert(result != EXIT_FAILURE);

for (auto i = 0ul; i < encrypted.size(); i += params.getKeyBytesSize()) {
aes_cbc_decrypt(encrypted.data() + i, decrypted.data() + i, params.getKeyBytesSize(), iv.data(), &ctx);
}
memzero(&ctx, sizeof(ctx));
memzero(derivedKey.data(), derivedKey.size());
} else {
memzero(derivedKey.data(), derivedKey.size());
throw DecryptionError::unsupportedCipher;
}

Expand Down
19 changes: 3 additions & 16 deletions src/Keystore/EncryptionParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,6 @@ namespace TW::Keystore {

/// Set of parameters used when encoding
struct EncryptionParameters {
static EncryptionParameters getPreset(enum TWStoredKeyEncryptionLevel preset, enum TWStoredKeyEncryption encryption = TWStoredKeyEncryptionAes128Ctr) {
switch (preset) {
case TWStoredKeyEncryptionLevelMinimal:
return { AESParameters::AESParametersFromEncryption(encryption), ScryptParameters::minimal() };
case TWStoredKeyEncryptionLevelWeak:
case TWStoredKeyEncryptionLevelDefault:
default:
return { AESParameters::AESParametersFromEncryption(encryption), ScryptParameters::weak() };
case TWStoredKeyEncryptionLevelStandard:
return { AESParameters::AESParametersFromEncryption(encryption), ScryptParameters::standard() };
}
}

std::int32_t getKeyBytesSize() const noexcept {
return cipherParams.mKeyLength;
}
Expand Down Expand Up @@ -96,9 +83,9 @@ struct EncryptedPayload {
, encrypted(std::move(encrypted))
, _mac(std::move(mac)) {}

/// Initializes by encrypting data with a password
/// using standard values.
EncryptedPayload(const Data& password, const Data& data, const EncryptionParameters& params);
/// Initializes by encrypting data with a password using standard values.
/// Note that we enforce to use Scrypt as KDF for new wallets encryption.
EncryptedPayload(const Data& password, const Data& data, const AESParameters& cipherParams, const ScryptParameters& scryptParams);

/// Initializes with a JSON object.
explicit EncryptedPayload(const nlohmann::json& json);
Expand Down
13 changes: 13 additions & 0 deletions src/Keystore/ScryptParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ std::string toString(const ScryptValidationError error) {
}
}

ScryptParameters ScryptParameters::getPreset(TWStoredKeyEncryptionLevel preset) {
switch (preset) {
case TWStoredKeyEncryptionLevelMinimal:
return minimal();
case TWStoredKeyEncryptionLevelStandard:
return standard();
case TWStoredKeyEncryptionLevelWeak:
case TWStoredKeyEncryptionLevelDefault:
default:
return weak();
}
}

ScryptParameters ScryptParameters::minimal() {
return { internal::randomSalt(), minimalN, defaultR, minimalP, defaultDesiredKeyLength };
}
Expand Down
4 changes: 4 additions & 0 deletions src/Keystore/ScryptParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "Data.h"
#include "TrustWalletCore/TWStoredKeyEncryptionLevel.h"
#include "../HexCoding.h"

#include <nlohmann/json.hpp>
Expand Down Expand Up @@ -63,6 +64,9 @@ struct ScryptParameters {
/// Block size factor.
uint32_t r = defaultR;

/// Returns a preset of Scrypt encryption parameters for the given encryption level.
static ScryptParameters getPreset(TWStoredKeyEncryptionLevel preset);

/// Generates Scrypt encryption parameters with the minimal sufficient level (4096), and with a random salt.
static ScryptParameters minimal();
/// Generates Scrypt encryption parameters with the weak sufficient level (16k), and with a random salt.
Expand Down
Loading
Loading