diff --git a/doc/release-notes-6017.md b/doc/release-notes-6017.md new file mode 100644 index 0000000000..03d05a42aa --- /dev/null +++ b/doc/release-notes-6017.md @@ -0,0 +1,3 @@ +RPC changes +----------- +- A new `sethdseed` RPC allows users to initialize their blank HD wallets with an HD seed. **A new backup must be made when a new HD seed is set.** This command cannot replace an existing HD seed if one is already set. `sethdseed` uses WIF private key as a seed. If you have a mnemonic, use the `upgradetohd` RPC. diff --git a/src/checkqueue.h b/src/checkqueue.h index 68357c2177..e94103f416 100644 --- a/src/checkqueue.h +++ b/src/checkqueue.h @@ -65,7 +65,7 @@ class CCheckQueue bool m_request_stop GUARDED_BY(m_mutex){false}; /** Internal function that does bulk of the verification work. */ - bool Loop(bool fMaster) + bool Loop(bool fMaster) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { std::condition_variable& cond = fMaster ? m_master_cv : m_worker_cv; std::vector vChecks; @@ -139,7 +139,7 @@ class CCheckQueue } //! Create a pool of new worker threads. - void StartWorkerThreads(const int threads_num) + void StartWorkerThreads(const int threads_num) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { { LOCK(m_mutex); @@ -157,13 +157,13 @@ class CCheckQueue } //! Wait until execution finishes, and return whether all evaluations were successful. - bool Wait() + bool Wait() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { return Loop(true /* master thread */); } //! Add a batch of checks to the queue - void Add(std::vector& vChecks) + void Add(std::vector& vChecks) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { if (vChecks.empty()) { return; @@ -186,7 +186,7 @@ class CCheckQueue } //! Stop all of the worker threads. - void StopWorkerThreads() + void StopWorkerThreads() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { WITH_LOCK(m_mutex, m_request_stop = true); m_worker_cv.notify_all(); diff --git a/src/coinjoin/coinjoin.h b/src/coinjoin/coinjoin.h index bb42917d32..06dd934a83 100644 --- a/src/coinjoin/coinjoin.h +++ b/src/coinjoin/coinjoin.h @@ -368,15 +368,22 @@ class CDSTXManager void AddDSTX(const CCoinJoinBroadcastTx& dstx) EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); CCoinJoinBroadcastTx GetDSTX(const uint256& hash) EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); - void UpdatedBlockTip(const CBlockIndex* pindex, const llmq::CChainLocksHandler& clhandler, const CMasternodeSync& mn_sync); - void NotifyChainLock(const CBlockIndex* pindex, const llmq::CChainLocksHandler& clhandler, const CMasternodeSync& mn_sync); + void UpdatedBlockTip(const CBlockIndex* pindex, const llmq::CChainLocksHandler& clhandler, + const CMasternodeSync& mn_sync) + EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); + void NotifyChainLock(const CBlockIndex* pindex, const llmq::CChainLocksHandler& clhandler, + const CMasternodeSync& mn_sync) + EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); void TransactionAddedToMempool(const CTransactionRef& tx) EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); - void BlockConnected(const std::shared_ptr& pblock, const CBlockIndex* pindex) EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); - void BlockDisconnected(const std::shared_ptr& pblock, const CBlockIndex*) EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); + void BlockConnected(const std::shared_ptr& pblock, const CBlockIndex* pindex) + EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); + void BlockDisconnected(const std::shared_ptr& pblock, const CBlockIndex*) + EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); private: - void CheckDSTXes(const CBlockIndex* pindex, const llmq::CChainLocksHandler& clhandler); + void CheckDSTXes(const CBlockIndex* pindex, const llmq::CChainLocksHandler& clhandler) + EXCLUSIVE_LOCKS_REQUIRED(!cs_mapdstx); void UpdateDSTXConfirmedHeight(const CTransactionRef& tx, std::optional nHeight); }; diff --git a/src/governance/governance.cpp b/src/governance/governance.cpp index 023d0603e9..bb1f654cc0 100644 --- a/src/governance/governance.cpp +++ b/src/governance/governance.cpp @@ -1223,11 +1223,11 @@ void CGovernanceManager::RequestGovernanceObject(CNode* pfrom, const uint256& nH int CGovernanceManager::RequestGovernanceObjectVotes(CNode& peer, CConnman& connman) const { - std::array nodeCopy{&peer}; - return RequestGovernanceObjectVotes(nodeCopy, connman); + const std::vector vNodeCopy{&peer}; + return RequestGovernanceObjectVotes(vNodeCopy, connman); } -int CGovernanceManager::RequestGovernanceObjectVotes(Span vNodesCopy, CConnman& connman) const +int CGovernanceManager::RequestGovernanceObjectVotes(const std::vector& vNodesCopy, CConnman& connman) const { static std::map > mapAskedRecently; @@ -1501,7 +1501,7 @@ void CGovernanceManager::UpdatedBlockTip(const CBlockIndex* pindex, CConnman& co void CGovernanceManager::RequestOrphanObjects(CConnman& connman) { - std::vector vNodesCopy = connman.CopyNodeVector(CConnman::FullyConnectedOnly); + const CConnman::NodesSnapshot snap{connman, /* filter = */ CConnman::FullyConnectedOnly}; std::vector vecHashesFiltered; { @@ -1517,15 +1517,13 @@ void CGovernanceManager::RequestOrphanObjects(CConnman& connman) LogPrint(BCLog::GOBJECT, "CGovernanceObject::RequestOrphanObjects -- number objects = %d\n", vecHashesFiltered.size()); for (const uint256& nHash : vecHashesFiltered) { - for (CNode* pnode : vNodesCopy) { + for (CNode* pnode : snap.Nodes()) { if (!pnode->CanRelay()) { continue; } RequestGovernanceObject(pnode, nHash, connman); } } - - connman.ReleaseNodeVector(vNodesCopy); } void CGovernanceManager::CleanOrphanObjects() diff --git a/src/governance/governance.h b/src/governance/governance.h index 0747e5dafa..03e8d01b56 100644 --- a/src/governance/governance.h +++ b/src/governance/governance.h @@ -358,7 +358,7 @@ class CGovernanceManager : public GovernanceStore void InitOnLoad(); int RequestGovernanceObjectVotes(CNode& peer, CConnman& connman) const; - int RequestGovernanceObjectVotes(Span vNodesCopy, CConnman& connman) const; + int RequestGovernanceObjectVotes(const std::vector& vNodesCopy, CConnman& connman) const; /* * Trigger Management (formerly CGovernanceTriggerManager) diff --git a/src/httpserver.cpp b/src/httpserver.cpp index 2572df0cf3..a749a02699 100644 --- a/src/httpserver.cpp +++ b/src/httpserver.cpp @@ -83,7 +83,7 @@ class WorkQueue { } /** Enqueue a work item */ - bool Enqueue(WorkItem* item) + bool Enqueue(WorkItem* item) EXCLUSIVE_LOCKS_REQUIRED(!cs) { LOCK(cs); if (!running || queue.size() >= maxDepth) { @@ -94,7 +94,7 @@ class WorkQueue return true; } /** Thread function */ - void Run() + void Run() EXCLUSIVE_LOCKS_REQUIRED(!cs) { while (true) { std::unique_ptr i; @@ -111,7 +111,7 @@ class WorkQueue } } /** Interrupt and exit loops */ - void Interrupt() + void Interrupt() EXCLUSIVE_LOCKS_REQUIRED(!cs) { LOCK(cs); running = false; diff --git a/src/i2p.h b/src/i2p.h index cb2efedba8..3899dbf81f 100644 --- a/src/i2p.h +++ b/src/i2p.h @@ -84,7 +84,7 @@ class Session * to the listening socket and address. * @return true on success */ - bool Listen(Connection& conn); + bool Listen(Connection& conn) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); /** * Wait for and accept a new incoming connection. @@ -103,7 +103,7 @@ class Session * it is set to `false`. Only set if `false` is returned. * @return true on success */ - bool Connect(const CService& to, Connection& conn, bool& proxy_error); + bool Connect(const CService& to, Connection& conn, bool& proxy_error) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); private: /** @@ -172,7 +172,7 @@ class Session /** * Check the control socket for errors and possibly disconnect. */ - void CheckControlSock(); + void CheckControlSock() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); /** * Generate a new destination with the SAM proxy and set `m_private_key` to it. diff --git a/src/index/blockfilterindex.h b/src/index/blockfilterindex.h index 96c4936f8d..760a712349 100644 --- a/src/index/blockfilterindex.h +++ b/src/index/blockfilterindex.h @@ -63,7 +63,7 @@ class BlockFilterIndex final : public BaseIndex bool LookupFilter(const CBlockIndex* block_index, BlockFilter& filter_out) const; /** Get a single filter header by block. */ - bool LookupFilterHeader(const CBlockIndex* block_index, uint256& header_out); + bool LookupFilterHeader(const CBlockIndex* block_index, uint256& header_out) EXCLUSIVE_LOCKS_REQUIRED(!m_cs_headers_cache); /** Get a range of filters between two heights on a chain. */ bool LookupFilterRange(int start_height, const CBlockIndex* stop_index, diff --git a/src/llmq/dkgsession.cpp b/src/llmq/dkgsession.cpp index bcbf91fe7e..167be04380 100644 --- a/src/llmq/dkgsession.cpp +++ b/src/llmq/dkgsession.cpp @@ -1336,7 +1336,7 @@ void CDKGSession::RelayInvToParticipants(const CInv& inv) const if (pnode->GetVerifiedProRegTxHash().IsNull()) { logger.Batch("node[%d:%s] not mn", pnode->GetId(), - pnode->GetAddrName()); + pnode->m_addr_name); } else if (relayMembers.count(pnode->GetVerifiedProRegTxHash()) == 0) { ss2 << pnode->GetVerifiedProRegTxHash().ToString().substr(0, 4) << " | "; } diff --git a/src/llmq/signing_shares.cpp b/src/llmq/signing_shares.cpp index 548602e622..3667e593bc 100644 --- a/src/llmq/signing_shares.cpp +++ b/src/llmq/signing_shares.cpp @@ -1092,14 +1092,13 @@ bool CSigSharesManager::SendMessages() return session->sendSessionId; }; - std::vector vNodesCopy = connman.CopyNodeVector(CConnman::FullyConnectedOnly); - + const CConnman::NodesSnapshot snap{connman, /* filter = */ CConnman::FullyConnectedOnly}; { LOCK(cs); CollectSigSharesToRequest(sigSharesToRequest); CollectSigSharesToSend(sigShareBatchesToSend); CollectSigSharesToAnnounce(sigSharesToAnnounce); - CollectSigSharesToSendConcentrated(sigSharesToSend, vNodesCopy); + CollectSigSharesToSendConcentrated(sigSharesToSend, snap.Nodes()); for (auto& [nodeId, sigShareMap] : sigSharesToRequest) { for (auto& [hash, sigShareInv] : sigShareMap) { @@ -1120,7 +1119,7 @@ bool CSigSharesManager::SendMessages() bool didSend = false; - for (auto& pnode : vNodesCopy) { + for (auto& pnode : snap.Nodes()) { CNetMsgMaker msgMaker(pnode->GetCommonVersion()); if (const auto it1 = sigSessionAnnouncements.find(pnode->GetId()); it1 != sigSessionAnnouncements.end()) { @@ -1222,9 +1221,6 @@ bool CSigSharesManager::SendMessages() } } - // looped through all nodes, release them - connman.ReleaseNodeVector(vNodesCopy); - return didSend; } diff --git a/src/masternode/sync.cpp b/src/masternode/sync.cpp index f690693b45..f067fcb800 100644 --- a/src/masternode/sync.cpp +++ b/src/masternode/sync.cpp @@ -140,12 +140,11 @@ void CMasternodeSync::ProcessTick() } nTimeLastProcess = GetTime(); - std::vector vNodesCopy = connman.CopyNodeVector(CConnman::FullyConnectedOnly); + const CConnman::NodesSnapshot snap{connman, /* filter = */ CConnman::FullyConnectedOnly}; // gradually request the rest of the votes after sync finished if(IsSynced()) { - m_govman.RequestGovernanceObjectVotes(vNodesCopy, connman); - connman.ReleaseNodeVector(vNodesCopy); + m_govman.RequestGovernanceObjectVotes(snap.Nodes(), connman); return; } @@ -154,7 +153,7 @@ void CMasternodeSync::ProcessTick() LogPrint(BCLog::MNSYNC, "CMasternodeSync::ProcessTick -- nTick %d nCurrentAsset %d nTriedPeerCount %d nSyncProgress %f\n", nTick, nCurrentAsset, nTriedPeerCount, nSyncProgress); uiInterface.NotifyAdditionalDataSyncProgressChanged(nSyncProgress); - for (auto& pnode : vNodesCopy) + for (auto& pnode : snap.Nodes()) { CNetMsgMaker msgMaker(pnode->GetCommonVersion()); @@ -189,7 +188,7 @@ void CMasternodeSync::ProcessTick() } if (nCurrentAsset == MASTERNODE_SYNC_BLOCKCHAIN) { - int64_t nTimeSyncTimeout = vNodesCopy.size() > 3 ? MASTERNODE_SYNC_TICK_SECONDS : MASTERNODE_SYNC_TIMEOUT_SECONDS; + int64_t nTimeSyncTimeout = snap.Nodes().size() > 3 ? MASTERNODE_SYNC_TICK_SECONDS : MASTERNODE_SYNC_TIMEOUT_SECONDS; if (fReachedBestHeader && (GetTime() - nTimeLastBumped > nTimeSyncTimeout)) { // At this point we know that: // a) there are peers (because we are looping on at least one of them); @@ -205,7 +204,7 @@ void CMasternodeSync::ProcessTick() if (gArgs.GetBoolArg("-syncmempool", DEFAULT_SYNC_MEMPOOL)) { // Now that the blockchain is synced request the mempool from the connected outbound nodes if possible - for (auto pNodeTmp : vNodesCopy) { + for (auto pNodeTmp : snap.Nodes()) { bool fRequestedEarlier = m_netfulfilledman.HasFulfilledRequest(pNodeTmp->addr, "mempool-sync"); if (pNodeTmp->nVersion >= 70216 && !pNodeTmp->IsInboundConn() && !fRequestedEarlier && !pNodeTmp->IsBlockRelayOnly()) { m_netfulfilledman.AddFulfilledRequest(pNodeTmp->addr, "mempool-sync"); @@ -222,7 +221,6 @@ void CMasternodeSync::ProcessTick() if(nCurrentAsset == MASTERNODE_SYNC_GOVERNANCE) { if (!m_govman.IsValid()) { SwitchToNextAsset(); - connman.ReleaseNodeVector(vNodesCopy); return; } LogPrint(BCLog::GOBJECT, "CMasternodeSync::ProcessTick -- nTick %d nCurrentAsset %d nTimeLastBumped %lld GetTime() %lld diff %lld\n", nTick, nCurrentAsset, nTimeLastBumped, GetTime(), GetTime() - nTimeLastBumped); @@ -235,7 +233,6 @@ void CMasternodeSync::ProcessTick() // it's kind of ok to skip this for now, hopefully we'll catch up later? } SwitchToNextAsset(); - connman.ReleaseNodeVector(vNodesCopy); return; } @@ -259,12 +256,11 @@ void CMasternodeSync::ProcessTick() if (nCurrentAsset != MASTERNODE_SYNC_GOVERNANCE) { // looped through all nodes and not syncing governance yet/already, release them - connman.ReleaseNodeVector(vNodesCopy); return; } // request votes on per-obj basis from each node - for (const auto& pnode : vNodesCopy) { + for (const auto& pnode : snap.Nodes()) { if(!m_netfulfilledman.HasFulfilledRequest(pnode->addr, "governance-sync")) { continue; // to early for this node } @@ -291,16 +287,12 @@ void CMasternodeSync::ProcessTick() // reset nTimeNoObjectsLeft to be able to use the same condition on resync nTimeNoObjectsLeft = 0; SwitchToNextAsset(); - connman.ReleaseNodeVector(vNodesCopy); return; } nLastTick = nTick; nLastVotes = m_govman.GetVoteCount(); } } - - // looped through all nodes, release them - connman.ReleaseNodeVector(vNodesCopy); } void CMasternodeSync::SendGovernanceSyncRequest(CNode* pnode) const diff --git a/src/net.cpp b/src/net.cpp index e4680255e2..90957d5f53 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -331,8 +331,8 @@ bool IsLocal(const CService& addr) CNode* CConnman::FindNode(const CNetAddr& ip, bool fExcludeDisconnecting) { - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { if (fExcludeDisconnecting && pnode->fDisconnect) { continue; } @@ -345,8 +345,8 @@ CNode* CConnman::FindNode(const CNetAddr& ip, bool fExcludeDisconnecting) CNode* CConnman::FindNode(const CSubNet& subNet, bool fExcludeDisconnecting) { - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { if (fExcludeDisconnecting && pnode->fDisconnect) { continue; } @@ -359,12 +359,12 @@ CNode* CConnman::FindNode(const CSubNet& subNet, bool fExcludeDisconnecting) CNode* CConnman::FindNode(const std::string& addrName, bool fExcludeDisconnecting) { - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { if (fExcludeDisconnecting && pnode->fDisconnect) { continue; } - if (pnode->GetAddrName() == addrName) { + if (pnode->m_addr_name == addrName) { return pnode; } } @@ -373,8 +373,8 @@ CNode* CConnman::FindNode(const std::string& addrName, bool fExcludeDisconnectin CNode* CConnman::FindNode(const CService& addr, bool fExcludeDisconnecting) { - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { if (fExcludeDisconnecting && pnode->fDisconnect) { continue; } @@ -392,8 +392,8 @@ bool CConnman::AlreadyConnectedToAddress(const CAddress& addr) bool CConnman::CheckIncomingNonce(uint64_t nonce) { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (!pnode->fSuccessfullyConnected && !pnode->IsInboundConn() && pnode->GetLocalNonce() == nonce) return false; } @@ -457,14 +457,10 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo return nullptr; } // It is possible that we already have a connection to the IP/port pszDest resolved to. - // In that case, drop the connection that was just created, and return the existing CNode instead. - // Also store the name we used to connect in that CNode, so that future FindNode() calls to that - // name catch this early. - LOCK(cs_vNodes); + // In that case, drop the connection that was just created. + LOCK(m_nodes_mutex); CNode* pnode = FindNode(static_cast(addrConnect)); - if (pnode) - { - pnode->MaybeSetAddrName(std::string(pszDest)); + if (pnode) { LogPrintf("Failed to open new connection, already connected\n"); return nullptr; } @@ -530,7 +526,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo if (!addr_bind.IsValid()) { addr_bind = GetBindAddress(sock->Get()); } - CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); + CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type, /* inbound_onion */ false); pnode->AddRef(); statsClient.inc("peers.connect", 1.0f); @@ -542,7 +538,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo void CNode::CloseSocketDisconnect(CConnman* connman) { - AssertLockHeld(connman->cs_vNodes); + AssertLockHeld(connman->m_nodes_mutex); fDisconnect = true; LOCK2(connman->cs_mapSocketToNode, cs_hSocket); @@ -608,25 +604,16 @@ std::string ConnectionTypeAsString(ConnectionType conn_type) assert(false); } -std::string CNode::GetAddrName() const { - LOCK(cs_addrName); - return addrName; -} - -void CNode::MaybeSetAddrName(const std::string& addrNameIn) { - LOCK(cs_addrName); - if (addrName.empty()) { - addrName = addrNameIn; - } -} - -CService CNode::GetAddrLocal() const { - LOCK(cs_addrLocal); +CService CNode::GetAddrLocal() const +{ + AssertLockNotHeld(m_addr_local_mutex); + LOCK(m_addr_local_mutex); return addrLocal; } void CNode::SetAddrLocal(const CService& addrLocalIn) { - LOCK(cs_addrLocal); + AssertLockNotHeld(m_addr_local_mutex); + LOCK(m_addr_local_mutex); if (addrLocal.IsValid()) { error("Addr local already set for node: %i. Refusing to change from %s to %s", id, addrLocal.ToString(), addrLocalIn.ToString()); } else { @@ -660,10 +647,10 @@ void CNode::copyStats(CNodeStats &stats, const std::vector &m_asmap) X(nLastBlockTime); X(nTimeConnected); X(nTimeOffset); - stats.addrName = GetAddrName(); + X(m_addr_name); X(nVersion); { - LOCK(cs_SubVer); + LOCK(m_subver_mutex); X(cleanSubVer); } stats.fInbound = IsInboundConn(); @@ -1085,9 +1072,9 @@ bool CConnman::AttemptToEvictConnection() { std::vector vEvictionCandidates; { - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); - for (const CNode* node : vNodes) { + for (const CNode* node : m_nodes) { if (node->HasPermission(NetPermissionFlags::NoBan)) continue; if (!node->IsInboundConn()) @@ -1127,8 +1114,8 @@ bool CConnman::AttemptToEvictConnection() if (!node_id_to_evict) { return false; } - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { if (pnode->GetId() == *node_id_to_evict) { LogPrint(BCLog::NET_NETCONN, "selected %s connection for eviction peer=%d; disconnecting\n", pnode->ConnectionTypeAsString(), pnode->GetId()); pnode->fDisconnect = true; @@ -1185,8 +1172,8 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, } { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (pnode->IsInboundConn()) { nInbound++; if (!pnode->GetVerifiedProRegTxHash().IsNull()) { @@ -1286,8 +1273,8 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, } { - LOCK(cs_vNodes); - vNodes.push_back(pnode); + LOCK(m_nodes_mutex); + m_nodes.push_back(pnode); WITH_LOCK(cs_mapSocketToNode, mapSocketToNode.emplace(hSocket, pnode)); RegisterEvents(pnode); WakeSelect(); @@ -1304,8 +1291,8 @@ bool CConnman::AddConnection(const std::string& address, ConnectionType conn_typ const int max_connections = conn_type == ConnectionType::OUTBOUND_FULL_RELAY ? m_max_outbound_full_relay : m_max_outbound_block_relay; // Count existing connections - int existing_connections = WITH_LOCK(cs_vNodes, - return std::count_if(vNodes.begin(), vNodes.end(), [conn_type](CNode* node) { return node->m_conn_type == conn_type; });); + int existing_connections = WITH_LOCK(m_nodes_mutex, + return std::count_if(m_nodes.begin(), m_nodes.end(), [conn_type](CNode* node) { return node->m_conn_type == conn_type; });); // Max connections of specified type already exist if (existing_connections >= max_connections) return false; @@ -1321,11 +1308,11 @@ bool CConnman::AddConnection(const std::string& address, ConnectionType conn_typ void CConnman::DisconnectNodes() { { - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); if (!fNetworkActive) { // Disconnect any connected nodes - for (CNode* pnode : vNodes) { + for (CNode* pnode : m_nodes) { if (!pnode->fDisconnect) { LogPrint(BCLog::NET, "Network not active, dropping peer=%d\n", pnode->GetId()); pnode->fDisconnect = true; @@ -1334,7 +1321,7 @@ void CConnman::DisconnectNodes() } // Disconnect unused nodes - for (auto it = vNodes.begin(); it != vNodes.end(); ) + for (auto it = m_nodes.begin(); it != m_nodes.end(); ) { CNode* pnode = *it; if (pnode->fDisconnect) @@ -1377,8 +1364,8 @@ void CConnman::DisconnectNodes() pnode->GetId(), pnode->GetRefCount(), pnode->IsInboundConn(), pnode->m_masternode_connection, pnode->m_masternode_iqr_connection); } - // remove from vNodes - it = vNodes.erase(it); + // remove from m_nodes + it = m_nodes.erase(it); // release outbound grant (if any) pnode->grantOutbound.Release(); @@ -1388,7 +1375,7 @@ void CConnman::DisconnectNodes() // hold in disconnected pool until all refs are released pnode->Release(); - vNodesDisconnected.push_back(pnode); + m_nodes_disconnected.push_back(pnode); } else { ++it; } @@ -1396,8 +1383,8 @@ void CConnman::DisconnectNodes() } { // Delete disconnected nodes - std::list vNodesDisconnectedCopy = vNodesDisconnected; - for (auto it = vNodesDisconnected.begin(); it != vNodesDisconnected.end(); ) + std::list nodes_disconnected_copy = m_nodes_disconnected; + for (auto it = m_nodes_disconnected.begin(); it != m_nodes_disconnected.end(); ) { CNode* pnode = *it; // wait until threads are done using it @@ -1410,7 +1397,7 @@ void CConnman::DisconnectNodes() } } if (fDelete) { - it = vNodesDisconnected.erase(it); + it = m_nodes_disconnected.erase(it); DeleteNode(pnode); } } @@ -1423,22 +1410,22 @@ void CConnman::DisconnectNodes() void CConnman::NotifyNumConnectionsChanged(CMasternodeSync& mn_sync) { - size_t vNodesSize; + size_t nodes_size; { - LOCK(cs_vNodes); - vNodesSize = vNodes.size(); + LOCK(m_nodes_mutex); + nodes_size = m_nodes.size(); } // If we had zero connections before and new connections now or if we just dropped // to zero connections reset the sync process if its outdated. - if ((vNodesSize > 0 && nPrevNodeCount == 0) || (vNodesSize == 0 && nPrevNodeCount > 0)) { + if ((nodes_size > 0 && nPrevNodeCount == 0) || (nodes_size == 0 && nPrevNodeCount > 0)) { mn_sync.Reset(); } - if(vNodesSize != nPrevNodeCount) { - nPrevNodeCount = vNodesSize; + if(nodes_size != nPrevNodeCount) { + nPrevNodeCount = nodes_size; if(clientInterface) - clientInterface->NotifyNumConnectionsChanged(vNodesSize); + clientInterface->NotifyNumConnectionsChanged(nodes_size); CalculateNumConnectionsChangedStats(); } @@ -1466,8 +1453,8 @@ void CConnman::CalculateNumConnectionsChangedStats() } mapRecvBytesMsgStats[NET_MESSAGE_COMMAND_OTHER] = 0; mapSentBytesMsgStats[NET_MESSAGE_COMMAND_OTHER] = 0; - auto vNodesCopy = CopyNodeVector(CConnman::FullyConnectedOnly); - for (auto pnode : vNodesCopy) { + const NodesSnapshot snap{*this, /* filter = */ CConnman::FullyConnectedOnly}; + for (auto pnode : snap.Nodes()) { { LOCK(pnode->cs_vRecv); for (const mapMsgCmdSize::value_type &i : pnode->mapRecvBytesPerMsgCmd) @@ -1496,7 +1483,6 @@ void CConnman::CalculateNumConnectionsChangedStats() if (last_ping_time > 0) statsClient.timing("peers.ping_us", last_ping_time, 1.0f); } - ReleaseNodeVector(vNodesCopy); for (const std::string &msg : getAllNetMessageTypes()) { statsClient.gauge("bandwidth.message." + msg + ".totalBytesReceived", mapRecvBytesMsgStats[msg], 1.0f); statsClient.gauge("bandwidth.message." + msg + ".totalBytesSent", mapSentBytesMsgStats[msg], 1.0f); @@ -1549,30 +1535,30 @@ bool CConnman::InactivityCheck(const CNode& node) const return false; } -bool CConnman::GenerateSelectSet(std::set &recv_set, std::set &send_set, std::set &error_set) +bool CConnman::GenerateSelectSet(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set) { for (const ListenSocket& hListenSocket : vhListenSocket) { recv_set.insert(hListenSocket.socket); } + for (CNode* pnode : nodes) { - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) - { - bool select_recv = !pnode->fHasRecvData; - bool select_send = !pnode->fCanSendData; + bool select_recv = !pnode->fHasRecvData; + bool select_send = !pnode->fCanSendData; - LOCK(pnode->cs_hSocket); - if (pnode->hSocket == INVALID_SOCKET) - continue; + LOCK(pnode->cs_hSocket); + if (pnode->hSocket == INVALID_SOCKET) + continue; - error_set.insert(pnode->hSocket); - if (select_send) { - send_set.insert(pnode->hSocket); - } - if (select_recv) { - recv_set.insert(pnode->hSocket); - } + error_set.insert(pnode->hSocket); + if (select_send) { + send_set.insert(pnode->hSocket); + } + if (select_recv) { + recv_set.insert(pnode->hSocket); } } @@ -1589,14 +1575,17 @@ bool CConnman::GenerateSelectSet(std::set &recv_set, std::set &s } #ifdef USE_KQUEUE -void CConnman::SocketEventsKqueue(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll) +void CConnman::SocketEventsKqueue(std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll) { const size_t maxEvents = 64; struct kevent events[maxEvents]; struct timespec timeout; - timeout.tv_sec = fOnlyPoll ? 0 : SELECT_TIMEOUT_MILLISECONDS / 1000; - timeout.tv_nsec = (fOnlyPoll ? 0 : SELECT_TIMEOUT_MILLISECONDS % 1000) * 1000 * 1000; + timeout.tv_sec = only_poll ? 0 : SELECT_TIMEOUT_MILLISECONDS / 1000; + timeout.tv_nsec = (only_poll ? 0 : SELECT_TIMEOUT_MILLISECONDS % 1000) * 1000 * 1000; wakeupSelectNeeded = true; int n = kevent(kqueuefd, nullptr, 0, events, maxEvents, &timeout); @@ -1624,13 +1613,16 @@ void CConnman::SocketEventsKqueue(std::set &recv_set, std::set & #endif #ifdef USE_EPOLL -void CConnman::SocketEventsEpoll(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll) +void CConnman::SocketEventsEpoll(std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll) { const size_t maxEvents = 64; epoll_event events[maxEvents]; wakeupSelectNeeded = true; - int n = epoll_wait(epollfd, events, maxEvents, fOnlyPoll ? 0 : SELECT_TIMEOUT_MILLISECONDS); + int n = epoll_wait(epollfd, events, maxEvents, only_poll ? 0 : SELECT_TIMEOUT_MILLISECONDS); wakeupSelectNeeded = false; for (int i = 0; i < n; i++) { auto& e = events[i]; @@ -1651,11 +1643,15 @@ void CConnman::SocketEventsEpoll(std::set &recv_set, std::set &s #endif #ifdef USE_POLL -void CConnman::SocketEventsPoll(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll) +void CConnman::SocketEventsPoll(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll) { std::set recv_select_set, send_select_set, error_select_set; - if (!GenerateSelectSet(recv_select_set, send_select_set, error_select_set)) { - if (!fOnlyPoll) interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); + if (!GenerateSelectSet(nodes, recv_select_set, send_select_set, error_select_set)) { + if (!only_poll) interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); return; } @@ -1683,7 +1679,7 @@ void CConnman::SocketEventsPoll(std::set &recv_set, std::set &se } wakeupSelectNeeded = true; - int r = poll(vpollfds.data(), vpollfds.size(), fOnlyPoll ? 0 : SELECT_TIMEOUT_MILLISECONDS); + int r = poll(vpollfds.data(), vpollfds.size(), only_poll ? 0 : SELECT_TIMEOUT_MILLISECONDS); wakeupSelectNeeded = false; if (r < 0) { return; @@ -1699,10 +1695,14 @@ void CConnman::SocketEventsPoll(std::set &recv_set, std::set &se } #endif -void CConnman::SocketEventsSelect(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll) +void CConnman::SocketEventsSelect(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll) { std::set recv_select_set, send_select_set, error_select_set; - if (!GenerateSelectSet(recv_select_set, send_select_set, error_select_set)) { + if (!GenerateSelectSet(nodes, recv_select_set, send_select_set, error_select_set)) { interruptNet.sleep_for(std::chrono::milliseconds(SELECT_TIMEOUT_MILLISECONDS)); return; } @@ -1712,7 +1712,7 @@ void CConnman::SocketEventsSelect(std::set &recv_set, std::set & // struct timeval timeout; timeout.tv_sec = 0; - timeout.tv_usec = fOnlyPoll ? 0 : SELECT_TIMEOUT_MILLISECONDS * 1000; // frequency to poll pnode->vSend + timeout.tv_usec = only_poll ? 0 : SELECT_TIMEOUT_MILLISECONDS * 1000; // frequency to poll pnode->vSend fd_set fdsetRecv; fd_set fdsetSend; @@ -1774,26 +1774,30 @@ void CConnman::SocketEventsSelect(std::set &recv_set, std::set & } } -void CConnman::SocketEvents(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll) +void CConnman::SocketEvents(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll) { switch (socketEventsMode) { #ifdef USE_KQUEUE case SOCKETEVENTS_KQUEUE: - SocketEventsKqueue(recv_set, send_set, error_set, fOnlyPoll); + SocketEventsKqueue(recv_set, send_set, error_set, only_poll); break; #endif #ifdef USE_EPOLL case SOCKETEVENTS_EPOLL: - SocketEventsEpoll(recv_set, send_set, error_set, fOnlyPoll); + SocketEventsEpoll(recv_set, send_set, error_set, only_poll); break; #endif #ifdef USE_POLL case SOCKETEVENTS_POLL: - SocketEventsPoll(recv_set, send_set, error_set, fOnlyPoll); + SocketEventsPoll(nodes, recv_set, send_set, error_set, only_poll); break; #endif case SOCKETEVENTS_SELECT: - SocketEventsSelect(recv_set, send_set, error_set, fOnlyPoll); + SocketEventsSelect(nodes, recv_set, send_set, error_set, only_poll); break; default: assert(false); @@ -1802,15 +1806,22 @@ void CConnman::SocketEvents(std::set &recv_set, std::set &send_s void CConnman::SocketHandler(CMasternodeSync& mn_sync) { - bool fOnlyPoll = [this]() { - // check if we have work to do and thus should avoid waiting for events - LOCK2(cs_vNodes, cs_sendable_receivable_nodes); + AssertLockNotHeld(m_total_bytes_sent_mutex); + + std::set recv_set; + std::set send_set; + std::set error_set; + + bool only_poll = [this]() { + // Check if we have work to do and thus should avoid waiting for events + LOCK2(m_nodes_mutex, cs_sendable_receivable_nodes); if (!mapReceivableNodes.empty()) { return true; } else if (!mapSendableNodes.empty()) { if (LOCK(cs_mapNodesWithDataToSend); !mapNodesWithDataToSend.empty()) { - // we must check if at least one of the nodes with pending messages is also sendable, as otherwise a single - // node would be able to make the network thread busy with polling + // We must check if at least one of the nodes with pending messages is also + // sendable, as otherwise a single node would be able to make the network + // thread busy with polling. for (auto& p : mapNodesWithDataToSend) { if (mapSendableNodes.count(p.first)) { return true; @@ -1822,8 +1833,14 @@ void CConnman::SocketHandler(CMasternodeSync& mn_sync) return false; }(); - std::set recv_set, send_set, error_set; - SocketEvents(recv_set, send_set, error_set, fOnlyPoll); + { + const NodesSnapshot snap{*this, /* filter = */ CConnman::AllNodes, /* shuffle = */ false}; + + // Check for the readiness of the already connected sockets and the + // listening sockets in one call ("readiness" as in poll(2) or + // select(2)). If none are ready, wait for a short while and return + // empty sets. + SocketEvents(snap.Nodes(), recv_set, send_set, error_set, only_poll); #ifdef USE_WAKEUP_PIPE // drain the wakeup pipe @@ -1838,19 +1855,22 @@ void CConnman::SocketHandler(CMasternodeSync& mn_sync) } #endif - if (interruptNet) return; - - // - // Accept new connections - // - for (const ListenSocket& hListenSocket : vhListenSocket) - { - if (recv_set.count(hListenSocket.socket) > 0) - { - AcceptConnection(hListenSocket, mn_sync); - } + // Service (send/receive) each of the already connected nodes. + SocketHandlerConnected(recv_set, send_set, error_set); } + // Accept new connections from listening sockets. + SocketHandlerListening(recv_set, mn_sync); +} + +void CConnman::SocketHandlerConnected(const std::set& recv_set, + const std::set& send_set, + const std::set& error_set) +{ + AssertLockNotHeld(m_total_bytes_sent_mutex); + + if (interruptNet) return; + std::vector vErrorNodes; std::vector vReceivableNodes; std::vector vSendableNodes; @@ -1970,9 +1990,15 @@ void CConnman::SocketHandler(CMasternodeSync& mn_sync) if (bytes_sent) RecordBytesSent(bytes_sent); } - ReleaseNodeVector(vErrorNodes); - ReleaseNodeVector(vReceivableNodes); - ReleaseNodeVector(vSendableNodes); + for (auto& node : vErrorNodes) { + node->Release(); + } + for (auto& node : vReceivableNodes) { + node->Release(); + } + for (auto& node : vSendableNodes) { + node->Release(); + } if (interruptNet) { return; @@ -1993,6 +2019,18 @@ void CConnman::SocketHandler(CMasternodeSync& mn_sync) } } +void CConnman::SocketHandlerListening(const std::set& recv_set, CMasternodeSync& mn_sync) +{ + for (const ListenSocket& listen_socket : vhListenSocket) { + if (interruptNet) { + return; + } + if (recv_set.count(listen_socket.socket) > 0) { + AcceptConnection(listen_socket, mn_sync); + } + } +} + size_t CConnman::SocketRecvData(CNode *pnode) { // typical socket buffer is 8K-64K @@ -2011,7 +2049,7 @@ size_t CConnman::SocketRecvData(CNode *pnode) { bool notify = false; if (!pnode->ReceiveMsgBytes(Span(pchBuf, nBytes), notify)) { - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); pnode->CloseSocketDisconnect(this); } RecordBytesRecv(nBytes); @@ -2038,7 +2076,7 @@ size_t CConnman::SocketRecvData(CNode *pnode) if (!pnode->fDisconnect) { LogPrint(BCLog::NET, "socket closed for peer=%d\n", pnode->GetId()); } - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); pnode->fOtherSideDisconnected = true; // avoid lingering pnode->CloseSocketDisconnect(this); } @@ -2051,7 +2089,7 @@ size_t CConnman::SocketRecvData(CNode *pnode) if (!pnode->fDisconnect){ LogPrint(BCLog::NET, "socket recv error for peer=%d: %s\n", pnode->GetId(), NetworkErrorString(nErr)); } - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); pnode->fOtherSideDisconnected = true; // avoid lingering pnode->CloseSocketDisconnect(this); } @@ -2064,6 +2102,8 @@ size_t CConnman::SocketRecvData(CNode *pnode) void CConnman::ThreadSocketHandler(CMasternodeSync& mn_sync) { + AssertLockNotHeld(m_total_bytes_sent_mutex); + int64_t nLastCleanupNodes = 0; while (!interruptNet) @@ -2156,8 +2196,8 @@ void CConnman::ThreadDNSAddressSeed() int nRelevant = 0; { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (pnode->fSuccessfullyConnected && !pnode->IsFullOutboundConn() && !pnode->m_masternode_probe_connection) ++nRelevant; } } @@ -2266,8 +2306,8 @@ int CConnman::GetExtraFullOutboundCount() const { int full_outbound_peers = 0; { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { // don't count outbound masternodes if (pnode->m_masternode_connection) { continue; @@ -2284,8 +2324,8 @@ int CConnman::GetExtraBlockRelayCount() const { int block_relay_peers = 0; { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (pnode->fSuccessfullyConnected && !pnode->fDisconnect && pnode->IsBlockOnlyConn()) { ++block_relay_peers; } @@ -2356,8 +2396,8 @@ void CConnman::ThreadOpenConnections(const std::vector connect, CDe // Checking !dnsseed is cheaper before locking 2 mutexes. if (!add_fixed_seeds_now && !dnsseed) { - LOCK2(m_addr_fetches_mutex, cs_vAddedNodes); - if (m_addr_fetches.empty() && vAddedNodes.empty()) { + LOCK2(m_addr_fetches_mutex, m_added_nodes_mutex); + if (m_addr_fetches.empty() && m_added_nodes.empty()) { add_fixed_seeds_now = true; LogPrintf("Adding fixed seeds as -dnsseed=0, -addnode is not provided and all -seednode(s) attempted\n"); } @@ -2382,8 +2422,8 @@ void CConnman::ThreadOpenConnections(const std::vector connect, CDe int nOutboundBlockRelay = 0; std::set > setConnected; if (!Params().AllowMultipleAddressesFromGroup()) { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (pnode->IsFullOutboundConn() && !pnode->m_masternode_connection) nOutboundFullRelay++; if (pnode->IsBlockOnlyConn()) nOutboundBlockRelay++; @@ -2407,8 +2447,8 @@ void CConnman::ThreadOpenConnections(const std::vector connect, CDe std::set setConnectedMasternodes; { - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { auto verifiedProRegTxHash = pnode->GetVerifiedProRegTxHash(); if (!verifiedProRegTxHash.IsNull()) { setConnectedMasternodes.emplace(verifiedProRegTxHash); @@ -2598,8 +2638,8 @@ void CConnman::ThreadOpenConnections(const std::vector connect, CDe std::vector CConnman::GetCurrentBlockRelayOnlyConns() const { std::vector ret; - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (pnode->IsBlockRelayOnly()) { ret.push_back(pnode->addr); } @@ -2614,9 +2654,9 @@ std::vector CConnman::GetAddedNodeInfo() const std::list lAddresses(0); { - LOCK(cs_vAddedNodes); - ret.reserve(vAddedNodes.size()); - std::copy(vAddedNodes.cbegin(), vAddedNodes.cend(), std::back_inserter(lAddresses)); + LOCK(m_added_nodes_mutex); + ret.reserve(m_added_nodes.size()); + std::copy(m_added_nodes.cbegin(), m_added_nodes.cend(), std::back_inserter(lAddresses)); } @@ -2624,12 +2664,12 @@ std::vector CConnman::GetAddedNodeInfo() const std::map mapConnected; std::map> mapConnectedByName; { - LOCK(cs_vNodes); - for (const CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (const CNode* pnode : m_nodes) { if (pnode->addr.IsValid()) { mapConnected[pnode->addr] = pnode->IsInboundConn(); } - std::string addrName = pnode->GetAddrName(); + std::string addrName{pnode->m_addr_name}; if (!addrName.empty()) { mapConnectedByName[std::move(addrName)] = std::make_pair(pnode->IsInboundConn(), static_cast(pnode->addr)); } @@ -2805,7 +2845,7 @@ void CConnman::ThreadOpenMasternodeConnections(CDeterministicMNManager& dmnman, auto getConnectToDmn = [&]() -> CDeterministicMNCPtr { // don't hold lock while calling OpenMasternodeConnection as cs_main is locked deep inside - LOCK2(cs_vNodes, cs_vPendingMasternodes); + LOCK2(m_nodes_mutex, cs_vPendingMasternodes); if (!vPendingMasternodes.empty()) { auto dmn = mnList.GetValidMN(vPendingMasternodes.front()); @@ -2938,15 +2978,14 @@ void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFai m_msgproc->InitializeNode(pnode); { - LOCK(cs_vNodes); - vNodes.push_back(pnode); + LOCK(m_nodes_mutex); + m_nodes.push_back(pnode); RegisterEvents(pnode); WakeSelect(); } } void CConnman::OpenMasternodeConnection(const CAddress &addrConnect, MasternodeProbeConn probe) { - OpenNetworkConnection(addrConnect, false, nullptr, nullptr, ConnectionType::OUTBOUND_FULL_RELAY, MasternodeConn::IsConnection, probe); } @@ -2957,8 +2996,6 @@ void CConnman::ThreadMessageHandler() FastRandomContext rng; while (!flagInterruptMsgProc) { - std::vector vNodesCopy = CopyNodeVector(); - bool fMoreWork = false; bool fSkipSendMessagesForMasternodes = true; @@ -2966,13 +3003,13 @@ void CConnman::ThreadMessageHandler() fSkipSendMessagesForMasternodes = false; nLastSendMessagesTimeMasternodes = GetTimeMillis(); } + // Randomize the order in which we process messages from/to our peers. // This prevents attacks in which an attacker exploits having multiple - // consecutive connections in the vNodes list. - Shuffle(vNodesCopy.begin(), vNodesCopy.end(), rng); + // consecutive connections in the m_nodes list. + const NodesSnapshot snap{*this, /* filter = */ CConnman::AllNodes, /* shuffle = */ true}; - for (CNode* pnode : vNodesCopy) - { + for (CNode* pnode : snap.Nodes()) { if (pnode->fDisconnect) continue; @@ -2991,8 +3028,6 @@ void CConnman::ThreadMessageHandler() return; } - ReleaseNodeVector(vNodesCopy); - WAIT_LOCK(mutexMsgProc, lock); if (!fMoreWork) { condMsgProc.wait_until(lock, std::chrono::steady_clock::now() + std::chrono::milliseconds(100), [this]() EXCLUSIVE_LOCKS_REQUIRED(mutexMsgProc) { return fMsgProcWake; }); @@ -3263,6 +3298,7 @@ bool CConnman::InitBinds( bool CConnman::Start(CDeterministicMNManager& dmnman, CMasternodeMetaMan& mn_metaman, CMasternodeSync& mn_sync, CScheduler& scheduler, const Options& connOptions) { + AssertLockNotHeld(m_total_bytes_sent_mutex); Init(connOptions); #ifdef USE_KQUEUE @@ -3522,10 +3558,10 @@ void CConnman::StopNodes() } { - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); // Close sockets - for (CNode *pnode : vNodes) + for (CNode *pnode : m_nodes) pnode->CloseSocketDisconnect(this); } for (ListenSocket& hListenSocket : vhListenSocket) @@ -3548,11 +3584,11 @@ void CConnman::StopNodes() // clean up some globals (to help leak detection) std::vector nodes; - WITH_LOCK(cs_vNodes, nodes.swap(vNodes)); + WITH_LOCK(m_nodes_mutex, nodes.swap(m_nodes)); for (CNode* pnode : nodes) { DeleteNode(pnode); } - for (CNode* pnode : vNodesDisconnected) { + for (CNode* pnode : m_nodes_disconnected) { DeleteNode(pnode); } WITH_LOCK(cs_mapSocketToNode, mapSocketToNode.clear()); @@ -3564,7 +3600,7 @@ void CConnman::StopNodes() LOCK(cs_mapNodesWithDataToSend); mapNodesWithDataToSend.clear(); } - vNodesDisconnected.clear(); + m_nodes_disconnected.clear(); vhListenSocket.clear(); semOutbound.reset(); semAddnode.reset(); @@ -3667,21 +3703,21 @@ std::vector CConnman::GetAddresses(CNode& requestor, size_t max_addres bool CConnman::AddNode(const std::string& strNode) { - LOCK(cs_vAddedNodes); - for (const std::string& it : vAddedNodes) { + LOCK(m_added_nodes_mutex); + for (const std::string& it : m_added_nodes) { if (strNode == it) return false; } - vAddedNodes.push_back(strNode); + m_added_nodes.push_back(strNode); return true; } bool CConnman::RemoveAddedNode(const std::string& strNode) { - LOCK(cs_vAddedNodes); - for(std::vector::iterator it = vAddedNodes.begin(); it != vAddedNodes.end(); ++it) { + LOCK(m_added_nodes_mutex); + for(std::vector::iterator it = m_added_nodes.begin(); it != m_added_nodes.end(); ++it) { if (strNode == *it) { - vAddedNodes.erase(it); + m_added_nodes.erase(it); return true; } } @@ -3754,7 +3790,7 @@ std::set CConnman::GetMasternodeQuorums(Consensus::LLMQType llmqType) std::set CConnman::GetMasternodeQuorumNodes(Consensus::LLMQType llmqType, const uint256& quorumHash) const { - LOCK2(cs_vNodes, cs_vPendingMasternodes); + LOCK2(m_nodes_mutex, cs_vPendingMasternodes); auto it = masternodeQuorumNodes.find(std::make_pair(llmqType, quorumHash)); if (it == masternodeQuorumNodes.end()) { return {}; @@ -3762,7 +3798,7 @@ std::set CConnman::GetMasternodeQuorumNodes(Consensus::LLMQType llmqType const auto& proRegTxHashes = it->second; std::set nodes; - for (const auto pnode : vNodes) { + for (const auto pnode : m_nodes) { if (pnode->fDisconnect) { continue; } @@ -3833,10 +3869,10 @@ void CConnman::AddPendingProbeConnections(const std::set &proTxHashes) size_t CConnman::GetNodeCount(ConnectionDirection flags) const { - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); int nNum = 0; - for (const auto& pnode : vNodes) { + for (const auto& pnode : m_nodes) { if (pnode->fDisconnect) { continue; } @@ -3861,9 +3897,9 @@ size_t CConnman::GetMaxOutboundNodeCount() void CConnman::GetNodeStats(std::vector& vstats) const { vstats.clear(); - LOCK(cs_vNodes); - vstats.reserve(vNodes.size()); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + vstats.reserve(m_nodes.size()); + for (CNode* pnode : m_nodes) { if (pnode->fDisconnect) { continue; } @@ -3874,7 +3910,7 @@ void CConnman::GetNodeStats(std::vector& vstats) const bool CConnman::DisconnectNode(const std::string& strNode) { - LOCK(cs_vNodes); + LOCK(m_nodes_mutex); if (CNode* pnode = FindNode(strNode)) { LogPrint(BCLog::NET_NETCONN, "disconnect by address%s matched peer=%d; disconnecting\n", (fLogIPs ? strprintf("=%s", strNode) : ""), pnode->GetId()); pnode->fDisconnect = true; @@ -3886,8 +3922,8 @@ bool CConnman::DisconnectNode(const std::string& strNode) bool CConnman::DisconnectNode(const CSubNet& subnet) { bool disconnected = false; - LOCK(cs_vNodes); - for (CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* pnode : m_nodes) { if (subnet.Match(pnode->addr)) { LogPrint(BCLog::NET_NETCONN, "disconnect by subnet%s matched peer=%d; disconnecting\n", (fLogIPs ? strprintf("=%s", subnet.ToString()) : ""), pnode->GetId()); pnode->fDisconnect = true; @@ -3904,8 +3940,8 @@ bool CConnman::DisconnectNode(const CNetAddr& addr) bool CConnman::DisconnectNode(NodeId id) { - LOCK(cs_vNodes); - for(CNode* pnode : vNodes) { + LOCK(m_nodes_mutex); + for(CNode* pnode : m_nodes) { if (id == pnode->GetId()) { LogPrint(BCLog::NET_NETCONN, "disconnect by id peer=%d; disconnecting\n", pnode->GetId()); pnode->fDisconnect = true; @@ -3917,7 +3953,6 @@ bool CConnman::DisconnectNode(NodeId id) void CConnman::RecordBytesRecv(uint64_t bytes) { - LOCK(cs_totalBytesRecv); nTotalBytesRecv += bytes; statsClient.count("bandwidth.bytesReceived", bytes, 0.1f); statsClient.gauge("bandwidth.totalBytesReceived", nTotalBytesRecv, 0.01f); @@ -3925,7 +3960,9 @@ void CConnman::RecordBytesRecv(uint64_t bytes) void CConnman::RecordBytesSent(uint64_t bytes) { - LOCK(cs_totalBytesSent); + AssertLockNotHeld(m_total_bytes_sent_mutex); + LOCK(m_total_bytes_sent_mutex); + nTotalBytesSent += bytes; statsClient.count("bandwidth.bytesSent", bytes, 0.01f); statsClient.gauge("bandwidth.totalBytesSent", nTotalBytesSent, 0.01f); @@ -3943,7 +3980,8 @@ void CConnman::RecordBytesSent(uint64_t bytes) uint64_t CConnman::GetMaxOutboundTarget() const { - LOCK(cs_totalBytesSent); + AssertLockNotHeld(m_total_bytes_sent_mutex); + LOCK(m_total_bytes_sent_mutex); return nMaxOutboundLimit; } @@ -3954,7 +3992,15 @@ std::chrono::seconds CConnman::GetMaxOutboundTimeframe() const std::chrono::seconds CConnman::GetMaxOutboundTimeLeftInCycle() const { - LOCK(cs_totalBytesSent); + AssertLockNotHeld(m_total_bytes_sent_mutex); + LOCK(m_total_bytes_sent_mutex); + return GetMaxOutboundTimeLeftInCycle_(); +} + +std::chrono::seconds CConnman::GetMaxOutboundTimeLeftInCycle_() const +{ + AssertLockHeld(m_total_bytes_sent_mutex); + if (nMaxOutboundLimit == 0) return 0s; @@ -3968,14 +4014,15 @@ std::chrono::seconds CConnman::GetMaxOutboundTimeLeftInCycle() const bool CConnman::OutboundTargetReached(bool historicalBlockServingLimit) const { - LOCK(cs_totalBytesSent); + AssertLockNotHeld(m_total_bytes_sent_mutex); + LOCK(m_total_bytes_sent_mutex); if (nMaxOutboundLimit == 0) return false; if (historicalBlockServingLimit) { // keep a large enough buffer to at least relay each block once - const std::chrono::seconds timeLeftInCycle = GetMaxOutboundTimeLeftInCycle(); + const std::chrono::seconds timeLeftInCycle = GetMaxOutboundTimeLeftInCycle_(); const uint64_t buffer = timeLeftInCycle / std::chrono::minutes{10} * MaxBlockSize(fDIP0001ActiveAtTip); if (buffer >= nMaxOutboundLimit || nMaxOutboundTotalBytesSentInCycle >= nMaxOutboundLimit - buffer) return true; @@ -3988,7 +4035,8 @@ bool CConnman::OutboundTargetReached(bool historicalBlockServingLimit) const uint64_t CConnman::GetOutboundTargetBytesLeft() const { - LOCK(cs_totalBytesSent); + AssertLockNotHeld(m_total_bytes_sent_mutex); + LOCK(m_total_bytes_sent_mutex); if (nMaxOutboundLimit == 0) return 0; @@ -3997,13 +4045,13 @@ uint64_t CConnman::GetOutboundTargetBytesLeft() const uint64_t CConnman::GetTotalBytesRecv() const { - LOCK(cs_totalBytesRecv); return nTotalBytesRecv; } uint64_t CConnman::GetTotalBytesSent() const { - LOCK(cs_totalBytesSent); + AssertLockNotHeld(m_total_bytes_sent_mutex); + LOCK(m_total_bytes_sent_mutex); return nTotalBytesSent; } @@ -4016,25 +4064,25 @@ unsigned int CConnman::GetReceiveFloodSize() const { return nReceiveFloodSize; } CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion) : nTimeConnected(GetTimeSeconds()), - addr(addrIn), - addrBind(addrBindIn), - nKeyedNetGroup(nKeyedNetGroupIn), - id(idIn), - nLocalHostNonce(nLocalHostNonceIn), - m_conn_type(conn_type_in), - nLocalServices(nLocalServicesIn), - m_inbound_onion(inbound_onion) + addr(addrIn), + addrBind(addrBindIn), + m_addr_name{addrNameIn.empty() ? addr.ToStringIPPort() : addrNameIn}, + m_inbound_onion(inbound_onion), + nKeyedNetGroup(nKeyedNetGroupIn), + id(idIn), + nLocalHostNonce(nLocalHostNonceIn), + m_conn_type(conn_type_in), + nLocalServices(nLocalServicesIn) { if (inbound_onion) assert(conn_type_in == ConnectionType::INBOUND); hSocket = hSocketIn; - addrName = addrNameIn == "" ? addr.ToStringIPPort() : addrNameIn; for (const std::string &msg : getAllNetMessageTypes()) mapRecvBytesPerMsgCmd[msg] = 0; mapRecvBytesPerMsgCmd[NET_MESSAGE_COMMAND_OTHER] = 0; if (fLogIPs) { - LogPrint(BCLog::NET, "Added connection to %s peer=%d\n", addrName, id); + LogPrint(BCLog::NET, "Added connection to %s peer=%d\n", m_addr_name, id); } else { LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } @@ -4055,6 +4103,7 @@ bool CConnman::NodeFullyConnected(const CNode* pnode) void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) { + AssertLockNotHeld(m_total_bytes_sent_mutex); size_t nMessageSize = msg.data.size(); LogPrint(BCLog::NET, "sending %s (%d bytes) peer=%d\n", SanitizeString(msg.m_type), nMessageSize, pnode->GetId()); if (gArgs.GetBoolArg("-capturemessages", false)) { @@ -4084,9 +4133,9 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) { LOCK(cs_mapNodesWithDataToSend); - // we're not holding cs_vNodes here, so there is a chance of this node being disconnected shortly before + // we're not holding m_nodes_mutex here, so there is a chance of this node being disconnected shortly before // we get here. Whoever called PushMessage still has a ref to CNode*, but will later Release() it, so we - // might end up having an entry in mapNodesWithDataToSend that is not in vNodes anymore. We need to + // might end up having an entry in mapNodesWithDataToSend that is not in m_nodes anymore. We need to // Add/Release refs when adding/erasing mapNodesWithDataToSend. if (mapNodesWithDataToSend.emplace(pnode->GetId(), pnode).second) { pnode->AddRef(); @@ -4102,8 +4151,8 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) bool CConnman::ForNode(const CService& addr, std::function cond, std::function func) { CNode* found = nullptr; - LOCK(cs_vNodes); - for (auto&& pnode : vNodes) { + LOCK(m_nodes_mutex); + for (auto&& pnode : m_nodes) { if((CService)pnode->addr == addr) { found = pnode; break; @@ -4115,8 +4164,8 @@ bool CConnman::ForNode(const CService& addr, std::function cond, std::function func) { CNode* found = nullptr; - LOCK(cs_vNodes); - for (auto&& pnode : vNodes) { + LOCK(m_nodes_mutex); + for (auto&& pnode : m_nodes) { if(pnode->GetId() == id) { found = pnode; break; @@ -4148,26 +4197,28 @@ std::chrono::microseconds PoissonNextSend(std::chrono::microseconds now, std::ch return now + std::chrono::duration_cast(unscaled * average_interval + 0.5us); } -std::vector CConnman::CopyNodeVector(std::function cond) +CConnman::NodesSnapshot::NodesSnapshot(const CConnman& connman, std::function filter, + bool shuffle) { - std::vector vecNodesCopy; - LOCK(cs_vNodes); - vecNodesCopy.reserve(vNodes.size()); - for(size_t i = 0; i < vNodes.size(); ++i) { - CNode* pnode = vNodes[i]; - if (!cond(pnode)) + LOCK(connman.m_nodes_mutex); + m_nodes_copy.reserve(connman.m_nodes.size()); + + for (auto& node : connman.m_nodes) { + if (!filter(node)) continue; - pnode->AddRef(); - vecNodesCopy.push_back(pnode); + node->AddRef(); + m_nodes_copy.push_back(node); + } + + if (shuffle) { + Shuffle(m_nodes_copy.begin(), m_nodes_copy.end(), FastRandomContext{}); } - return vecNodesCopy; } -void CConnman::ReleaseNodeVector(const std::vector& vecNodes) +CConnman::NodesSnapshot::~NodesSnapshot() { - for(size_t i = 0; i < vecNodes.size(); ++i) { - CNode* pnode = vecNodes[i]; - pnode->Release(); + for (auto& node : m_nodes_copy) { + node->Release(); } } diff --git a/src/net.h b/src/net.h index 5905760325..a8b6505099 100644 --- a/src/net.h +++ b/src/net.h @@ -281,7 +281,7 @@ class CNodeStats int64_t nLastBlockTime; int64_t nTimeConnected; int64_t nTimeOffset; - std::string addrName; + std::string m_addr_name; int nVersion; std::string cleanSubVer; bool fInbound; @@ -471,14 +471,17 @@ class CNode const CAddress addr; // Bind address of our side of the connection const CAddress addrBind; + const std::string m_addr_name; + //! Whether this peer is an inbound onion, i.e. connected via our Tor onion service. + const bool m_inbound_onion; std::atomic nNumWarningsSkipped{0}; std::atomic nVersion{0}; + Mutex m_subver_mutex; /** * cleanSubVer is a sanitized string of the user agent byte array we read * from the wire. This cleaned string can safely be logged or displayed. */ - std::string cleanSubVer GUARDED_BY(cs_SubVer){}; - RecursiveMutex cs_SubVer; // used for both cleanSubVer and strSubVer + std::string cleanSubVer GUARDED_BY(m_subver_mutex){}; bool m_prefer_evict{false}; // This peer is preferred for eviction. bool HasPermission(NetPermissionFlags permission) const { return NetPermissions::HasFlag(m_permissionFlags, permission); @@ -621,7 +624,7 @@ class CNode bool IsBlockRelayOnly() const; - CNode(NodeId id, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress &addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress &addrBindIn, const std::string &addrNameIn, ConnectionType conn_type_in, bool inbound_onion = false); + CNode(NodeId id, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress &addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress &addrBindIn, const std::string &addrNameIn, ConnectionType conn_type_in, bool inbound_onion); ~CNode(); CNode(const CNode&) = delete; CNode& operator=(const CNode&) = delete; @@ -649,7 +652,7 @@ class CNode * @return True if the peer should stay connected, * False if the peer should be disconnected from. */ - bool ReceiveMsgBytes(Span msg_bytes, bool& complete); + bool ReceiveMsgBytes(Span msg_bytes, bool& complete) EXCLUSIVE_LOCKS_REQUIRED(!cs_vRecv); void SetCommonVersion(int greatest_common_version) { @@ -661,9 +664,9 @@ class CNode return m_greatest_common_version; } - CService GetAddrLocal() const; + CService GetAddrLocal() const EXCLUSIVE_LOCKS_REQUIRED(!m_addr_local_mutex); //! May not be called more than once - void SetAddrLocal(const CService& addrLocalIn); + void SetAddrLocal(const CService& addrLocalIn) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_local_mutex); CNode* AddRef() { @@ -676,20 +679,15 @@ class CNode nRefCount--; } - void CloseSocketDisconnect(CConnman* connman); + void CloseSocketDisconnect(CConnman* connman) EXCLUSIVE_LOCKS_REQUIRED(!cs_hSocket); - void copyStats(CNodeStats &stats, const std::vector &m_asmap); + void copyStats(CNodeStats &stats, const std::vector &m_asmap) EXCLUSIVE_LOCKS_REQUIRED(!m_subver_mutex, !m_addr_local_mutex, !cs_vSend, !cs_vRecv); ServiceFlags GetLocalServices() const { return nLocalServices; } - std::string GetAddrName() const; - //! Sets the addrName only if it was not previously set - void MaybeSetAddrName(const std::string& addrNameIn); - - std::string ConnectionTypeAsString() const { return ::ConnectionTypeAsString(m_conn_type); } /** A ping-pong round trip has completed successfully. Update latest and minimum ping times. */ @@ -698,8 +696,6 @@ class CNode m_min_ping_time = std::min(m_min_ping_time.load(), ping_time); } - /** Whether this peer is an inbound onion, e.g. connected via our Tor onion service. */ - bool IsInboundOnion() const { return m_inbound_onion; } std::string GetLogString() const; bool CanRelay() const { return !m_masternode_connection || m_masternode_iqr_connection; } @@ -769,15 +765,9 @@ class CNode std::list vRecvMsg; // Used only by SocketHandler thread - mutable RecursiveMutex cs_addrName; - std::string addrName GUARDED_BY(cs_addrName); - // Our address, as reported by the peer - CService addrLocal GUARDED_BY(cs_addrLocal); - mutable RecursiveMutex cs_addrLocal; - - //! Whether this peer is an inbound onion, e.g. connected via our Tor onion service. - const bool m_inbound_onion{false}; + CService addrLocal GUARDED_BY(m_addr_local_mutex); + mutable Mutex m_addr_local_mutex; // Challenge sent in VERSION to be answered with MNAUTH (only happens between MNs) mutable Mutex cs_mnauth; @@ -867,7 +857,10 @@ friend class CNode; bool m_i2p_accept_incoming; }; - void Init(const Options& connOptions) { + void Init(const Options& connOptions) EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex, !m_total_bytes_sent_mutex) + { + AssertLockNotHeld(m_total_bytes_sent_mutex); + nLocalServices = connOptions.nLocalServices; nMaxConnections = connOptions.nMaxConnections; m_max_outbound_full_relay = std::min(connOptions.m_max_outbound_full_relay, connOptions.nMaxConnections); @@ -883,13 +876,13 @@ friend class CNode; nReceiveFloodSize = connOptions.nReceiveFloodSize; m_peer_connect_timeout = std::chrono::seconds{connOptions.m_peer_connect_timeout}; { - LOCK(cs_totalBytesSent); + LOCK(m_total_bytes_sent_mutex); nMaxOutboundLimit = connOptions.nMaxOutboundLimit; } vWhitelistedRange = connOptions.vWhitelistedRange; { - LOCK(cs_vAddedNodes); - vAddedNodes = connOptions.m_added_nodes; + LOCK(m_added_nodes_mutex); + m_added_nodes = connOptions.m_added_nodes; } socketEventsMode = connOptions.socketEventsMode; m_onion_binds = connOptions.onion_binds; @@ -898,7 +891,8 @@ friend class CNode; CConnman(uint64_t seed0, uint64_t seed1, CAddrMan& addrman, bool network_active = true); ~CConnman(); bool Start(CDeterministicMNManager& dmnman, CMasternodeMetaMan& mn_metaman, CMasternodeSync& mn_sync, - CScheduler& scheduler, const Options& options); + CScheduler& scheduler, const Options& options) + EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !m_added_nodes_mutex, !m_addr_fetches_mutex, !mutexMsgProc); void StopThreads(); void StopNodes(); @@ -908,7 +902,7 @@ friend class CNode; StopNodes(); }; - void Interrupt(); + void Interrupt() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); bool GetNetworkActive() const { return fNetworkActive; }; bool GetUseAddrmanOutgoing() const { return m_use_addrman_outgoing; }; void SetNetworkActive(bool active, CMasternodeSync* const mn_sync); @@ -924,8 +918,13 @@ friend class CNode; IsConnection, }; - void OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant* grantOutbound, const char* strDest, ConnectionType conn_type, MasternodeConn masternode_connection = MasternodeConn::IsNotConnection, MasternodeProbeConn masternode_probe_connection = MasternodeProbeConn::IsNotConnection); - void OpenMasternodeConnection(const CAddress& addrConnect, MasternodeProbeConn probe = MasternodeProbeConn::IsConnection); + void OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant* grantOutbound, + const char* strDest, ConnectionType conn_type, + MasternodeConn masternode_connection = MasternodeConn::IsNotConnection, + MasternodeProbeConn masternode_probe_connection = MasternodeProbeConn::IsNotConnection) + EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); + void OpenMasternodeConnection(const CAddress& addrConnect, MasternodeProbeConn probe = MasternodeProbeConn::IsConnection) + EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); bool CheckIncomingNonce(uint64_t nonce); struct CFullyConnectedOnly { @@ -966,13 +965,14 @@ friend class CNode; bool IsMasternodeOrDisconnectRequested(const CService& addr); - void PushMessage(CNode* pnode, CSerializedNetMsg&& msg); + void PushMessage(CNode* pnode, CSerializedNetMsg&& msg) + EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc, !m_total_bytes_sent_mutex); template bool ForEachNodeContinueIf(const Condition& cond, Callable&& func) { - LOCK(cs_vNodes); - for (auto&& node : vNodes) + LOCK(m_nodes_mutex); + for (auto&& node : m_nodes) if (cond(node)) if(!func(node)) return false; @@ -988,8 +988,8 @@ friend class CNode; template bool ForEachNodeContinueIf(const Condition& cond, Callable&& func) const { - LOCK(cs_vNodes); - for (const auto& node : vNodes) + LOCK(m_nodes_mutex); + for (const auto& node : m_nodes) if (cond(node)) if(!func(node)) return false; @@ -1005,8 +1005,8 @@ friend class CNode; template void ForEachNode(const Condition& cond, Callable&& func) { - LOCK(cs_vNodes); - for (auto&& node : vNodes) { + LOCK(m_nodes_mutex); + for (auto&& node : m_nodes) { if (cond(node)) func(node); } @@ -1021,8 +1021,8 @@ friend class CNode; template void ForEachNode(const Condition& cond, Callable&& func) const { - LOCK(cs_vNodes); - for (auto&& node : vNodes) { + LOCK(m_nodes_mutex); + for (auto&& node : m_nodes) { if (cond(node)) func(node); } @@ -1037,8 +1037,8 @@ friend class CNode; template void ForEachNodeThen(const Condition& cond, Callable&& pre, CallableAfter&& post) { - LOCK(cs_vNodes); - for (auto&& node : vNodes) { + LOCK(m_nodes_mutex); + for (auto&& node : m_nodes) { if (cond(node)) pre(node); } @@ -1054,8 +1054,8 @@ friend class CNode; template void ForEachNodeThen(const Condition& cond, Callable&& pre, CallableAfter&& post) const { - LOCK(cs_vNodes); - for (auto&& node : vNodes) { + LOCK(m_nodes_mutex); + for (auto&& node : m_nodes) { if (cond(node)) pre(node); } @@ -1068,9 +1068,6 @@ friend class CNode; ForEachNodeThen(FullyConnectedOnly, pre, post); } - std::vector CopyNodeVector(std::function cond = AllNodes); - void ReleaseNodeVector(const std::vector& vecNodes); - // Addrman functions /** * Return all or many randomly selected addresses, optionally by network. @@ -1109,9 +1106,9 @@ friend class CNode; // Count the number of block-relay-only peers we have over our limit. int GetExtraBlockRelayCount() const; - bool AddNode(const std::string& node); - bool RemoveAddedNode(const std::string& node); - std::vector GetAddedNodeInfo() const; + bool AddNode(const std::string& node) EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex); + bool RemoveAddedNode(const std::string& node) EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex); + std::vector GetAddedNodeInfo() const EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex); /** * Attempts to open a connection. Currently only used from tests. @@ -1124,7 +1121,7 @@ friend class CNode; * - Max total outbound connection capacity filled * - Max connection capacity for type is filled */ - bool AddConnection(const std::string& address, ConnectionType conn_type); + bool AddConnection(const std::string& address, ConnectionType conn_type) EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); bool AddPendingMasternode(const uint256& proTxHash); void SetMasternodeQuorumNodes(Consensus::LLMQType llmqType, const uint256& quorumHash, const std::set& proTxHashes); @@ -1154,32 +1151,30 @@ friend class CNode; //! that peer during `net_processing.cpp:PushNodeVersion()`. ServiceFlags GetLocalServices() const; - uint64_t GetMaxOutboundTarget() const; + uint64_t GetMaxOutboundTarget() const EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); std::chrono::seconds GetMaxOutboundTimeframe() const; //! check if the outbound target is reached //! if param historicalBlockServingLimit is set true, the function will //! response true if the limit for serving historical blocks has been reached - bool OutboundTargetReached(bool historicalBlockServingLimit) const; + bool OutboundTargetReached(bool historicalBlockServingLimit) const EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); //! response the bytes left in the current max outbound cycle //! in case of no limit, it will always response 0 - uint64_t GetOutboundTargetBytesLeft() const; + uint64_t GetOutboundTargetBytesLeft() const EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); - //! returns the time left in the current max outbound cycle - //! in case of no limit, it will always return 0 - std::chrono::seconds GetMaxOutboundTimeLeftInCycle() const; + std::chrono::seconds GetMaxOutboundTimeLeftInCycle() const EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); uint64_t GetTotalBytesRecv() const; - uint64_t GetTotalBytesSent() const; + uint64_t GetTotalBytesSent() const EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); /** Get a unique deterministic randomizer. */ CSipHasher GetDeterministicRandomizer(uint64_t id) const; unsigned int GetReceiveFloodSize() const; - void WakeMessageHandler(); - void WakeSelect(); + void WakeMessageHandler() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); + void WakeSelect() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); /** Attempts to obfuscate tx time through exponentially distributed emitting. Works assuming that a single interval is used. @@ -1192,6 +1187,26 @@ friend class CNode; /** Return true if we should disconnect the peer for failing an inactivity check. */ bool ShouldRunInactivityChecks(const CNode& node, std::chrono::seconds now) const; + /** + * RAII helper to atomically create a copy of `m_nodes` and add a reference + * to each of the nodes. The nodes are released when this object is destroyed. + */ + class NodesSnapshot + { + public: + explicit NodesSnapshot(const CConnman& connman, std::function cond = AllNodes, + bool shuffle = false); + ~NodesSnapshot(); + + const std::vector& Nodes() const + { + return m_nodes_copy; + } + + private: + std::vector m_nodes_copy; + }; + private: struct ListenSocket { public: @@ -1202,6 +1217,10 @@ friend class CNode; NetPermissionFlags m_permissions; }; + //! returns the time left in the current max outbound cycle + //! in case of no limit, it will always return 0 + std::chrono::seconds GetMaxOutboundTimeLeftInCycle_() const EXCLUSIVE_LOCKS_REQUIRED(m_total_bytes_sent_mutex); + bool BindListenPort(const CService& bindAddr, bilingual_str& strError, NetPermissionFlags permissions); bool Bind(const CService& addr, unsigned int flags, NetPermissionFlags permissions); bool InitBinds( @@ -1209,17 +1228,19 @@ friend class CNode; const std::vector& whiteBinds, const std::vector& onion_binds); - void ThreadOpenAddedConnections(); - void AddAddrFetch(const std::string& strDest); - void ProcessAddrFetch(); - void ThreadOpenConnections(const std::vector connect, CDeterministicMNManager& dmnman); - void ThreadMessageHandler(); - void ThreadI2PAcceptIncoming(CMasternodeSync& mn_sync); - void AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSync& mn_sync); + void ThreadOpenAddedConnections() EXCLUSIVE_LOCKS_REQUIRED(!m_added_nodes_mutex, !mutexMsgProc); + void AddAddrFetch(const std::string& strDest) EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex); + void ProcessAddrFetch() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !mutexMsgProc); + void ThreadOpenConnections(const std::vector connect, CDeterministicMNManager& dmnman) + EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_added_nodes_mutex, !m_nodes_mutex, !mutexMsgProc); + void ThreadMessageHandler() EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); + void ThreadI2PAcceptIncoming(CMasternodeSync& mn_sync) EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); + void AcceptConnection(const ListenSocket& hListenSocket, CMasternodeSync& mn_sync) + EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); /** * Create a `CNode` object from a socket that has just been accepted and add the node to - * the `vNodes` member. + * the `m_nodes` member. * @param[in] hSocket Connected socket to communicate with the peer. * @param[in] permissionFlags The peer's permissions. * @param[in] addr_bind The address and port at our side of the connection. @@ -1229,30 +1250,95 @@ friend class CNode; NetPermissionFlags permissionFlags, const CAddress& addr_bind, const CAddress& addr, - CMasternodeSync& mn_sync); + CMasternodeSync& mn_sync) EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); void DisconnectNodes(); void NotifyNumConnectionsChanged(CMasternodeSync& mn_sync); void CalculateNumConnectionsChangedStats(); /** Return true if the peer is inactive and should be disconnected. */ bool InactivityCheck(const CNode& node) const; - bool GenerateSelectSet(std::set &recv_set, std::set &send_set, std::set &error_set); + + /** + * Generate a collection of sockets to check for IO readiness. + * @param[in] nodes Select from these nodes' sockets. + * @param[out] recv_set Sockets to check for read readiness. + * @param[out] send_set Sockets to check for write readiness. + * @param[out] error_set Sockets to check for errors. + * @return true if at least one socket is to be checked (the returned set is not empty) + */ + bool GenerateSelectSet(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set); + + /** + * Check which sockets are ready for IO. + * @param[in] nodes Select from these nodes' sockets (in supported event methods). + * @param[in] only_poll Permit zero timeout polling + * @param[out] recv_set Sockets which are ready for read. + * @param[out] send_set Sockets which are ready for write. + * @param[out] error_set Sockets which have errors. + * This calls `GenerateSelectSet()` to gather a list of sockets to check. + */ + void SocketEvents(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll); + #ifdef USE_KQUEUE - void SocketEventsKqueue(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll); + void SocketEventsKqueue(std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll); #endif #ifdef USE_EPOLL - void SocketEventsEpoll(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll); + void SocketEventsEpoll(std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll); #endif #ifdef USE_POLL - void SocketEventsPoll(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll); + void SocketEventsPoll(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll); #endif - void SocketEventsSelect(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll); - void SocketEvents(std::set &recv_set, std::set &send_set, std::set &error_set, bool fOnlyPoll); - void SocketHandler(CMasternodeSync& mn_sync); - void ThreadSocketHandler(CMasternodeSync& mn_sync); - void ThreadDNSAddressSeed(); + void SocketEventsSelect(const std::vector& nodes, + std::set& recv_set, + std::set& send_set, + std::set& error_set, + bool only_poll); + + /** + * Check connected and listening sockets for IO readiness and process them accordingly. + */ + void SocketHandler(CMasternodeSync& mn_sync) EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); + + /** + * Do the read/write for connected sockets that are ready for IO. + * @param[in] recv_set Sockets that are ready for read. + * @param[in] send_set Sockets that are ready for send. + * @param[in] error_set Sockets that have an exceptional condition (error). + */ + void SocketHandlerConnected(const std::set& recv_set, + const std::set& send_set, + const std::set& error_set) + EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); + + /** + * Accept incoming connections, one from each read-ready listening socket. + * @param[in] recv_set Sockets that are ready for read. + */ + void SocketHandlerListening(const std::set& recv_set, CMasternodeSync& mn_sync) + EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); + + void ThreadSocketHandler(CMasternodeSync& mn_sync) EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex, !mutexMsgProc); + void ThreadDNSAddressSeed() EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_nodes_mutex); void ThreadOpenMasternodeConnections(CDeterministicMNManager& dmnman, CMasternodeMetaMan& mn_metaman, - CMasternodeSync& mn_sync); + CMasternodeSync& mn_sync) + EXCLUSIVE_LOCKS_REQUIRED(!m_addr_fetches_mutex, !m_nodes_mutex, !mutexMsgProc); uint64_t CalculateKeyedNetGroup(const CAddress& ad) const; @@ -1276,12 +1362,12 @@ friend class CNode; NodeId GetNewNodeId(); size_t SocketSendData(CNode& node) EXCLUSIVE_LOCKS_REQUIRED(node.cs_vSend); - size_t SocketRecvData(CNode* pnode); + size_t SocketRecvData(CNode* pnode) EXCLUSIVE_LOCKS_REQUIRED(!mutexMsgProc); void DumpAddresses(); // Network stats void RecordBytesRecv(uint64_t bytes); - void RecordBytesSent(uint64_t bytes); + void RecordBytesSent(uint64_t bytes) EXCLUSIVE_LOCKS_REQUIRED(!m_total_bytes_sent_mutex); /** * Return vector of current BLOCK_RELAY peers. @@ -1295,15 +1381,14 @@ friend class CNode; void UnregisterEvents(CNode* pnode); // Network usage totals - mutable RecursiveMutex cs_totalBytesRecv; - mutable RecursiveMutex cs_totalBytesSent; - uint64_t nTotalBytesRecv GUARDED_BY(cs_totalBytesRecv) {0}; - uint64_t nTotalBytesSent GUARDED_BY(cs_totalBytesSent) {0}; + mutable Mutex m_total_bytes_sent_mutex; + std::atomic nTotalBytesRecv{0}; + uint64_t nTotalBytesSent GUARDED_BY(m_total_bytes_sent_mutex) {0}; // outbound limit & stats - uint64_t nMaxOutboundTotalBytesSentInCycle GUARDED_BY(cs_totalBytesSent) {0}; - std::chrono::seconds nMaxOutboundCycleStartTime GUARDED_BY(cs_totalBytesSent) {0}; - uint64_t nMaxOutboundLimit GUARDED_BY(cs_totalBytesSent); + uint64_t nMaxOutboundTotalBytesSentInCycle GUARDED_BY(m_total_bytes_sent_mutex) {0}; + std::chrono::seconds nMaxOutboundCycleStartTime GUARDED_BY(m_total_bytes_sent_mutex) {0}; + uint64_t nMaxOutboundLimit GUARDED_BY(m_total_bytes_sent_mutex); // P2P timeout in seconds std::chrono::seconds m_peer_connect_timeout; @@ -1320,21 +1405,23 @@ friend class CNode; bool fAddressesInitialized{false}; CAddrMan& addrman; std::deque m_addr_fetches GUARDED_BY(m_addr_fetches_mutex); - RecursiveMutex m_addr_fetches_mutex; - std::vector vAddedNodes GUARDED_BY(cs_vAddedNodes); - mutable RecursiveMutex cs_vAddedNodes; + Mutex m_addr_fetches_mutex; + std::vector m_added_nodes GUARDED_BY(m_added_nodes_mutex); + mutable Mutex m_added_nodes_mutex; + std::vector m_nodes GUARDED_BY(m_nodes_mutex); + std::list m_nodes_disconnected; + mutable RecursiveMutex m_nodes_mutex; + std::atomic nLastNodeId{0}; + unsigned int nPrevNodeCount{0}; + std::vector vPendingMasternodes; mutable RecursiveMutex cs_vPendingMasternodes; std::map, std::set> masternodeQuorumNodes GUARDED_BY(cs_vPendingMasternodes); std::map, std::set> masternodeQuorumRelayMembers GUARDED_BY(cs_vPendingMasternodes); std::set masternodePendingProbes GUARDED_BY(cs_vPendingMasternodes); - std::vector vNodes GUARDED_BY(cs_vNodes); - std::list vNodesDisconnected; + mutable Mutex cs_mapSocketToNode; std::unordered_map mapSocketToNode GUARDED_BY(cs_mapSocketToNode); - mutable RecursiveMutex cs_vNodes; - std::atomic nLastNodeId{0}; - unsigned int nPrevNodeCount{0}; /** * Cache responses to addr requests to minimize privacy leak. diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 7f79ab4892..3a33be97c7 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -374,38 +374,45 @@ class PeerManagerImpl final : public PeerManager bool ignore_incoming_txs); /** Overridden from CValidationInterface. */ - void BlockConnected(const std::shared_ptr& pblock, const CBlockIndex* pindexConnected) override; - void BlockDisconnected(const std::shared_ptr &block, const CBlockIndex* pindex) override; - void UpdatedBlockTip(const CBlockIndex *pindexNew, const CBlockIndex *pindexFork, bool fInitialDownload) override; - void BlockChecked(const CBlock& block, const BlockValidationState& state) override; + void BlockConnected(const std::shared_ptr& pblock, const CBlockIndex* pindexConnected) override + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex, !m_recent_confirmed_transactions_mutex); + void BlockDisconnected(const std::shared_ptr &block, const CBlockIndex* pindex) override + EXCLUSIVE_LOCKS_REQUIRED(!m_recent_confirmed_transactions_mutex); + void UpdatedBlockTip(const CBlockIndex *pindexNew, const CBlockIndex *pindexFork, bool fInitialDownload) override + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + void BlockChecked(const CBlock& block, const BlockValidationState& state) override + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); void NewPoWValidBlock(const CBlockIndex *pindex, const std::shared_ptr& pblock) override; /** Implement NetEventsInterface */ - void InitializeNode(CNode* pnode) override; - void FinalizeNode(const CNode& node) override; - bool ProcessMessages(CNode* pfrom, std::atomic& interrupt) override; - bool SendMessages(CNode* pto) override EXCLUSIVE_LOCKS_REQUIRED(pto->cs_sendProcessing); + void InitializeNode(CNode* pnode) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + void FinalizeNode(const CNode& node) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + bool ProcessMessages(CNode* pfrom, std::atomic& interrupt) override + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex, !m_recent_confirmed_transactions_mutex); + bool SendMessages(CNode* pto) override EXCLUSIVE_LOCKS_REQUIRED(pto->cs_sendProcessing) + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex, !m_recent_confirmed_transactions_mutex); /** Implement PeerManager */ void CheckForStaleTipAndEvictPeers() override; - bool GetNodeStateStats(NodeId nodeid, CNodeStateStats& stats) const override; + bool GetNodeStateStats(NodeId nodeid, CNodeStateStats& stats) const override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); bool IgnoresIncomingTxs() override { return m_ignore_incoming_txs; } - void SendPings() override; - void PushInventory(NodeId nodeid, const CInv& inv) override; - void RelayInv(CInv &inv, const int minProtoVersion) override; - void RelayInvFiltered(CInv &inv, const CTransaction &relatedTx, const int minProtoVersion) override; - void RelayInvFiltered(CInv &inv, const uint256 &relatedTxHash, const int minProtoVersion) override; - void RelayTransaction(const uint256& txid) override; + void SendPings() override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex);; + void PushInventory(NodeId nodeid, const CInv& inv) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + void RelayInv(CInv &inv, const int minProtoVersion) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + void RelayInvFiltered(CInv &inv, const CTransaction &relatedTx, const int minProtoVersion) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + void RelayInvFiltered(CInv &inv, const uint256 &relatedTxHash, const int minProtoVersion) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); + void RelayTransaction(const uint256& txid) override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); void SetBestHeight(int height) override { m_best_height = height; }; - void Misbehaving(const NodeId pnode, const int howmuch, const std::string& message = "") override; + void Misbehaving(const NodeId pnode, const int howmuch, const std::string& message = "") override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); void ProcessMessage(CNode& pfrom, const std::string& msg_type, CDataStream& vRecv, - const std::chrono::microseconds time_received, const std::atomic& interruptMsgProc) override; - bool IsBanned(NodeId pnode) override EXCLUSIVE_LOCKS_REQUIRED(cs_main); - bool IsInvInFilter(NodeId nodeid, const uint256& hash) const override; + const std::chrono::microseconds time_received, const std::atomic& interruptMsgProc) override + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex, !m_recent_confirmed_transactions_mutex); + bool IsBanned(NodeId pnode) override EXCLUSIVE_LOCKS_REQUIRED(cs_main, !m_peer_mutex); + bool IsInvInFilter(NodeId nodeid, const uint256& hash) const override EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); private: /** Helper to process result of external handlers of message */ - void ProcessPeerMsgRet(const PeerMsgRet& ret, CNode& pfrom); + void ProcessPeerMsgRet(const PeerMsgRet& ret, CNode& pfrom) EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** Consider evicting an outbound peer based on the amount of time they've been behind our tip */ void ConsiderEviction(CNode& pto, int64_t time_in_seconds) EXCLUSIVE_LOCKS_REQUIRED(cs_main); @@ -414,15 +421,15 @@ class PeerManagerImpl final : public PeerManager void EvictExtraOutboundPeers(int64_t time_in_seconds) EXCLUSIVE_LOCKS_REQUIRED(cs_main); /** Retrieve unbroadcast transactions from the mempool and reattempt sending to peers */ - void ReattemptInitialBroadcast(CScheduler& scheduler); + void ReattemptInitialBroadcast(CScheduler& scheduler) EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** Get a shared pointer to the Peer object. * May return an empty shared_ptr if the Peer object can't be found. */ - PeerRef GetPeerRef(NodeId id) const; + PeerRef GetPeerRef(NodeId id) const EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** Get a shared pointer to the Peer object and remove it from m_peer_map. * May return an empty shared_ptr if the Peer object can't be found. */ - PeerRef RemovePeer(NodeId id); + PeerRef RemovePeer(NodeId id) EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** * Potentially mark a node discouraged based on the contents of a BlockValidationState object @@ -435,7 +442,8 @@ class PeerManagerImpl final : public PeerManager * @return Returns true if the peer was punished (probably disconnected) */ bool MaybePunishNodeForBlock(NodeId nodeid, const BlockValidationState& state, - bool via_compact_block, const std::string& message = ""); + bool via_compact_block, const std::string& message = "") + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** * Potentially ban a node based on the contents of a TxValidationState object @@ -444,7 +452,8 @@ class PeerManagerImpl final : public PeerManager * * Changes here may need to be reflected in TxRelayMayResultInDisconnect(). */ - bool MaybePunishNodeForTx(NodeId nodeid, const TxValidationState& state, const std::string& message = ""); + bool MaybePunishNodeForTx(NodeId nodeid, const TxValidationState& state, const std::string& message = "") + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** Maybe disconnect a peer and discourage future connections from its address. * @@ -454,14 +463,16 @@ class PeerManagerImpl final : public PeerManager */ bool MaybeDiscourageAndDisconnect(CNode& pnode, Peer& peer); - void ProcessOrphanTx(std::set& orphan_work_set) - EXCLUSIVE_LOCKS_REQUIRED(cs_main, g_cs_orphans); + void ProcessOrphanTx(std::set& orphan_work_set) EXCLUSIVE_LOCKS_REQUIRED(cs_main, g_cs_orphans) + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** Process a single headers message from a peer. */ void ProcessHeadersMessage(CNode& pfrom, const Peer& peer, const std::vector& headers, - bool via_compact_block); + bool via_compact_block) + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); - void SendBlockTransactions(CNode& pfrom, const CBlock& block, const BlockTransactionsRequest& req); + void SendBlockTransactions(CNode& pfrom, const CBlock& block, const BlockTransactionsRequest& req) + EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); /** Send a version message to a peer */ void PushNodeVersion(CNode& pnode, const Peer& peer); @@ -482,7 +493,7 @@ class PeerManagerImpl final : public PeerManager * @param[in] fReachable Whether the address' network is reachable. We relay unreachable * addresses less. */ - void RelayAddress(NodeId originator, const CAddress& addr, bool fReachable); + void RelayAddress(NodeId originator, const CAddress& addr, bool fReachable) EXCLUSIVE_LOCKS_REQUIRED(!m_peer_mutex); const CChainParams& m_chainparams; CConnman& m_connman; @@ -608,7 +619,8 @@ class PeerManagerImpl final : public PeerManager /** Number of outbound peers with m_chain_sync.m_protect. */ int m_outbound_peers_with_protect_from_disconnect GUARDED_BY(cs_main) = 0; - bool AlreadyHave(const CInv& inv) EXCLUSIVE_LOCKS_REQUIRED(cs_main); + bool AlreadyHave(const CInv& inv) + EXCLUSIVE_LOCKS_REQUIRED(cs_main, !m_recent_confirmed_transactions_mutex); /** * Filter for transactions that were recently rejected by @@ -3357,7 +3369,7 @@ void PeerManagerImpl::ProcessMessage( pfrom.nServices = nServices; pfrom.SetAddrLocal(addrMe); { - LOCK(pfrom.cs_SubVer); + LOCK(pfrom.m_subver_mutex); pfrom.cleanSubVer = cleanSubVer; } peer->m_starting_height = starting_height; diff --git a/src/qt/clientmodel.h b/src/qt/clientmodel.h index f328ed7519..46068ca264 100644 --- a/src/qt/clientmodel.h +++ b/src/qt/clientmodel.h @@ -63,7 +63,7 @@ class ClientModel : public QObject //! Return number of connections, default is in- and outbound (total) int getNumConnections(unsigned int flags = CONNECTIONS_ALL) const; int getNumBlocks() const; - uint256 getBestBlockHash(); + uint256 getBestBlockHash() EXCLUSIVE_LOCKS_REQUIRED(!m_cached_tip_mutex); int getHeaderTipHeight() const; int64_t getHeaderTipTime() const; diff --git a/src/qt/peertablemodel.cpp b/src/qt/peertablemodel.cpp index e1a4a1f46f..a85958769d 100644 --- a/src/qt/peertablemodel.cpp +++ b/src/qt/peertablemodel.cpp @@ -28,7 +28,7 @@ bool NodeLessThan::operator()(const CNodeCombinedStats &left, const CNodeCombine case PeerTableModel::NetNodeId: return pLeft->nodeid < pRight->nodeid; case PeerTableModel::Address: - return pLeft->addrName.compare(pRight->addrName) < 0; + return pLeft->m_addr_name.compare(pRight->m_addr_name) < 0; case PeerTableModel::Network: return pLeft->m_network < pRight->m_network; case PeerTableModel::Ping: @@ -163,7 +163,7 @@ QVariant PeerTableModel::data(const QModelIndex &index, int role) const return (qint64)rec->nodeStats.nodeid; case Address: // prepend to peer address down-arrow symbol for inbound connection and up-arrow for outbound connection - return QString(rec->nodeStats.fInbound ? "↓ " : "↑ ") + QString::fromStdString(rec->nodeStats.addrName); + return QString(rec->nodeStats.fInbound ? "↓ " : "↑ ") + QString::fromStdString(rec->nodeStats.m_addr_name); case Network: return GUIUtil::NetworkToQString(rec->nodeStats.m_network); case Ping: diff --git a/src/qt/rpcconsole.cpp b/src/qt/rpcconsole.cpp index 25d06547ec..627b59f3fe 100644 --- a/src/qt/rpcconsole.cpp +++ b/src/qt/rpcconsole.cpp @@ -70,6 +70,7 @@ namespace { const QStringList historyFilter = QStringList() << "importprivkey" << "importmulti" + << "sethdseed" << "signmessagewithprivkey" << "signrawtransactionwithkey" << "upgradetohd" @@ -1230,7 +1231,7 @@ void RPCConsole::updateDetailWidget() } const CNodeCombinedStats *stats = clientModel->getPeerTableModel()->getNodeStats(selected_rows.first().row()); // update the detail ui with latest node information - QString peerAddrDetails(QString::fromStdString(stats->nodeStats.addrName) + " "); + QString peerAddrDetails(QString::fromStdString(stats->nodeStats.m_addr_name) + " "); peerAddrDetails += tr("(peer id: %1)").arg(QString::number(stats->nodeStats.nodeid)); if (!stats->nodeStats.addrLocal.empty()) peerAddrDetails += "
" + tr("via %1").arg(QString::fromStdString(stats->nodeStats.addrLocal)); diff --git a/src/random.cpp b/src/random.cpp index 0c6129ea25..92ea4976e1 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -377,7 +377,7 @@ class RNGState { { } - void AddEvent(uint32_t event_info) noexcept + void AddEvent(uint32_t event_info) noexcept EXCLUSIVE_LOCKS_REQUIRED(!m_events_mutex) { LOCK(m_events_mutex); @@ -391,7 +391,7 @@ class RNGState { /** * Feed (the hash of) all events added through AddEvent() to hasher. */ - void SeedEvents(CSHA512& hasher) noexcept + void SeedEvents(CSHA512& hasher) noexcept EXCLUSIVE_LOCKS_REQUIRED(!m_events_mutex) { // We use only SHA256 for the events hashing to get the ASM speedups we have for SHA256, // since we want it to be fast as network peers may be able to trigger it repeatedly. @@ -410,7 +410,7 @@ class RNGState { * * If this function has never been called with strong_seed = true, false is returned. */ - bool MixExtract(unsigned char* out, size_t num, CSHA512&& hasher, bool strong_seed) noexcept + bool MixExtract(unsigned char* out, size_t num, CSHA512&& hasher, bool strong_seed) noexcept EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { assert(num <= 32); unsigned char buf[64]; diff --git a/src/rpc/client.cpp b/src/rpc/client.cpp index 4645854ba6..510ce0443e 100644 --- a/src/rpc/client.cpp +++ b/src/rpc/client.cpp @@ -45,6 +45,7 @@ static const CRPCConvertParam vRPCConvertParams[] = { "sendtoaddress", 7, "conf_target" }, { "sendtoaddress", 9, "avoid_reuse" }, { "settxfee", 0, "amount" }, + { "sethdseed", 0, "newkeypool" }, { "getreceivedbyaddress", 1, "minconf" }, { "getreceivedbyaddress", 2, "addlocked" }, { "getreceivedbylabel", 1, "minconf" }, diff --git a/src/rpc/net.cpp b/src/rpc/net.cpp index b88e9de9c9..a5e4244995 100644 --- a/src/rpc/net.cpp +++ b/src/rpc/net.cpp @@ -200,7 +200,7 @@ static RPCHelpMan getpeerinfo() CNodeStateStats statestats; bool fStateStats = peerman.GetNodeStateStats(stats.nodeid, statestats); obj.pushKV("id", stats.nodeid); - obj.pushKV("addr", stats.addrName); + obj.pushKV("addr", stats.m_addr_name); if (stats.addrBind.IsValid()) { obj.pushKV("addrbind", stats.addrBind.ToString()); } diff --git a/src/scheduler.h b/src/scheduler.h index c2396da742..438fb8c126 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -46,10 +46,10 @@ class CScheduler typedef std::function Function; /** Call func at/after time t */ - void schedule(Function f, std::chrono::system_clock::time_point t); + void schedule(Function f, std::chrono::system_clock::time_point t) EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex); /** Call f once after the delta has passed */ - void scheduleFromNow(Function f, std::chrono::milliseconds delta) + void scheduleFromNow(Function f, std::chrono::milliseconds delta) EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex) { schedule(std::move(f), std::chrono::system_clock::now() + delta); } @@ -60,29 +60,29 @@ class CScheduler * The timing is not exact: Every time f is finished, it is rescheduled to run again after delta. If you need more * accurate scheduling, don't use this method. */ - void scheduleEvery(Function f, std::chrono::milliseconds delta); + void scheduleEvery(Function f, std::chrono::milliseconds delta) EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex); /** * Mock the scheduler to fast forward in time. * Iterates through items on taskQueue and reschedules them * to be delta_seconds sooner. */ - void MockForward(std::chrono::seconds delta_seconds); + void MockForward(std::chrono::seconds delta_seconds) EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex); /** * Services the queue 'forever'. Should be run in a thread. */ - void serviceQueue(); + void serviceQueue() EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex); /** Tell any threads running serviceQueue to stop as soon as the current task is done */ - void stop() + void stop() EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex) { WITH_LOCK(newTaskMutex, stopRequested = true); newTaskScheduled.notify_all(); if (m_service_thread.joinable()) m_service_thread.join(); } /** Tell any threads running serviceQueue to stop when there is no work left to be done */ - void StopWhenDrained() + void StopWhenDrained() EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex) { WITH_LOCK(newTaskMutex, stopWhenEmpty = true); newTaskScheduled.notify_all(); @@ -94,10 +94,11 @@ class CScheduler * and first and last task times */ size_t getQueueInfo(std::chrono::system_clock::time_point& first, - std::chrono::system_clock::time_point& last) const; + std::chrono::system_clock::time_point& last) const + EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex); /** Returns true if there are threads actively running in serviceQueue() */ - bool AreThreadsServicingQueue() const; + bool AreThreadsServicingQueue() const EXCLUSIVE_LOCKS_REQUIRED(!newTaskMutex); private: mutable Mutex newTaskMutex; @@ -128,8 +129,8 @@ class SingleThreadedSchedulerClient std::list> m_callbacks_pending GUARDED_BY(m_callbacks_mutex); bool m_are_callbacks_running GUARDED_BY(m_callbacks_mutex) = false; - void MaybeScheduleProcessQueue(); - void ProcessQueue(); + void MaybeScheduleProcessQueue() EXCLUSIVE_LOCKS_REQUIRED(!m_callbacks_mutex); + void ProcessQueue() EXCLUSIVE_LOCKS_REQUIRED(!m_callbacks_mutex); public: explicit SingleThreadedSchedulerClient(CScheduler& scheduler LIFETIMEBOUND) : m_scheduler{scheduler} {} @@ -140,15 +141,15 @@ class SingleThreadedSchedulerClient * Practically, this means that callbacks can behave as if they are executed * in order by a single thread. */ - void AddToProcessQueue(std::function func); + void AddToProcessQueue(std::function func) EXCLUSIVE_LOCKS_REQUIRED(!m_callbacks_mutex); /** * Processes all remaining queue members on the calling thread, blocking until queue is empty * Must be called after the CScheduler has no remaining processing threads! */ - void EmptyQueue(); + void EmptyQueue() EXCLUSIVE_LOCKS_REQUIRED(!m_callbacks_mutex); - size_t CallbacksPending(); + size_t CallbacksPending() EXCLUSIVE_LOCKS_REQUIRED(!m_callbacks_mutex); }; #endif diff --git a/src/script/descriptor.cpp b/src/script/descriptor.cpp index 3664a90db2..cc17ea947a 100644 --- a/src/script/descriptor.cpp +++ b/src/script/descriptor.cpp @@ -166,7 +166,7 @@ struct PubkeyProvider * write_cache is the cache to write keys to (if not nullptr) * Caches are not exclusive but this is not tested. Currently we use them exclusively */ - virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) = 0; + virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0; /** Whether this represent multiple public keys at different positions. */ virtual bool IsRange() const = 0; @@ -181,7 +181,7 @@ struct PubkeyProvider virtual bool ToPrivateString(const SigningProvider& arg, std::string& out) const = 0; /** Get the descriptor string form with the xpub at the last hardened derivation */ - virtual bool ToNormalizedString(const SigningProvider& arg, std::string& out, bool priv) const = 0; + virtual bool ToNormalizedString(const SigningProvider& arg, std::string& out, const DescriptorCache* cache = nullptr) const = 0; /** Derive a private key, if private data is available in arg. */ virtual bool GetPrivKey(int pos, const SigningProvider& arg, CKey& key) const = 0; @@ -199,7 +199,7 @@ class OriginPubkeyProvider final : public PubkeyProvider public: OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr provider) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, write_cache)) return false; std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), info.fingerprint); @@ -216,10 +216,10 @@ class OriginPubkeyProvider final : public PubkeyProvider ret = "[" + OriginString() + "]" + std::move(sub); return true; } - bool ToNormalizedString(const SigningProvider& arg, std::string& ret, bool priv) const override + bool ToNormalizedString(const SigningProvider& arg, std::string& ret, const DescriptorCache* cache) const override { std::string sub; - if (!m_provider->ToNormalizedString(arg, sub, priv)) return false; + if (!m_provider->ToNormalizedString(arg, sub, cache)) return false; // If m_provider is a BIP32PubkeyProvider, we may get a string formatted like a OriginPubkeyProvider // In that case, we need to strip out the leading square bracket and fingerprint from the substring, // and append that to our own origin string. @@ -244,7 +244,7 @@ class ConstPubkeyProvider final : public PubkeyProvider public: ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey) : PubkeyProvider(exp_index), m_pubkey(pubkey) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { key = m_pubkey; info.path.clear(); @@ -262,9 +262,8 @@ class ConstPubkeyProvider final : public PubkeyProvider ret = EncodeSecret(key); return true; } - bool ToNormalizedString(const SigningProvider& arg, std::string& ret, bool priv) const override + bool ToNormalizedString(const SigningProvider& arg, std::string& ret, const DescriptorCache* cache) const override { - if (priv) return ToPrivateString(arg, ret); ret = ToString(); return true; } @@ -287,9 +286,6 @@ class BIP32PubkeyProvider final : public PubkeyProvider CExtPubKey m_root_extkey; KeyPath m_path; DeriveType m_derive; - // Cache of the parent of the final derived pubkeys. - // Primarily useful for situations when no read_cache is provided - CExtPubKey m_cached_xpub; bool GetExtKey(const SigningProvider& arg, CExtKey& ret) const { @@ -304,11 +300,14 @@ class BIP32PubkeyProvider final : public PubkeyProvider } // Derives the last xprv - bool GetDerivedExtKey(const SigningProvider& arg, CExtKey& xprv) const + bool GetDerivedExtKey(const SigningProvider& arg, CExtKey& xprv, CExtKey& last_hardened) const { if (!GetExtKey(arg, xprv)) return false; for (auto entry : m_path) { xprv.Derive(xprv, entry); + if (entry >> 31) { + last_hardened = xprv; + } } return true; } @@ -326,7 +325,7 @@ class BIP32PubkeyProvider final : public PubkeyProvider BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive) {} bool IsRange() const override { return m_derive != DeriveType::NO; } size_t GetSize() const override { return 33; } - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { // Info of parent of the to be derived pubkey KeyOriginInfo parent_info; @@ -342,6 +341,7 @@ class BIP32PubkeyProvider final : public PubkeyProvider // Derive keys or fetch them from cache CExtPubKey final_extkey = m_root_extkey; CExtPubKey parent_extkey = m_root_extkey; + CExtPubKey last_hardened_extkey; bool der = true; if (read_cache) { if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) { @@ -351,16 +351,17 @@ class BIP32PubkeyProvider final : public PubkeyProvider final_extkey = parent_extkey; if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); } - } else if (m_cached_xpub.pubkey.IsValid() && m_derive != DeriveType::HARDENED) { - parent_extkey = final_extkey = m_cached_xpub; - if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); } else if (IsHardened()) { CExtKey xprv; - if (!GetDerivedExtKey(arg, xprv)) return false; + CExtKey lh_xprv; + if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return false; parent_extkey = xprv.Neuter(); if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos); if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL); final_extkey = xprv.Neuter(); + if (lh_xprv.key.IsValid()) { + last_hardened_extkey = lh_xprv.Neuter(); + } } else { for (auto entry : m_path) { der = parent_extkey.Derive(parent_extkey, entry); @@ -375,15 +376,14 @@ class BIP32PubkeyProvider final : public PubkeyProvider final_info_out = final_info_out_tmp; key_out = final_extkey.pubkey; - // We rely on the consumer to check that m_derive isn't HARDENED as above - // But we can't have already cached something in case we read something from the cache - // and parent_extkey isn't actually the parent. - if (!m_cached_xpub.pubkey.IsValid()) m_cached_xpub = parent_extkey; - if (write_cache) { // Only cache parent if there is any unhardened derivation if (m_derive != DeriveType::HARDENED) { write_cache->CacheParentExtPubKey(m_expr_index, parent_extkey); + // Cache last hardened xpub if we have it + if (last_hardened_extkey.pubkey.IsValid()) { + write_cache->CacheLastHardenedExtPubKey(m_expr_index, last_hardened_extkey); + } } else if (final_info_out.path.size() > 0) { write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey); } @@ -411,11 +411,10 @@ class BIP32PubkeyProvider final : public PubkeyProvider } return true; } - bool ToNormalizedString(const SigningProvider& arg, std::string& out, bool priv) const override + bool ToNormalizedString(const SigningProvider& arg, std::string& out, const DescriptorCache* cache) const override { // For hardened derivation type, just return the typical string, nothing to normalize if (m_derive == DeriveType::HARDENED) { - if (priv) return ToPrivateString(arg, out); out = ToString(); return true; } @@ -428,33 +427,42 @@ class BIP32PubkeyProvider final : public PubkeyProvider } // Either no derivation or all unhardened derivation if (i == -1) { - if (priv) return ToPrivateString(arg, out); out = ToString(); return true; } - // Derive the xpub at the last hardened step - CExtKey xprv; - if (!GetExtKey(arg, xprv)) return false; + // Get the path to the last hardened stup KeyOriginInfo origin; int k = 0; for (; k <= i; ++k) { - // Derive - xprv.Derive(xprv, m_path.at(k)); // Add to the path origin.path.push_back(m_path.at(k)); - // First derivation element, get the fingerprint for origin - if (k == 0) { - std::copy(xprv.vchFingerprint, xprv.vchFingerprint + 4, origin.fingerprint); - } } // Build the remaining path KeyPath end_path; for (; k < (int)m_path.size(); ++k) { end_path.push_back(m_path.at(k)); } + // Get the fingerprint + CKeyID id = m_root_extkey.pubkey.GetID(); + std::copy(id.begin(), id.begin() + 4, origin.fingerprint); + + CExtPubKey xpub; + CExtKey lh_xprv; + // If we have the cache, just get the parent xpub + if (cache != nullptr) { + cache->GetCachedLastHardenedExtPubKey(m_expr_index, xpub); + } + if (!xpub.pubkey.IsValid()) { + // Cache miss, or nor cache, or need privkey + CExtKey xprv; + if (!GetDerivedExtKey(arg, xprv, lh_xprv)) return false; + xpub = lh_xprv.Neuter(); + } + assert(xpub.pubkey.IsValid()); + // Build the string std::string origin_str = HexStr(origin.fingerprint) + FormatHDKeypath(origin.path); - out = "[" + origin_str + "]" + (priv ? EncodeExtKey(xprv) : EncodeExtPubKey(xprv.Neuter())) + FormatHDKeypath(end_path); + out = "[" + origin_str + "]" + EncodeExtPubKey(xpub) + FormatHDKeypath(end_path); if (IsRange()) { out += "/*"; assert(m_derive == DeriveType::UNHARDENED); @@ -464,7 +472,8 @@ class BIP32PubkeyProvider final : public PubkeyProvider bool GetPrivKey(int pos, const SigningProvider& arg, CKey& key) const override { CExtKey extkey; - if (!GetDerivedExtKey(arg, extkey)) return false; + CExtKey dummy; + if (!GetDerivedExtKey(arg, extkey, dummy)) return false; if (m_derive == DeriveType::UNHARDENED) extkey.Derive(extkey, pos); if (m_derive == DeriveType::HARDENED) extkey.Derive(extkey, pos | 0x80000000UL); key = extkey.key; @@ -481,34 +490,42 @@ class DescriptorImpl : public Descriptor const std::string m_name; protected: - //! The sub-descriptor argument (nullptr for everything but SH and WSH). + //! The sub-descriptor arguments (empty for everything but SH and WSH). //! In doc/descriptors.m this is referred to as SCRIPT expressions sh(SCRIPT) //! and wsh(SCRIPT), and distinct from KEY expressions and ADDR expressions. - const std::unique_ptr m_subdescriptor_arg; + //! Subdescriptors can only ever generate a single script. + const std::vector> m_subdescriptor_args; //! Return a serialization of anything except pubkey and script arguments, to be prepended to those. virtual std::string ToStringExtra() const { return ""; } /** A helper function to construct the scripts for this descriptor. * - * This function is invoked once for every CScript produced by evaluating - * m_subdescriptor_arg, or just once in case m_subdescriptor_arg is nullptr. - + * This function is invoked once by ExpandHelper. + * * @param pubkeys The evaluations of the m_pubkey_args field. - * @param scripts The evaluation of m_subdescriptor_arg (or nullptr when m_subdescriptor_arg is nullptr). + * @param scripts The evaluations of m_subdescriptor_args (one for each m_subdescriptor_args element). * @param out A FlatSigningProvider to put scripts or public keys in that are necessary to the solver. - * The script arguments to this function are automatically added, as is the origin info of the provided pubkeys. + * The origin info of the provided pubkeys is automatically added. * @return A vector with scriptPubKeys for this descriptor. */ - virtual std::vector MakeScripts(const std::vector& pubkeys, const CScript* script, FlatSigningProvider& out) const = 0; + virtual std::vector MakeScripts(const std::vector& pubkeys, Span scripts, FlatSigningProvider& out) const = 0; public: - DescriptorImpl(std::vector> pubkeys, std::unique_ptr script, const std::string& name) : m_pubkey_args(std::move(pubkeys)), m_name(name), m_subdescriptor_arg(std::move(script)) {} + DescriptorImpl(std::vector> pubkeys, const std::string& name) : m_pubkey_args(std::move(pubkeys)), m_name(name), m_subdescriptor_args() {} + DescriptorImpl(std::vector> pubkeys, std::unique_ptr script, const std::string& name) : m_pubkey_args(std::move(pubkeys)), m_name(name), m_subdescriptor_args(Vector(std::move(script))) {} + + enum class StringType + { + PUBLIC, + PRIVATE, + NORMALIZED, + }; bool IsSolvable() const override { - if (m_subdescriptor_arg) { - if (!m_subdescriptor_arg->IsSolvable()) return false; + for (const auto& arg : m_subdescriptor_args) { + if (!arg->IsSolvable()) return false; } return true; } @@ -518,13 +535,25 @@ class DescriptorImpl : public Descriptor for (const auto& pubkey : m_pubkey_args) { if (pubkey->IsRange()) return true; } - if (m_subdescriptor_arg) { - if (m_subdescriptor_arg->IsRange()) return true; + for (const auto& arg : m_subdescriptor_args) { + if (arg->IsRange()) return true; } return false; } - bool ToStringHelper(const SigningProvider* arg, std::string& out, bool priv, bool normalized) const + virtual bool ToStringSubScriptHelper(const SigningProvider* arg, std::string& ret, const StringType type, const DescriptorCache* cache = nullptr) const + { + size_t pos = 0; + for (const auto& scriptarg : m_subdescriptor_args) { + if (pos++) ret += ","; + std::string tmp; + if (!scriptarg->ToStringHelper(arg, tmp, type, cache)) return false; + ret += std::move(tmp); + } + return true; + } + + bool ToStringHelper(const SigningProvider* arg, std::string& out, const StringType type, const DescriptorCache* cache = nullptr) const { std::string extra = ToStringExtra(); size_t pos = extra.size() > 0 ? 1 : 0; @@ -532,42 +561,43 @@ class DescriptorImpl : public Descriptor for (const auto& pubkey : m_pubkey_args) { if (pos++) ret += ","; std::string tmp; - if (normalized) { - if (!pubkey->ToNormalizedString(*arg, tmp, priv)) return false; - } else if (priv) { - if (!pubkey->ToPrivateString(*arg, tmp)) return false; - } else { - tmp = pubkey->ToString(); + switch (type) { + case StringType::NORMALIZED: + if (!pubkey->ToNormalizedString(*arg, tmp, cache)) return false; + break; + case StringType::PRIVATE: + if (!pubkey->ToPrivateString(*arg, tmp)) return false; + break; + case StringType::PUBLIC: + tmp = pubkey->ToString(); + break; } ret += std::move(tmp); } - if (m_subdescriptor_arg) { - if (pos++) ret += ","; - std::string tmp; - if (!m_subdescriptor_arg->ToStringHelper(arg, tmp, priv, normalized)) return false; - ret += std::move(tmp); - } - out = std::move(ret) + ")"; + std::string subscript; + if (!ToStringSubScriptHelper(arg, subscript, type, cache)) return false; + if (pos && subscript.size()) ret += ','; + out = std::move(ret) + std::move(subscript) + ")"; return true; } std::string ToString() const final { std::string ret; - ToStringHelper(nullptr, ret, false, false); + ToStringHelper(nullptr, ret, StringType::PUBLIC); return AddChecksum(ret); } bool ToPrivateString(const SigningProvider& arg, std::string& out) const override final { - bool ret = ToStringHelper(&arg, out, true, false); + bool ret = ToStringHelper(&arg, out, StringType::PRIVATE); out = AddChecksum(out); return ret; } - bool ToNormalizedString(const SigningProvider& arg, std::string& out, bool priv) const override final + bool ToNormalizedString(const SigningProvider& arg, std::string& out, const DescriptorCache* cache) const override final { - bool ret = ToStringHelper(&arg, out, priv, true); + bool ret = ToStringHelper(&arg, out, StringType::NORMALIZED, cache); out = AddChecksum(out); return ret; } @@ -577,17 +607,20 @@ class DescriptorImpl : public Descriptor std::vector> entries; entries.reserve(m_pubkey_args.size()); - // Construct temporary data in `entries` and `subscripts`, to avoid producing output in case of failure. + // Construct temporary data in `entries`, `subscripts`, and `subprovider` to avoid producing output in case of failure. for (const auto& p : m_pubkey_args) { entries.emplace_back(); if (!p->GetPubKey(pos, arg, entries.back().first, entries.back().second, read_cache, write_cache)) return false; } std::vector subscripts; - if (m_subdescriptor_arg) { - FlatSigningProvider subprovider; - if (!m_subdescriptor_arg->ExpandHelper(pos, arg, read_cache, subscripts, subprovider, write_cache)) return false; - out = Merge(out, subprovider); + FlatSigningProvider subprovider; + for (const auto& subarg : m_subdescriptor_args) { + std::vector outscripts; + if (!subarg->ExpandHelper(pos, arg, read_cache, outscripts, subprovider, write_cache)) return false; + assert(outscripts.size() == 1); + subscripts.emplace_back(std::move(outscripts[0])); } + out = Merge(std::move(out), std::move(subprovider)); std::vector pubkeys; pubkeys.reserve(entries.size()); @@ -595,17 +628,8 @@ class DescriptorImpl : public Descriptor pubkeys.push_back(entry.first); out.origins.emplace(entry.first.GetID(), std::make_pair(CPubKey(entry.first), std::move(entry.second))); } - if (m_subdescriptor_arg) { - for (const auto& subscript : subscripts) { - out.scripts.emplace(CScriptID(subscript), subscript); - std::vector addscripts = MakeScripts(pubkeys, &subscript, out); - for (auto& addscript : addscripts) { - output_scripts.push_back(std::move(addscript)); - } - } - } else { - output_scripts = MakeScripts(pubkeys, nullptr, out); - } + + output_scripts = MakeScripts(pubkeys, Span{subscripts}, out); return true; } @@ -626,10 +650,8 @@ class DescriptorImpl : public Descriptor if (!p->GetPrivKey(pos, provider, key)) continue; out.keys.emplace(key.GetPubKey().GetID(), key); } - if (m_subdescriptor_arg) { - FlatSigningProvider subprovider; - m_subdescriptor_arg->ExpandPrivate(pos, provider, subprovider); - out = Merge(out, subprovider); + for (const auto& arg : m_subdescriptor_args) { + arg->ExpandPrivate(pos, provider, out); } } std::optional GetOutputType() const override { return std::nullopt; } @@ -641,9 +663,9 @@ class AddressDescriptor final : public DescriptorImpl const CTxDestination m_destination; protected: std::string ToStringExtra() const override { return EncodeDestination(m_destination); } - std::vector MakeScripts(const std::vector&, const CScript*, FlatSigningProvider&) const override { return Vector(GetScriptForDestination(m_destination)); } + std::vector MakeScripts(const std::vector&, Span, FlatSigningProvider&) const override { return Vector(GetScriptForDestination(m_destination)); } public: - AddressDescriptor(CTxDestination destination) : DescriptorImpl({}, {}, "addr"), m_destination(std::move(destination)) {} + AddressDescriptor(CTxDestination destination) : DescriptorImpl({}, "addr"), m_destination(std::move(destination)) {} bool IsSolvable() const final { return false; } std::optional GetOutputType() const override @@ -664,9 +686,9 @@ class RawDescriptor final : public DescriptorImpl const CScript m_script; protected: std::string ToStringExtra() const override { return HexStr(m_script); } - std::vector MakeScripts(const std::vector&, const CScript*, FlatSigningProvider&) const override { return Vector(m_script); } + std::vector MakeScripts(const std::vector&, Span, FlatSigningProvider&) const override { return Vector(m_script); } public: - RawDescriptor(CScript script) : DescriptorImpl({}, {}, "raw"), m_script(std::move(script)) {} + RawDescriptor(CScript script) : DescriptorImpl({}, "raw"), m_script(std::move(script)) {} bool IsSolvable() const final { return false; } std::optional GetOutputType() const override @@ -687,9 +709,9 @@ class RawDescriptor final : public DescriptorImpl class PKDescriptor final : public DescriptorImpl { protected: - std::vector MakeScripts(const std::vector& keys, const CScript*, FlatSigningProvider&) const override { return Vector(GetScriptForRawPubKey(keys[0])); } + std::vector MakeScripts(const std::vector& keys, Span, FlatSigningProvider&) const override { return Vector(GetScriptForRawPubKey(keys[0])); } public: - PKDescriptor(std::unique_ptr prov) : DescriptorImpl(Vector(std::move(prov)), {}, "pk") {} + PKDescriptor(std::unique_ptr prov) : DescriptorImpl(Vector(std::move(prov)), "pk") {} std::optional GetOutputType() const override { return OutputType::LEGACY; } bool IsSingleType() const final { return true; } }; @@ -698,14 +720,14 @@ class PKDescriptor final : public DescriptorImpl class PKHDescriptor final : public DescriptorImpl { protected: - std::vector MakeScripts(const std::vector& keys, const CScript*, FlatSigningProvider& out) const override + std::vector MakeScripts(const std::vector& keys, Span, FlatSigningProvider& out) const override { CKeyID id = keys[0].GetID(); out.pubkeys.emplace(id, keys[0]); return Vector(GetScriptForDestination(PKHash(id))); } public: - PKHDescriptor(std::unique_ptr prov) : DescriptorImpl(Vector(std::move(prov)), {}, "pkh") {} + PKHDescriptor(std::unique_ptr prov) : DescriptorImpl(Vector(std::move(prov)), "pkh") {} std::optional GetOutputType() const override { return OutputType::LEGACY; } bool IsSingleType() const final { return true; } }; @@ -717,7 +739,7 @@ class MultisigDescriptor final : public DescriptorImpl const bool m_sorted; protected: std::string ToStringExtra() const override { return strprintf("%i", m_threshold); } - std::vector MakeScripts(const std::vector& keys, const CScript*, FlatSigningProvider&) const override { + std::vector MakeScripts(const std::vector& keys, Span, FlatSigningProvider&) const override { if (m_sorted) { std::vector sorted_keys(keys); std::sort(sorted_keys.begin(), sorted_keys.end()); @@ -726,7 +748,7 @@ class MultisigDescriptor final : public DescriptorImpl return Vector(GetScriptForMultisig(m_threshold, keys)); } public: - MultisigDescriptor(int threshold, std::vector> providers, bool sorted = false) : DescriptorImpl(std::move(providers), {}, sorted ? "sortedmulti" : "multi"), m_threshold(threshold), m_sorted(sorted) {} + MultisigDescriptor(int threshold, std::vector> providers, bool sorted = false) : DescriptorImpl(std::move(providers), sorted ? "sortedmulti" : "multi"), m_threshold(threshold), m_sorted(sorted) {} bool IsSingleType() const final { return true; } }; @@ -734,13 +756,18 @@ class MultisigDescriptor final : public DescriptorImpl class SHDescriptor final : public DescriptorImpl { protected: - std::vector MakeScripts(const std::vector&, const CScript* script, FlatSigningProvider&) const override { return Vector(GetScriptForDestination(ScriptHash(*script))); } + std::vector MakeScripts(const std::vector&, Span scripts, FlatSigningProvider& out) const override + { + auto ret = Vector(GetScriptForDestination(ScriptHash(scripts[0]))); + if (ret.size()) out.scripts.emplace(CScriptID(scripts[0]), scripts[0]); + return ret; + } public: SHDescriptor(std::unique_ptr desc) : DescriptorImpl({}, std::move(desc), "sh") {} std::optional GetOutputType() const override { - assert(m_subdescriptor_arg); + assert(m_subdescriptor_args.size() == 1); return OutputType::LEGACY; } bool IsSingleType() const final { return true; } @@ -750,7 +777,7 @@ class SHDescriptor final : public DescriptorImpl class ComboDescriptor final : public DescriptorImpl { protected: - std::vector MakeScripts(const std::vector& keys, const CScript*, FlatSigningProvider& out) const override + std::vector MakeScripts(const std::vector& keys, Span scripts, FlatSigningProvider& out) const override { std::vector ret; CKeyID id = keys[0].GetID(); @@ -766,7 +793,7 @@ class ComboDescriptor final : public DescriptorImpl } public: - ComboDescriptor(std::unique_ptr prov) : DescriptorImpl(Vector(std::move(prov)), {}, "combo") {} + ComboDescriptor(std::unique_ptr prov) : DescriptorImpl(Vector(std::move(prov)), "combo") {} std::optional GetOutputType() const override { return OutputType::LEGACY; } bool IsSingleType() const final { return false; } }; @@ -776,8 +803,8 @@ class ComboDescriptor final : public DescriptorImpl //////////////////////////////////////////////////////////////////////////// enum class ParseScriptContext { - TOP, - P2SH + TOP, //!< Top-level context (script goes directly in scriptPubKey) + P2SH, //!< Inside sh() (script becomes P2SH redeemScript) }; /** Parse a key path, being passed a split list of elements (the first element is ignored). */ @@ -804,10 +831,11 @@ enum class ParseScriptContext { } /** Parse a public key that excludes origin information. */ -std::unique_ptr ParsePubkeyInner(uint32_t key_exp_index, const Span& sp, bool permit_uncompressed, FlatSigningProvider& out, std::string& error) +std::unique_ptr ParsePubkeyInner(uint32_t key_exp_index, const Span& sp, ParseScriptContext ctx, FlatSigningProvider& out, std::string& error) { using namespace spanparsing; + bool permit_uncompressed = ctx == ParseScriptContext::TOP || ctx == ParseScriptContext::P2SH; auto split = Split(sp, '/'); std::string str(split[0].begin(), split[0].end()); if (str.size() == 0) { @@ -865,7 +893,7 @@ std::unique_ptr ParsePubkeyInner(uint32_t key_exp_index, const S } /** Parse a public key including origin information (if enabled). */ -std::unique_ptr ParsePubkey(uint32_t key_exp_index, const Span& sp, bool permit_uncompressed, FlatSigningProvider& out, std::string& error) +std::unique_ptr ParsePubkey(uint32_t key_exp_index, const Span& sp, ParseScriptContext ctx, FlatSigningProvider& out, std::string& error) { using namespace spanparsing; @@ -874,7 +902,7 @@ std::unique_ptr ParsePubkey(uint32_t key_exp_index, const Span ParsePubkey(uint32_t key_exp_index, const Span(key_exp_index, std::move(info), std::move(provider)); } /** Parse a script in a particular context. */ -std::unique_ptr ParseScript(uint32_t key_exp_index, Span& sp, ParseScriptContext ctx, FlatSigningProvider& out, std::string& error) +std::unique_ptr ParseScript(uint32_t& key_exp_index, Span& sp, ParseScriptContext ctx, FlatSigningProvider& out, std::string& error) { using namespace spanparsing; auto expr = Expr(sp); bool sorted_multi = false; if (Func("pk", expr)) { - auto pubkey = ParsePubkey(key_exp_index, expr, true, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, ctx, out, error); if (!pubkey) return nullptr; + ++key_exp_index; return std::make_unique(std::move(pubkey)); } if (Func("pkh", expr)) { - auto pubkey = ParsePubkey(key_exp_index, expr, true, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, ctx, out, error); if (!pubkey) return nullptr; + ++key_exp_index; return std::make_unique(std::move(pubkey)); } if (ctx == ParseScriptContext::TOP && Func("combo", expr)) { - auto pubkey = ParsePubkey(key_exp_index, expr, true, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, ctx, out, error); if (!pubkey) return nullptr; + ++key_exp_index; return std::make_unique(std::move(pubkey)); - } else if (ctx != ParseScriptContext::TOP && Func("combo", expr)) { - error = "Cannot have combo in non-top level"; + } else if (Func("combo", expr)) { + error = "Can only have combo() at top level"; return nullptr; } if ((sorted_multi = Func("sortedmulti", expr)) || Func("multi", expr)) { @@ -941,7 +972,7 @@ std::unique_ptr ParseScript(uint32_t key_exp_index, SpanGetSize() + 1; providers.emplace_back(std::move(pk)); @@ -975,8 +1006,8 @@ std::unique_ptr ParseScript(uint32_t key_exp_index, Span(std::move(desc)); - } else if (ctx != ParseScriptContext::TOP && Func("sh", expr)) { - error = "Cannot have sh in non-top level"; + } else if (Func("sh", expr)) { + error = "Can only have sh() at top level"; return nullptr; } if (ctx == ParseScriptContext::TOP && Func("addr", expr)) { @@ -986,6 +1017,9 @@ std::unique_ptr ParseScript(uint32_t key_exp_index, Span(std::move(dest)); + } else if (Func("addr", expr)) { + error = "Can only have addr() at top level"; + return nullptr; } if (ctx == ParseScriptContext::TOP && Func("raw", expr)) { std::string str(expr.begin(), expr.end()); @@ -995,6 +1029,9 @@ std::unique_ptr ParseScript(uint32_t key_exp_index, Span(CScript(bytes.begin(), bytes.end())); + } else if (Func("raw", expr)) { + error = "Can only have raw() at top level"; + return nullptr; } if (ctx == ParseScriptContext::P2SH) { error = "A function is needed within P2SH"; @@ -1104,7 +1141,8 @@ std::unique_ptr Parse(const std::string& descriptor, FlatSigningProv { Span sp{descriptor}; if (!CheckChecksum(sp, require_checksum, error)) return nullptr; - auto ret = ParseScript(0, sp, ParseScriptContext::TOP, out, error); + uint32_t key_exp_index = 0; + auto ret = ParseScript(key_exp_index, sp, ParseScriptContext::TOP, out, error); if (sp.size() == 0 && ret) return std::unique_ptr(std::move(ret)); return nullptr; } @@ -1134,6 +1172,11 @@ void DescriptorCache::CacheDerivedExtPubKey(uint32_t key_exp_pos, uint32_t der_i xpubs[der_index] = xpub; } +void DescriptorCache::CacheLastHardenedExtPubKey(uint32_t key_exp_pos, const CExtPubKey& xpub) +{ + m_last_hardened_xpubs[key_exp_pos] = xpub; +} + bool DescriptorCache::GetCachedParentExtPubKey(uint32_t key_exp_pos, CExtPubKey& xpub) const { const auto& it = m_parent_xpubs.find(key_exp_pos); @@ -1152,6 +1195,55 @@ bool DescriptorCache::GetCachedDerivedExtPubKey(uint32_t key_exp_pos, uint32_t d return true; } +bool DescriptorCache::GetCachedLastHardenedExtPubKey(uint32_t key_exp_pos, CExtPubKey& xpub) const +{ + const auto& it = m_last_hardened_xpubs.find(key_exp_pos); + if (it == m_last_hardened_xpubs.end()) return false; + xpub = it->second; + return true; +} + +DescriptorCache DescriptorCache::MergeAndDiff(const DescriptorCache& other) +{ + DescriptorCache diff; + for (const auto& parent_xpub_pair : other.GetCachedParentExtPubKeys()) { + CExtPubKey xpub; + if (GetCachedParentExtPubKey(parent_xpub_pair.first, xpub)) { + if (xpub != parent_xpub_pair.second) { + throw std::runtime_error(std::string(__func__) + ": New cached parent xpub does not match already cached parent xpub"); + } + continue; + } + CacheParentExtPubKey(parent_xpub_pair.first, parent_xpub_pair.second); + diff.CacheParentExtPubKey(parent_xpub_pair.first, parent_xpub_pair.second); + } + for (const auto& derived_xpub_map_pair : other.GetCachedDerivedExtPubKeys()) { + for (const auto& derived_xpub_pair : derived_xpub_map_pair.second) { + CExtPubKey xpub; + if (GetCachedDerivedExtPubKey(derived_xpub_map_pair.first, derived_xpub_pair.first, xpub)) { + if (xpub != derived_xpub_pair.second) { + throw std::runtime_error(std::string(__func__) + ": New cached derived xpub does not match already cached derived xpub"); + } + continue; + } + CacheDerivedExtPubKey(derived_xpub_map_pair.first, derived_xpub_pair.first, derived_xpub_pair.second); + diff.CacheDerivedExtPubKey(derived_xpub_map_pair.first, derived_xpub_pair.first, derived_xpub_pair.second); + } + } + for (const auto& lh_xpub_pair : other.GetCachedLastHardenedExtPubKeys()) { + CExtPubKey xpub; + if (GetCachedLastHardenedExtPubKey(lh_xpub_pair.first, xpub)) { + if (xpub != lh_xpub_pair.second) { + throw std::runtime_error(std::string(__func__) + ": New cached last hardened xpub does not match already cached last hardened xpub"); + } + continue; + } + CacheLastHardenedExtPubKey(lh_xpub_pair.first, lh_xpub_pair.second); + diff.CacheLastHardenedExtPubKey(lh_xpub_pair.first, lh_xpub_pair.second); + } + return diff; +} + const ExtPubKeyMap DescriptorCache::GetCachedParentExtPubKeys() const { return m_parent_xpubs; @@ -1161,3 +1253,8 @@ const std::unordered_map DescriptorCache::GetCachedDeriv { return m_derived_xpubs; } + +const ExtPubKeyMap DescriptorCache::GetCachedLastHardenedExtPubKeys() const +{ + return m_last_hardened_xpubs; +} diff --git a/src/script/descriptor.h b/src/script/descriptor.h index 28df1ed4df..f5bbf487c6 100644 --- a/src/script/descriptor.h +++ b/src/script/descriptor.h @@ -22,6 +22,8 @@ class DescriptorCache { std::unordered_map m_derived_xpubs; /** Map key expression index -> parent xpub */ ExtPubKeyMap m_parent_xpubs; + /** Map key expression index -> last hardened xpub */ + ExtPubKeyMap m_last_hardened_xpubs; public: /** Cache a parent xpub @@ -50,11 +52,30 @@ class DescriptorCache { * @param[in] xpub The CExtPubKey to get from cache */ bool GetCachedDerivedExtPubKey(uint32_t key_exp_pos, uint32_t der_index, CExtPubKey& xpub) const; + /** Cache a last hardened xpub + * + * @param[in] key_exp_pos Position of the key expression within the descriptor + * @param[in] xpub The CExtPubKey to cache + */ + void CacheLastHardenedExtPubKey(uint32_t key_exp_pos, const CExtPubKey& xpub); + /** Retrieve a cached last hardened xpub + * + * @param[in] key_exp_pos Position of the key expression within the descriptor + * @param[in] xpub The CExtPubKey to get from cache + */ + bool GetCachedLastHardenedExtPubKey(uint32_t key_exp_pos, CExtPubKey& xpub) const; /** Retrieve all cached parent xpubs */ const ExtPubKeyMap GetCachedParentExtPubKeys() const; /** Retrieve all cached derived xpubs */ const std::unordered_map GetCachedDerivedExtPubKeys() const; + /** Retrieve all cached last hardened xpubs */ + const ExtPubKeyMap GetCachedLastHardenedExtPubKeys() const; + + /** Combine another DescriptorCache into this one. + * Returns a cache containing the items from the other cache unknown to current cache + */ + DescriptorCache MergeAndDiff(const DescriptorCache& other); }; /** \brief Interface for parsed descriptor objects. @@ -94,7 +115,7 @@ struct Descriptor { virtual bool ToPrivateString(const SigningProvider& provider, std::string& out) const = 0; /** Convert the descriptor to a normalized string. Normalized descriptors have the xpub at the last hardened step. This fails if the provided provider does not have the private keys to derive that xpub. */ - virtual bool ToNormalizedString(const SigningProvider& provider, std::string& out, bool priv) const = 0; + virtual bool ToNormalizedString(const SigningProvider& provider, std::string& out, const DescriptorCache* cache = nullptr) const = 0; /** Expand a descriptor at a specified position. * diff --git a/src/sync.h b/src/sync.h index 2a046a78d8..c75b6554cf 100644 --- a/src/sync.h +++ b/src/sync.h @@ -77,8 +77,6 @@ inline void AssertLockNotHeldInternal(const char* pszName, const char* pszFile, inline void DeleteLock(void* cs) {} inline bool LockStackEmpty() { return true; } #endif -#define AssertLockHeld(cs) AssertLockHeldInternal(#cs, __FILE__, __LINE__, &cs) -#define AssertLockNotHeld(cs) AssertLockNotHeldInternal(#cs, __FILE__, __LINE__, &cs) /** * Template mixin that adds -Wthread-safety locking annotations and lock order @@ -138,10 +136,18 @@ class LOCKABLE SharedAnnotatedMixin : public AnnotatedMixin using RecursiveMutex = AnnotatedMixin; /** Wrapped mutex: supports waiting but not recursive locking */ -typedef AnnotatedMixin Mutex; +using Mutex = AnnotatedMixin; + /** Wrapped shared mutex: supports read locking via .shared_lock, exlusive locking via .lock; * does not support recursive locking */ -typedef SharedAnnotatedMixin SharedMutex; +using SharedMutex = SharedAnnotatedMixin; + +#define AssertLockHeld(cs) AssertLockHeldInternal(#cs, __FILE__, __LINE__, &cs) + +inline void AssertLockNotHeldInline(const char* name, const char* file, int line, Mutex* cs) EXCLUSIVE_LOCKS_REQUIRED(!cs) { AssertLockNotHeldInternal(name, file, line, cs); } +inline void AssertLockNotHeldInline(const char* name, const char* file, int line, RecursiveMutex* cs) LOCKS_EXCLUDED(cs) { AssertLockNotHeldInternal(name, file, line, cs); } +inline void AssertLockNotHeldInline(const char* name, const char* file, int line, SharedMutex* cs) LOCKS_EXCLUDED(cs) { AssertLockNotHeldInternal(name, file, line, cs); } +#define AssertLockNotHeld(cs) AssertLockNotHeldInline(#cs, __FILE__, __LINE__, &cs) /** Prints a lock contention to the log */ void LockContention(const char* pszName, const char* pszFile, int nLine); diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp index fb1675e20d..c45d20be37 100644 --- a/src/test/denialofservice_tests.cpp +++ b/src/test/denialofservice_tests.cpp @@ -74,7 +74,7 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) // Mock an outbound peer CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK), INVALID_SOCKET, addr1, 0, 0, CAddress(), "", ConnectionType::OUTBOUND_FULL_RELAY); + CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK), INVALID_SOCKET, addr1, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress(), /* pszDest */ "", ConnectionType::OUTBOUND_FULL_RELAY, /* inbound_onion */ false); dummyNode1.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode1); @@ -124,7 +124,7 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) static void AddRandomOutboundPeer(std::vector& vNodes, PeerManager& peerLogic, ConnmanTestMsg& connman) { CAddress addr(ip(g_insecure_rand_ctx.randbits(32)), NODE_NONE); - vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK), INVALID_SOCKET, addr, 0, 0, CAddress(), "", ConnectionType::OUTBOUND_FULL_RELAY)); + vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK), INVALID_SOCKET, addr, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress(), /* pszDest */ "", ConnectionType::OUTBOUND_FULL_RELAY, /* inbound_onion */ false)); CNode &node = *vNodes.back(); node.SetCommonVersion(PROTOCOL_VERSION); @@ -220,7 +220,7 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) banman->ClearBanned(); CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, NODE_NETWORK, INVALID_SOCKET, addr1, 0, 0, CAddress(), "", ConnectionType::INBOUND); + CNode dummyNode1(id++, NODE_NETWORK, INVALID_SOCKET, addr1, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress(), /* pszDest */ "", ConnectionType::INBOUND, /* inbound_onion */ false); dummyNode1.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode1); dummyNode1.fSuccessfullyConnected = true; @@ -233,7 +233,7 @@ BOOST_AUTO_TEST_CASE(peer_discouragement) BOOST_CHECK(!banman->IsDiscouraged(ip(0xa0b0c001|0x0000ff00))); // Different IP, not discouraged CAddress addr2(ip(0xa0b0c002), NODE_NONE); - CNode dummyNode2(id++, NODE_NETWORK, INVALID_SOCKET, addr2, 1, 1, CAddress(), "", ConnectionType::INBOUND); + CNode dummyNode2(id++, NODE_NETWORK, INVALID_SOCKET, addr2, /* nKeyedNetGroupIn */ 1, /* nLocalHostNonceIn */ 1, CAddress(), /* pszDest */ "", ConnectionType::INBOUND, /* inbound_onion */ false); dummyNode2.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode2); dummyNode2.fSuccessfullyConnected = true; @@ -271,7 +271,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) SetMockTime(nStartTime); // Overrides future calls to GetTime() CAddress addr(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode(id++, NODE_NETWORK, INVALID_SOCKET, addr, 4, 4, CAddress(), "", ConnectionType::INBOUND); + CNode dummyNode(id++, NODE_NETWORK, INVALID_SOCKET, addr, /* nKeyedNetGroupIn */ 4, /* nLocalHostNonceIn */ 4, CAddress(), /* pszDest */ "", ConnectionType::INBOUND, /* inbound_onion */ false); dummyNode.SetCommonVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode); dummyNode.fSuccessfullyConnected = true; diff --git a/src/test/descriptor_tests.cpp b/src/test/descriptor_tests.cpp index 2324610754..8d0c812b6f 100644 --- a/src/test/descriptor_tests.cpp +++ b/src/test/descriptor_tests.cpp @@ -24,7 +24,7 @@ void CheckUnparsable(const std::string& prv, const std::string& pub, const std:: auto parse_pub = Parse(pub, keys_pub, error); BOOST_CHECK_MESSAGE(!parse_priv, prv); BOOST_CHECK_MESSAGE(!parse_pub, pub); - BOOST_CHECK(error == expected_error); + BOOST_CHECK_EQUAL(error, expected_error); } constexpr int DEFAULT = 0; @@ -115,14 +115,10 @@ void DoCheck(const std::string& prv, const std::string& pub, const std::string& // Check that private can produce the normalized descriptors std::string norm1; - BOOST_CHECK(parse_priv->ToNormalizedString(keys_priv, norm1, false)); + BOOST_CHECK(parse_priv->ToNormalizedString(keys_priv, norm1)); BOOST_CHECK(EqualDescriptor(norm1, norm_pub)); - BOOST_CHECK(parse_pub->ToNormalizedString(keys_priv, norm1, false)); + BOOST_CHECK(parse_pub->ToNormalizedString(keys_priv, norm1)); BOOST_CHECK(EqualDescriptor(norm1, norm_pub)); - BOOST_CHECK(parse_priv->ToNormalizedString(keys_priv, norm1, true)); - BOOST_CHECK(EqualDescriptor(norm1, norm_prv)); - BOOST_CHECK(parse_pub->ToNormalizedString(keys_priv, norm1, true)); - BOOST_CHECK(EqualDescriptor(norm1, norm_prv)); // Check whether IsRange on both returns the expected result BOOST_CHECK_EQUAL(parse_pub->IsRange(), (flags & RANGE) != 0); @@ -335,9 +331,8 @@ BOOST_AUTO_TEST_CASE(descriptor_test) // Check for invalid nesting of structures CheckUnparsable("sh(XJvEUEcFWCHCyruc8ZX5exPZaGe4UR7gC5FHrhwPnQGDs1uWCsT2)", "sh(03a34b99f22c790c4e36b2b3c2c35a36db06226e41c692fc82b8b56ac1c540c5bd)", "A function is needed within P2SH"); // P2SH needs a script, not a key - CheckUnparsable("sh(sh(pk(XJvEUEcFWCHCyruc8ZX5exPZaGe4UR7gC5FHrhwPnQGDs1uWCsT2)))", "sh(sh(pk(03a34b99f22c790c4e36b2b3c2c35a36db06226e41c692fc82b8b56ac1c540c5bd)))", "Cannot have sh in non-top level"); // Cannot embed P2SH inside P2SH - CheckUnparsable("sh(combo(XJvEUEcFWCHCyruc8ZX5exPZaGe4UR7gC5FHrhwPnQGDs1uWCsT2))", "sh(combo(03a34b99f22c790c4e36b2b3c2c35a36db06226e41c692fc82b8b56ac1c540c5bd))", "Cannot have combo in non-top level"); // Old must be top level - + CheckUnparsable("sh(sh(pk(XJvEUEcFWCHCyruc8ZX5exPZaGe4UR7gC5FHrhwPnQGDs1uWCsT2)))", "sh(sh(pk(03a34b99f22c790c4e36b2b3c2c35a36db06226e41c692fc82b8b56ac1c540c5bd)))", "Can only have sh() at top level"); // Cannot embed P2SH inside P2SH + CheckUnparsable("sh(combo(XJvEUEcFWCHCyruc8ZX5exPZaGe4UR7gC5FHrhwPnQGDs1uWCsT2))", "sh(combo(03a34b99f22c790c4e36b2b3c2c35a36db06226e41c692fc82b8b56ac1c540c5bd))", "Can only have combo() at top level"); // Old must be top level // Checksums Check("sh(multi(2,[00000000/111'/222]xprvA1RpRA33e1JQ7ifknakTFpgNXPmW2YvmhqLQYMmrj4xJXXWYpDPS3xz7iAxn8L39njGVyuoseXzU6rcxFLJ8HFsTjSyQbLYnMpCqE2VbFWc,xprv9uPDJpEQgRQfDcW7BkF7eTya6RPxXeJCqCJGHuCJ4GiRVLzkTXBAJMu2qaMWPrS7AANYqdq6vcBcBUdJCVVFceUvJFjaPdGZ2y9WACViL4L/0))#ggrsrxfy", "sh(multi(2,[00000000/111'/222]xpub6ERApfZwUNrhLCkDtcHTcxd75RbzS1ed54G1LkBUHQVHQKqhMkhgbmJbZRkrgZw4koxb5JaHWkY4ALHY2grBGRjaDMzQLcgJvLJuZZvRcEL,xpub68NZiKmJWnxxS6aaHmn81bvJeTESw724CRDs6HbuccFQN9Ku14VQrADWgqbhhTHBaohPX4CjNLf9fq9MYo6oDaPPLPxSb7gwQN3ih19Zm4Y/0))#tjg09x5t", "sh(multi(2,[00000000/111'/222]xprvA1RpRA33e1JQ7ifknakTFpgNXPmW2YvmhqLQYMmrj4xJXXWYpDPS3xz7iAxn8L39njGVyuoseXzU6rcxFLJ8HFsTjSyQbLYnMpCqE2VbFWc,xprv9uPDJpEQgRQfDcW7BkF7eTya6RPxXeJCqCJGHuCJ4GiRVLzkTXBAJMu2qaMWPrS7AANYqdq6vcBcBUdJCVVFceUvJFjaPdGZ2y9WACViL4L/0))#ggrsrxfy", "sh(multi(2,[00000000/111'/222]xpub6ERApfZwUNrhLCkDtcHTcxd75RbzS1ed54G1LkBUHQVHQKqhMkhgbmJbZRkrgZw4koxb5JaHWkY4ALHY2grBGRjaDMzQLcgJvLJuZZvRcEL,xpub68NZiKmJWnxxS6aaHmn81bvJeTESw724CRDs6HbuccFQN9Ku14VQrADWgqbhhTHBaohPX4CjNLf9fq9MYo6oDaPPLPxSb7gwQN3ih19Zm4Y/0))#tjg09x5t", DEFAULT, {{"a91445a9a622a8b0a1269944be477640eedc447bbd8487"}}, OutputType::LEGACY, {{0x8000006FUL,222},{0}}); Check("sh(multi(2,[00000000/111'/222]xprvA1RpRA33e1JQ7ifknakTFpgNXPmW2YvmhqLQYMmrj4xJXXWYpDPS3xz7iAxn8L39njGVyuoseXzU6rcxFLJ8HFsTjSyQbLYnMpCqE2VbFWc,xprv9uPDJpEQgRQfDcW7BkF7eTya6RPxXeJCqCJGHuCJ4GiRVLzkTXBAJMu2qaMWPrS7AANYqdq6vcBcBUdJCVVFceUvJFjaPdGZ2y9WACViL4L/0))", "sh(multi(2,[00000000/111'/222]xpub6ERApfZwUNrhLCkDtcHTcxd75RbzS1ed54G1LkBUHQVHQKqhMkhgbmJbZRkrgZw4koxb5JaHWkY4ALHY2grBGRjaDMzQLcgJvLJuZZvRcEL,xpub68NZiKmJWnxxS6aaHmn81bvJeTESw724CRDs6HbuccFQN9Ku14VQrADWgqbhhTHBaohPX4CjNLf9fq9MYo6oDaPPLPxSb7gwQN3ih19Zm4Y/0))", "sh(multi(2,[00000000/111'/222]xprvA1RpRA33e1JQ7ifknakTFpgNXPmW2YvmhqLQYMmrj4xJXXWYpDPS3xz7iAxn8L39njGVyuoseXzU6rcxFLJ8HFsTjSyQbLYnMpCqE2VbFWc,xprv9uPDJpEQgRQfDcW7BkF7eTya6RPxXeJCqCJGHuCJ4GiRVLzkTXBAJMu2qaMWPrS7AANYqdq6vcBcBUdJCVVFceUvJFjaPdGZ2y9WACViL4L/0))", "sh(multi(2,[00000000/111'/222]xpub6ERApfZwUNrhLCkDtcHTcxd75RbzS1ed54G1LkBUHQVHQKqhMkhgbmJbZRkrgZw4koxb5JaHWkY4ALHY2grBGRjaDMzQLcgJvLJuZZvRcEL,xpub68NZiKmJWnxxS6aaHmn81bvJeTESw724CRDs6HbuccFQN9Ku14VQrADWgqbhhTHBaohPX4CjNLf9fq9MYo6oDaPPLPxSb7gwQN3ih19Zm4Y/0))", DEFAULT, {{"a91445a9a622a8b0a1269944be477640eedc447bbd8487"}}, OutputType::LEGACY, {{0x8000006FUL,222},{0}}); diff --git a/src/test/fuzz/net.cpp b/src/test/fuzz/net.cpp index 84d6834af8..4d18a81754 100644 --- a/src/test/fuzz/net.cpp +++ b/src/test/fuzz/net.cpp @@ -40,9 +40,6 @@ FUZZ_TARGET_INIT(net, initialize_net) CConnman connman{fuzzed_data_provider.ConsumeIntegral(), fuzzed_data_provider.ConsumeIntegral(), addrman}; node.CloseSocketDisconnect(&connman); }, - [&] { - node.MaybeSetAddrName(fuzzed_data_provider.ConsumeRandomLengthString(32)); - }, [&] { const std::vector asmap = ConsumeRandomLengthBitVector(fuzzed_data_provider); if (!SanityCheckASMap(asmap)) { @@ -75,7 +72,6 @@ FUZZ_TARGET_INIT(net, initialize_net) } (void)node.GetAddrLocal(); - (void)node.GetAddrName(); (void)node.GetId(); (void)node.GetLocalNonce(); (void)node.GetLocalServices(); diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index 68c266694a..29b714c584 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -194,14 +194,15 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) id++, NODE_NETWORK, hSocket, addr, /* nKeyedNetGroupIn = */ 0, /* nLocalHostNonceIn = */ 0, - CAddress(), pszDest, ConnectionType::OUTBOUND_FULL_RELAY); + CAddress(), pszDest, ConnectionType::OUTBOUND_FULL_RELAY, + /* inbound_onion = */ false); BOOST_CHECK(pnode1->IsFullOutboundConn() == true); BOOST_CHECK(pnode1->IsManualConn() == false); BOOST_CHECK(pnode1->IsBlockOnlyConn() == false); BOOST_CHECK(pnode1->IsFeelerConn() == false); BOOST_CHECK(pnode1->IsAddrFetchConn() == false); BOOST_CHECK(pnode1->IsInboundConn() == false); - BOOST_CHECK(pnode1->IsInboundOnion() == false); + BOOST_CHECK(pnode1->m_inbound_onion == false); BOOST_CHECK_EQUAL(pnode1->ConnectedThroughNetwork(), Network::NET_IPV4); std::unique_ptr pnode2 = std::make_unique( @@ -216,7 +217,7 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK(pnode2->IsFeelerConn() == false); BOOST_CHECK(pnode2->IsAddrFetchConn() == false); BOOST_CHECK(pnode2->IsInboundConn() == true); - BOOST_CHECK(pnode2->IsInboundOnion() == false); + BOOST_CHECK(pnode2->m_inbound_onion == false); BOOST_CHECK_EQUAL(pnode2->ConnectedThroughNetwork(), Network::NET_IPV4); std::unique_ptr pnode3 = std::make_unique( @@ -231,7 +232,7 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK(pnode3->IsFeelerConn() == false); BOOST_CHECK(pnode3->IsAddrFetchConn() == false); BOOST_CHECK(pnode3->IsInboundConn() == false); - BOOST_CHECK(pnode3->IsInboundOnion() == false); + BOOST_CHECK(pnode3->m_inbound_onion == false); BOOST_CHECK_EQUAL(pnode3->ConnectedThroughNetwork(), Network::NET_IPV4); std::unique_ptr pnode4 = std::make_unique( @@ -246,7 +247,7 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) BOOST_CHECK(pnode4->IsFeelerConn() == false); BOOST_CHECK(pnode4->IsAddrFetchConn() == false); BOOST_CHECK(pnode4->IsInboundConn() == true); - BOOST_CHECK(pnode4->IsInboundOnion() == true); + BOOST_CHECK(pnode4->m_inbound_onion == true); BOOST_CHECK_EQUAL(pnode4->ConnectedThroughNetwork(), Network::NET_ONION); } @@ -740,7 +741,7 @@ BOOST_AUTO_TEST_CASE(ipv4_peer_with_ipv6_addrMe_test) in_addr ipv4AddrPeer; ipv4AddrPeer.s_addr = 0xa0b0c001; CAddress addr = CAddress(CService(ipv4AddrPeer, 7777), NODE_NETWORK); - std::unique_ptr pnode = std::make_unique(0, NODE_NETWORK, INVALID_SOCKET, addr, 0, 0, CAddress{}, std::string{}, ConnectionType::OUTBOUND_FULL_RELAY); + std::unique_ptr pnode = std::make_unique(0, NODE_NETWORK, INVALID_SOCKET, addr, /* nKeyedNetGroupIn */ 0, /* nLocalHostNonceIn */ 0, CAddress{}, /* pszDest */ std::string{}, ConnectionType::OUTBOUND_FULL_RELAY, /* inbound_onion */ false); pnode->fSuccessfullyConnected.store(true); // the peer claims to be reaching us via IPv6 diff --git a/src/test/util/net.h b/src/test/util/net.h index b27364fee8..26a67a3e9b 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -23,16 +23,16 @@ struct ConnmanTestMsg : public CConnman { void AddTestNode(CNode& node) { - LOCK(cs_vNodes); - vNodes.push_back(&node); + LOCK(m_nodes_mutex); + m_nodes.push_back(&node); } void ClearTestNodes() { - LOCK(cs_vNodes); - for (CNode* node : vNodes) { + LOCK(m_nodes_mutex); + for (CNode* node : m_nodes) { delete node; } - vNodes.clear(); + m_nodes.clear(); } void ProcessMessagesOnce(CNode& node) { m_msgproc->ProcessMessages(&node, flagInterruptMsgProc); } diff --git a/src/threadinterrupt.h b/src/threadinterrupt.h index c60ed3b336..c053c01499 100644 --- a/src/threadinterrupt.h +++ b/src/threadinterrupt.h @@ -22,7 +22,7 @@ class CThreadInterrupt using Clock = std::chrono::steady_clock; CThreadInterrupt(); explicit operator bool() const; - void operator()(); + void operator()() EXCLUSIVE_LOCKS_REQUIRED(!mut); void reset(); bool sleep_for(Clock::duration rel_time) EXCLUSIVE_LOCKS_REQUIRED(!mut); diff --git a/src/validationinterface.cpp b/src/validationinterface.cpp index cbee7d7cb8..ac8b19dcdc 100644 --- a/src/validationinterface.cpp +++ b/src/validationinterface.cpp @@ -47,7 +47,7 @@ struct MainSignalsInstance { explicit MainSignalsInstance(CScheduler& scheduler LIFETIMEBOUND) : m_schedulerClient(scheduler) {} - void Register(std::shared_ptr callbacks) + void Register(std::shared_ptr callbacks) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { LOCK(m_mutex); auto inserted = m_map.emplace(callbacks.get(), m_list.end()); @@ -55,7 +55,7 @@ struct MainSignalsInstance { inserted.first->second->callbacks = std::move(callbacks); } - void Unregister(CValidationInterface* callbacks) + void Unregister(CValidationInterface* callbacks) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { LOCK(m_mutex); auto it = m_map.find(callbacks); @@ -69,7 +69,7 @@ struct MainSignalsInstance { //! map entry. After this call, the list may still contain callbacks that //! are currently executing, but it will be cleared when they are done //! executing. - void Clear() + void Clear() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { LOCK(m_mutex); for (const auto& entry : m_map) { @@ -78,7 +78,7 @@ struct MainSignalsInstance { m_map.clear(); } - template void Iterate(F&& f) + template void Iterate(F&& f) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex) { WAIT_LOCK(m_mutex, lock); for (auto it = m_list.begin(); it != m_list.end();) { diff --git a/src/versionbits.h b/src/versionbits.h index c02769809e..a02fc8cce5 100644 --- a/src/versionbits.h +++ b/src/versionbits.h @@ -90,16 +90,16 @@ class VersionBitsCache static uint32_t Mask(const Consensus::Params& params, Consensus::DeploymentPos pos); /** Get the BIP9 state for a given deployment for the block after pindexPrev. */ - ThresholdState State(const CBlockIndex* pindexPrev, const Consensus::Params& params, Consensus::DeploymentPos pos); + ThresholdState State(const CBlockIndex* pindexPrev, const Consensus::Params& params, Consensus::DeploymentPos pos) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); /** Get the block height at which the BIP9 deployment switched into the state for the block after pindexPrev. */ - int StateSinceHeight(const CBlockIndex* pindexPrev, const Consensus::Params& params, Consensus::DeploymentPos pos); + int StateSinceHeight(const CBlockIndex* pindexPrev, const Consensus::Params& params, Consensus::DeploymentPos pos) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); /** Determine what nVersion a new block should use */ - int32_t ComputeBlockVersion(const CBlockIndex* pindexPrev, const Consensus::Params& params); + int32_t ComputeBlockVersion(const CBlockIndex* pindexPrev, const Consensus::Params& params) EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); - void Clear(); + void Clear() EXCLUSIVE_LOCKS_REQUIRED(!m_mutex); }; class AbstractEHFManager diff --git a/src/wallet/rpcdump.cpp b/src/wallet/rpcdump.cpp index 230a6fb2c7..55488c7080 100644 --- a/src/wallet/rpcdump.cpp +++ b/src/wallet/rpcdump.cpp @@ -1766,7 +1766,7 @@ static UniValue ProcessDescriptorImport(CWallet * const pwallet, const UniValue& if (!w_desc.descriptor->GetOutputType()) { warnings.push_back("Unknown output type, cannot set descriptor to active."); } else { - pwallet->SetActiveScriptPubKeyMan(spk_manager->GetID(), internal); + pwallet->AddActiveScriptPubKeyMan(spk_manager->GetID(), internal); } } @@ -1972,8 +1972,6 @@ RPCHelpMan listdescriptors() throw JSONRPCError(RPC_WALLET_ERROR, "listdescriptors is not available for non-descriptor wallets"); } - EnsureWalletIsUnlocked(wallet.get()); - LOCK(wallet->cs_wallet); UniValue descriptors(UniValue::VARR); @@ -1987,7 +1985,7 @@ RPCHelpMan listdescriptors() LOCK(desc_spk_man->cs_desc_man); const auto& wallet_descriptor = desc_spk_man->GetWalletDescriptor(); std::string descriptor; - if (!desc_spk_man->GetDescriptorString(descriptor, false)) { + if (!desc_spk_man->GetDescriptorString(descriptor)) { throw JSONRPCError(RPC_WALLET_ERROR, "Can't get normalized descriptor string."); } spk.pushKV("desc", descriptor); diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index fec1bb8e4e..6b3986ef3f 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -83,14 +83,12 @@ static bool ParseIncludeWatchonly(const UniValue& include_watchonly, const CWall /** Checks if a CKey is in the given CWallet compressed or otherwise*/ -/* bool HaveKey(const SigningProvider& wallet, const CKey& key) { CKey key2; key2.Set(key.begin(), key.end(), !key.IsCompressed()); return wallet.HaveKey(key.GetPubKey().GetID()) || wallet.HaveKey(key2.GetPubKey().GetID()); } -*/ bool GetWalletNameFromJSONRPCRequest(const JSONRPCRequest& request, std::string& wallet_name) { @@ -2976,7 +2974,7 @@ static RPCHelpMan createwallet() { {"wallet_name", RPCArg::Type::STR, RPCArg::Optional::NO, "The name for the new wallet. If this is a path, the wallet will be created at the path location."}, {"disable_private_keys", RPCArg::Type::BOOL, /* default */ "false", "Disable the possibility of private keys (only watchonlys are possible in this mode)."}, - {"blank", RPCArg::Type::BOOL, /* default */ "false", "Create a blank wallet. A blank wallet has no keys or HD seed. One can be set using upgradetohd."}, + {"blank", RPCArg::Type::BOOL, /* default */ "false", "Create a blank wallet. A blank wallet has no keys or HD seed. One can be set using upgradetohd (by mnemonic) or sethdseed (WIF private key)."}, {"passphrase", RPCArg::Type::STR, RPCArg::Optional::OMITTED, "Encrypt the wallet with this passphrase."}, {"avoid_reuse", RPCArg::Type::BOOL, /* default */ "false", "Keep track of coin reuse, and treat dirty and clean coins differently with privacy considerations in mind."}, {"descriptors", RPCArg::Type::BOOL, /* default */ "false", "Create a native descriptor wallet. The wallet will use descriptors internally to handle address creation"}, @@ -3517,7 +3515,7 @@ static RPCHelpMan fundrawtransaction() CAmount fee; int change_position; CCoinControl coin_control; - // Automatically select (additional) coins. Can be overriden by options.add_inputs. + // Automatically select (additional) coins. Can be overridden by options.add_inputs. coin_control.m_add_inputs = true; FundTransaction(pwallet, tx, fee, change_position, request.params[1], coin_control); @@ -3938,7 +3936,7 @@ RPCHelpMan getaddressinfo() DescriptorScriptPubKeyMan* desc_spk_man = dynamic_cast(pwallet->GetScriptPubKeyMan(scriptPubKey)); if (desc_spk_man) { std::string desc_str; - if (desc_spk_man->GetDescriptorString(desc_str, false)) { + if (desc_spk_man->GetDescriptorString(desc_str)) { ret.pushKV("parent_desc", desc_str); } } @@ -4272,6 +4270,85 @@ static RPCHelpMan send() }; } +static RPCHelpMan sethdseed() +{ + return RPCHelpMan{"sethdseed", + "\nSet or generate a new HD wallet seed. Non-HD wallets will not be upgraded to being a HD wallet. Wallets that are already\n" + "HD can not be updated to a new HD seed.\n" + "\nNote that you will need to MAKE A NEW BACKUP of your wallet after setting the HD wallet seed." + + HELP_REQUIRING_PASSPHRASE, + { + {"newkeypool", RPCArg::Type::BOOL, /* default */ "true", "Whether to flush old unused addresses, including change addresses, from the keypool and regenerate it.\n" + "If true, the next address from getnewaddress and change address from getrawchangeaddress will be from this new seed.\n" + "If false, addresses from the existing keypool will be used until it has been depleted."}, + {"seed", RPCArg::Type::STR, /* default */ "random seed", "The WIF private key to use as the new HD seed.\n" + "The seed value can be retrieved using the dumpwallet command. It is the private key marked hdseed=1"}, + }, + RPCResult{RPCResult::Type::NONE, "", ""}, + RPCExamples{ + HelpExampleCli("sethdseed", "") + + HelpExampleCli("sethdseed", "false") + + HelpExampleCli("sethdseed", "true \"wifkey\"") + + HelpExampleRpc("sethdseed", "true, \"wifkey\"") + }, + [&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue +{ + // TODO: add mnemonic feature to sethdseed or remove it in favour of upgradetohd + std::shared_ptr const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + + LegacyScriptPubKeyMan& spk_man = EnsureLegacyScriptPubKeyMan(*pwallet, true); + + if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { + throw JSONRPCError(RPC_WALLET_ERROR, "Cannot set a HD seed to a wallet with private keys disabled"); + } + + LOCK2(pwallet->cs_wallet, spk_man.cs_KeyStore); + + // Do not do anything to non-HD wallets + if (!pwallet->CanSupportFeature(FEATURE_HD)) { + throw JSONRPCError(RPC_WALLET_ERROR, "Cannot set a HD seed on a non-HD wallet. Use the upgradewallet RPC in order to upgrade a non-HD wallet to HD"); + } + if (pwallet->IsHDEnabled()) { + throw JSONRPCError(RPC_WALLET_ERROR, "Cannot set a HD seed. The wallet already has a seed"); + } + + EnsureWalletIsUnlocked(pwallet); + + bool flush_key_pool = true; + if (!request.params[0].isNull()) { + flush_key_pool = request.params[0].get_bool(); + } + + if (request.params[1].isNull()) { + spk_man.GenerateNewHDChain("", ""); + } else { + CKey key = DecodeSecret(request.params[1].get_str()); + if (!key.IsValid()) { + throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key"); + } + if (HaveKey(spk_man, key)) { + throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Already have this key (either as an HD seed or as a loose private key)"); + } + CHDChain newHdChain; + if (!newHdChain.SetSeed(SecureVector(key.begin(), key.end()), true)) { + throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key: SetSeed failed"); + } + if (!spk_man.AddHDChainSingle(newHdChain)) { + throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid private key: AddHDChainSingle failed"); + } + // add default account + newHdChain.AddAccount(); + } + + if (flush_key_pool) spk_man.NewKeyPool(); + + return NullUniValue; +}, + }; +} + RPCHelpMan walletprocesspsbt() { return RPCHelpMan{"walletprocesspsbt", @@ -4432,7 +4509,7 @@ static RPCHelpMan walletcreatefundedpsbt() CMutableTransaction rawTx = ConstructTransaction(request.params[0], request.params[1], request.params[2]); CCoinControl coin_control; // Automatically select coins, unless at least one is manually selected. Can - // be overriden by options.add_inputs. + // be overridden by options.add_inputs. coin_control.m_add_inputs = rawTx.vin.size() == 0; FundTransaction(pwallet, rawTx, fee, change_position, request.params[3], coin_control); @@ -4463,14 +4540,18 @@ static RPCHelpMan walletcreatefundedpsbt() static RPCHelpMan upgradewallet() { return RPCHelpMan{"upgradewallet", - "\nUpgrade the wallet. Upgrades to the latest version if no version number is specified\n" + "\nUpgrade the wallet. Upgrades to the latest version if no version number is specified.\n" "New keys may be generated and a new wallet backup will need to be made.", { - {"version", RPCArg::Type::NUM, /* default */ strprintf("%d", FEATURE_LATEST), "The version number to upgrade to. Default is the latest wallet version"} + {"version", RPCArg::Type::NUM, /* default */ strprintf("%d", FEATURE_LATEST), "The version number to upgrade to. Default is the latest wallet version."} }, RPCResult{ RPCResult::Type::OBJ, "", "", { + {RPCResult::Type::STR, "wallet_name", "Name of wallet this operation was performed on"}, + {RPCResult::Type::NUM, "previous_version", "Version of wallet before this operation"}, + {RPCResult::Type::NUM, "current_version", "Version of wallet after this operation"}, + {RPCResult::Type::STR, "result", /* optional */ true, "Description of result, if no error"}, {RPCResult::Type::STR, "error", /* optional */ true, "Error message (if there is one)"} }, }, @@ -4493,11 +4574,27 @@ static RPCHelpMan upgradewallet() version = request.params[0].get_int(); } bilingual_str error; - if (!pwallet->UpgradeWallet(version, error)) { - throw JSONRPCError(RPC_WALLET_ERROR, error.original); + const int previous_version{pwallet->GetVersion()}; + const bool wallet_upgraded{pwallet->UpgradeWallet(version, error)}; + const int current_version{pwallet->GetVersion()}; + std::string result; + + if (wallet_upgraded) { + if (previous_version == current_version) { + result = "Already at latest version. Wallet version unchanged."; + } else { + result = strprintf("Wallet upgraded successfully from version %i to version %i.", previous_version, current_version); + } } + UniValue obj(UniValue::VOBJ); - if (!error.empty()) { + obj.pushKV("wallet_name", pwallet->GetName()); + obj.pushKV("previous_version", previous_version); + obj.pushKV("current_version", current_version); + if (!result.empty()) { + obj.pushKV("result", result); + } else { + CHECK_NONFATAL(!error.empty()); obj.pushKV("error", error.original); } return obj; @@ -4576,6 +4673,7 @@ static const CRPCCommand commands[] = { "wallet", "send", &send, {"outputs","conf_target","estimate_mode","options"} }, { "wallet", "sendmany", &sendmany, {"dummy","amounts","minconf","addlocked","comment","subtractfeefrom","use_is","use_cj","conf_target","estimate_mode"} }, { "wallet", "sendtoaddress", &sendtoaddress, {"address","amount","comment","comment_to","subtractfeefromamount","use_is","use_cj","conf_target","estimate_mode", "avoid_reuse"} }, + { "wallet", "sethdseed", &sethdseed, {"newkeypool","seed"} }, { "wallet", "setcoinjoinrounds", &setcoinjoinrounds, {"rounds"} }, { "wallet", "setcoinjoinamount", &setcoinjoinamount, {"amount"} }, { "wallet", "setlabel", &setlabel, {"address","label"} }, diff --git a/src/wallet/scriptpubkeyman.cpp b/src/wallet/scriptpubkeyman.cpp index fdf019d183..45c03055f1 100644 --- a/src/wallet/scriptpubkeyman.cpp +++ b/src/wallet/scriptpubkeyman.cpp @@ -218,14 +218,14 @@ bool LegacyScriptPubKeyMan::CheckDecryptionKey(const CKeyingMaterial& master_key if (keyFail) { return false; } - if (!keyPass && !accept_no_keys && (hdChain.IsNull() || !hdChain.IsNull() && !hdChain.IsCrypted())) { + if (!keyPass && !accept_no_keys && (m_hd_chain.IsNull() || !m_hd_chain.IsNull() && !m_hd_chain.IsCrypted())) { return false; } - if(!hdChain.IsNull() && !hdChain.IsCrypted()) { + if(!m_hd_chain.IsNull() && !m_hd_chain.IsCrypted()) { // try to decrypt seed and make sure it matches CHDChain hdChainTmp; - if (!DecryptHDChain(master_key, hdChainTmp) || (hdChain.GetID() != hdChainTmp.GetSeedHash())) { + if (!DecryptHDChain(master_key, hdChainTmp) || (m_hd_chain.GetID() != hdChainTmp.GetSeedHash())) { return false; } } @@ -267,8 +267,8 @@ bool LegacyScriptPubKeyMan::Encrypt(const CKeyingMaterial& master_key, WalletBat } if (!hdChainCurrent.IsNull()) { - assert(EncryptHDChain(master_key, hdChain)); - assert(SetHDChain(hdChain)); + assert(EncryptHDChain(master_key, m_hd_chain)); + assert(LoadHDChain(m_hd_chain)); CHDChain hdChainCrypted; assert(GetHDChain(hdChainCrypted)); @@ -277,7 +277,7 @@ bool LegacyScriptPubKeyMan::Encrypt(const CKeyingMaterial& master_key, WalletBat assert(hdChainCurrent.GetID() == hdChainCrypted.GetID()); assert(hdChainCurrent.GetSeedHash() != hdChainCrypted.GetSeedHash()); - assert(SetHDChain(*encrypted_batch, hdChainCrypted, false)); + assert(AddHDChain(*encrypted_batch, hdChainCrypted)); } encrypted_batch = nullptr; @@ -396,7 +396,7 @@ void LegacyScriptPubKeyMan::GenerateNewCryptedHDChain(const SecureString& secure CHDChain hdChainPrev = hdChainTmp; bool res = EncryptHDChain(vMasterKey, hdChainTmp); assert(res); - res = SetHDChain(hdChainTmp); + res = LoadHDChain(hdChainTmp); assert(res); CHDChain hdChainCrypted; @@ -407,8 +407,8 @@ void LegacyScriptPubKeyMan::GenerateNewCryptedHDChain(const SecureString& secure assert(hdChainPrev.GetID() == hdChainCrypted.GetID()); assert(hdChainPrev.GetSeedHash() != hdChainCrypted.GetSeedHash()); - if (!SetHDChainSingle(hdChainCrypted, false)) { - throw std::runtime_error(std::string(__func__) + ": SetHDChainSingle failed"); + if (!AddHDChainSingle(hdChainCrypted)) { + throw std::runtime_error(std::string(__func__) + ": AddHDChainSingle failed"); } } @@ -426,8 +426,8 @@ void LegacyScriptPubKeyMan::GenerateNewHDChain(const SecureString& secureMnemoni // add default account newHdChain.AddAccount(); - if (!SetHDChainSingle(newHdChain, false)) { - throw std::runtime_error(std::string(__func__) + ": SetHDChainSingle failed"); + if (!AddHDChainSingle(newHdChain)) { + throw std::runtime_error(std::string(__func__) + ": AddHDChainSingle failed"); } if (!NewKeyPool()) { @@ -435,14 +435,24 @@ void LegacyScriptPubKeyMan::GenerateNewHDChain(const SecureString& secureMnemoni } } -bool LegacyScriptPubKeyMan::SetHDChain(WalletBatch &batch, const CHDChain& chain, bool memonly) +bool LegacyScriptPubKeyMan::LoadHDChain(const CHDChain& chain) { LOCK(cs_KeyStore); - if (!SetHDChain(chain)) + if (m_storage.HasEncryptionKeys() != chain.IsCrypted()) return false; + + m_hd_chain = chain; + return true; +} + +bool LegacyScriptPubKeyMan::AddHDChain(WalletBatch &batch, const CHDChain& chain) +{ + LOCK(cs_KeyStore); + + if (!LoadHDChain(chain)) return false; - if (!memonly) { + { if (chain.IsCrypted() && encrypted_batch) { if (!encrypted_batch->WriteHDChain(chain)) throw std::runtime_error(std::string(__func__) + ": WriteHDChain failed for encrypted batch"); @@ -458,10 +468,10 @@ bool LegacyScriptPubKeyMan::SetHDChain(WalletBatch &batch, const CHDChain& chain return true; } -bool LegacyScriptPubKeyMan::SetHDChainSingle(const CHDChain& chain, bool memonly) +bool LegacyScriptPubKeyMan::AddHDChainSingle(const CHDChain& chain) { WalletBatch batch(m_storage.GetDatabase()); - return SetHDChain(batch, chain, memonly); + return AddHDChain(batch, chain); } bool LegacyScriptPubKeyMan::GetDecryptedHDChain(CHDChain& hdChainRet) @@ -539,40 +549,40 @@ bool LegacyScriptPubKeyMan::DecryptHDChain(const CKeyingMaterial& vMasterKeyIn, if (!m_storage.HasEncryptionKeys()) return true; - if (hdChain.IsNull()) + if (m_hd_chain.IsNull()) return false; - if (!hdChain.IsCrypted()) + if (!m_hd_chain.IsCrypted()) return false; SecureVector vchSecureSeed; - SecureVector vchSecureCryptedSeed = hdChain.GetSeed(); + SecureVector vchSecureCryptedSeed = m_hd_chain.GetSeed(); std::vector vchCryptedSeed(vchSecureCryptedSeed.begin(), vchSecureCryptedSeed.end()); - if (!DecryptSecret(vMasterKeyIn, vchCryptedSeed, hdChain.GetID(), vchSecureSeed)) + if (!DecryptSecret(vMasterKeyIn, vchCryptedSeed, m_hd_chain.GetID(), vchSecureSeed)) return false; - hdChainRet = hdChain; + hdChainRet = m_hd_chain; if (!hdChainRet.SetSeed(vchSecureSeed, false)) return false; // hash of decrypted seed must match chain id - if (hdChainRet.GetSeedHash() != hdChain.GetID()) + if (hdChainRet.GetSeedHash() != m_hd_chain.GetID()) return false; SecureVector vchSecureCryptedMnemonic; SecureVector vchSecureCryptedMnemonicPassphrase; // it's ok to have no mnemonic if wallet was initialized via hdseed - if (hdChain.GetMnemonic(vchSecureCryptedMnemonic, vchSecureCryptedMnemonicPassphrase)) { + if (m_hd_chain.GetMnemonic(vchSecureCryptedMnemonic, vchSecureCryptedMnemonicPassphrase)) { SecureVector vchSecureMnemonic; SecureVector vchSecureMnemonicPassphrase; std::vector vchCryptedMnemonic(vchSecureCryptedMnemonic.begin(), vchSecureCryptedMnemonic.end()); std::vector vchCryptedMnemonicPassphrase(vchSecureCryptedMnemonicPassphrase.begin(), vchSecureCryptedMnemonicPassphrase.end()); - if (!vchCryptedMnemonic.empty() && !DecryptSecret(vMasterKeyIn, vchCryptedMnemonic, hdChain.GetID(), vchSecureMnemonic)) + if (!vchCryptedMnemonic.empty() && !DecryptSecret(vMasterKeyIn, vchCryptedMnemonic, m_hd_chain.GetID(), vchSecureMnemonic)) return false; - if (!vchCryptedMnemonicPassphrase.empty() && !DecryptSecret(vMasterKeyIn, vchCryptedMnemonicPassphrase, hdChain.GetID(), vchSecureMnemonicPassphrase)) + if (!vchCryptedMnemonicPassphrase.empty() && !DecryptSecret(vMasterKeyIn, vchCryptedMnemonicPassphrase, m_hd_chain.GetID(), vchSecureMnemonicPassphrase)) return false; if (!hdChainRet.SetMnemonic(vchSecureMnemonic, vchSecureMnemonicPassphrase, false)) @@ -1090,16 +1100,6 @@ bool LegacyScriptPubKeyMan::AddWatchOnly(const CScript& dest, int64_t nCreateTim return AddWatchOnly(dest); } -bool LegacyScriptPubKeyMan::SetHDChain(const CHDChain& chain) -{ - LOCK(cs_KeyStore); - - if (m_storage.HasEncryptionKeys() != chain.IsCrypted()) return false; - - hdChain = chain; - return true; -} - bool LegacyScriptPubKeyMan::HaveHDKey(const CKeyID &address, CHDChain& hdChainCurrent) const { LOCK(cs_KeyStore); @@ -1322,8 +1322,8 @@ void LegacyScriptPubKeyMan::DeriveNewChildKey(WalletBatch &batch, CKeyMetadata& if (!hdChainCurrent.SetAccount(nAccountIndex, acc)) throw std::runtime_error(std::string(__func__) + ": SetAccount failed"); - if (!SetHDChain(batch, hdChainCurrent, false)) { - throw std::runtime_error(std::string(__func__) + ": SetHDChain failed"); + if (!AddHDChain(batch, hdChainCurrent)) { + throw std::runtime_error(std::string(__func__) + ": AddHDChain failed"); } if (!AddHDPubKey(batch, childKey.Neuter(), fInternal)) @@ -1758,8 +1758,8 @@ std::set LegacyScriptPubKeyMan::GetKeys() const bool LegacyScriptPubKeyMan::GetHDChain(CHDChain& hdChainRet) const { LOCK(cs_KeyStore); - hdChainRet = hdChain; - return !hdChain.IsNull(); + hdChainRet = m_hd_chain; + return !m_hd_chain.IsNull(); } void LegacyScriptPubKeyMan::SetInternal(bool internal) {} @@ -1950,34 +1950,10 @@ bool DescriptorScriptPubKeyMan::TopUp(unsigned int size) } m_map_pubkeys[pubkey] = i; } - // Write the cache - for (const auto& parent_xpub_pair : temp_cache.GetCachedParentExtPubKeys()) { - CExtPubKey xpub; - if (m_wallet_descriptor.cache.GetCachedParentExtPubKey(parent_xpub_pair.first, xpub)) { - if (xpub != parent_xpub_pair.second) { - throw std::runtime_error(std::string(__func__) + ": New cached parent xpub does not match already cached parent xpub"); - } - continue; - } - if (!batch.WriteDescriptorParentCache(parent_xpub_pair.second, id, parent_xpub_pair.first)) { - throw std::runtime_error(std::string(__func__) + ": writing cache item failed"); - } - m_wallet_descriptor.cache.CacheParentExtPubKey(parent_xpub_pair.first, parent_xpub_pair.second); - } - for (const auto& derived_xpub_map_pair : temp_cache.GetCachedDerivedExtPubKeys()) { - for (const auto& derived_xpub_pair : derived_xpub_map_pair.second) { - CExtPubKey xpub; - if (m_wallet_descriptor.cache.GetCachedDerivedExtPubKey(derived_xpub_map_pair.first, derived_xpub_pair.first, xpub)) { - if (xpub != derived_xpub_pair.second) { - throw std::runtime_error(std::string(__func__) + ": New cached derived xpub does not match already cached derived xpub"); - } - continue; - } - if (!batch.WriteDescriptorDerivedCache(derived_xpub_pair.second, id, derived_xpub_map_pair.first, derived_xpub_pair.first)) { - throw std::runtime_error(std::string(__func__) + ": writing cache item failed"); - } - m_wallet_descriptor.cache.CacheDerivedExtPubKey(derived_xpub_map_pair.first, derived_xpub_pair.first, derived_xpub_pair.second); - } + // Merge and write the cache + DescriptorCache new_items = m_wallet_descriptor.cache.MergeAndDiff(temp_cache); + if (!batch.WriteDescriptorCacheItems(id, new_items)) { + throw std::runtime_error(std::string(__func__) + ": writing cache items failed"); } m_max_cached_index++; } @@ -2402,15 +2378,41 @@ const std::vector DescriptorScriptPubKeyMan::GetScriptPubKeys() const return script_pub_keys; } -bool DescriptorScriptPubKeyMan::GetDescriptorString(std::string& out, bool priv) const +bool DescriptorScriptPubKeyMan::GetDescriptorString(std::string& out) const { LOCK(cs_desc_man); - if (m_storage.IsLocked()) { - return false; + + FlatSigningProvider provider; + provider.keys = GetKeys(); + + return m_wallet_descriptor.descriptor->ToNormalizedString(provider, out, &m_wallet_descriptor.cache); +} + +void DescriptorScriptPubKeyMan::UpgradeDescriptorCache() +{ + LOCK(cs_desc_man); + if (m_storage.IsLocked() || m_storage.IsWalletFlagSet(WALLET_FLAG_LAST_HARDENED_XPUB_CACHED)) { + return; + } + + // Skip if we have the last hardened xpub cache + if (m_wallet_descriptor.cache.GetCachedLastHardenedExtPubKeys().size() > 0) { + return; } + // Expand the descriptor FlatSigningProvider provider; provider.keys = GetKeys(); + FlatSigningProvider out_keys; + std::vector scripts_temp; + DescriptorCache temp_cache; + if (!m_wallet_descriptor.descriptor->Expand(0, provider, scripts_temp, out_keys, &temp_cache)){ + throw std::runtime_error("Unable to expand descriptor"); + } - return m_wallet_descriptor.descriptor->ToNormalizedString(provider, out, priv); + // Cache the last hardened xpubs + DescriptorCache diff = m_wallet_descriptor.cache.MergeAndDiff(temp_cache); + if (!WalletBatch(m_storage.GetDatabase()).WriteDescriptorCacheItems(GetID(), diff)) { + throw std::runtime_error(std::string(__func__) + ": writing cache items failed"); + } } diff --git a/src/wallet/scriptpubkeyman.h b/src/wallet/scriptpubkeyman.h index adf19cc8e5..524deb21a2 100644 --- a/src/wallet/scriptpubkeyman.h +++ b/src/wallet/scriptpubkeyman.h @@ -37,7 +37,7 @@ class WalletStorage virtual bool IsWalletFlagSet(uint64_t) const = 0; virtual void UnsetBlankWalletFlag(WalletBatch&) = 0; virtual bool CanSupportFeature(enum WalletFeature) const = 0; - virtual void SetMinVersion(enum WalletFeature, WalletBatch* = nullptr, bool = false) = 0; + virtual void SetMinVersion(enum WalletFeature, WalletBatch* = nullptr) = 0; virtual const CKeyingMaterial& GetEncryptionKey() const = 0; virtual bool HasEncryptionKeys() const = 0; virtual bool IsLocked(bool fForMixing = false) const = 0; @@ -278,15 +278,11 @@ class LegacyScriptPubKeyMan : public ScriptPubKeyMan, public FillableSigningProv /** Add a KeyOriginInfo to the wallet */ bool AddKeyOriginWithDB(WalletBatch& batch, const CPubKey& pubkey, const KeyOriginInfo& info); - /* Set the HD chain model (chain child index counters) */ - bool SetHDChain(WalletBatch &batch, const CHDChain& chain, bool memonly); - bool EncryptHDChain(const CKeyingMaterial& vMasterKeyIn, CHDChain& chain); bool DecryptHDChain(const CKeyingMaterial& vMasterKeyIn, CHDChain& hdChainRet) const; - bool SetHDChain(const CHDChain& chain); /* the HD chain data model (external chain counters) */ - CHDChain hdChain GUARDED_BY(cs_KeyStore); + CHDChain m_hd_chain GUARDED_BY(cs_KeyStore); /* HD derive new child key (on internal or external chain) */ void DeriveNewChildKey(WalletBatch& batch, CKeyMetadata& metadata, CKey& secretRet, uint32_t nAccountIndex, bool fInternal /*= false*/) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); @@ -398,11 +394,15 @@ class LegacyScriptPubKeyMan : public ScriptPubKeyMan, public FillableSigningProv //! Generate a new key CPubKey GenerateNewKey(WalletBatch& batch, uint32_t nAccountIndex, bool fInternal /*= false*/) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); + /* Set the HD chain model (chain child index counters) and writes it to the database */ + bool AddHDChain(WalletBatch &batch, const CHDChain& chain); + //! Load a HD chain model (used by LoadWallet) + bool LoadHDChain(const CHDChain& chain); /** * Set the HD chain model (chain child index counters) using temporary wallet db object * which causes db flush every time these methods are used */ - bool SetHDChainSingle(const CHDChain& chain, bool memonly); + bool AddHDChainSingle(const CHDChain& chain); //! Adds a watch-only address to the store, without saving it to disk (used by LoadWallet) bool LoadWatchOnly(const CScript &dest); @@ -606,7 +606,9 @@ class DescriptorScriptPubKeyMan : public ScriptPubKeyMan const WalletDescriptor GetWalletDescriptor() const EXCLUSIVE_LOCKS_REQUIRED(cs_desc_man); const std::vector GetScriptPubKeys() const; - bool GetDescriptorString(std::string& out, bool priv) const; + bool GetDescriptorString(std::string& out) const; + + void UpgradeDescriptorCache(); }; #endif // BITCOIN_WALLET_SCRIPTPUBKEYMAN_H diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index e1f2537561..e2bcbcdabd 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -389,6 +389,19 @@ void CWallet::UpgradeKeyMetadata() SetWalletFlag(WALLET_FLAG_KEY_ORIGIN_METADATA); } +void CWallet::UpgradeDescriptorCache() +{ + if (!IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS) || IsLocked() || IsWalletFlagSet(WALLET_FLAG_LAST_HARDENED_XPUB_CACHED)) { + return; + } + + for (ScriptPubKeyMan* spkm : GetAllScriptPubKeyMans()) { + DescriptorScriptPubKeyMan* desc_spkm = dynamic_cast(spkm); + desc_spkm->UpgradeDescriptorCache(); + } + SetWalletFlag(WALLET_FLAG_LAST_HARDENED_XPUB_CACHED); +} + bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, const SecureString& strNewWalletPassphrase) { bool fWasLocked = IsLocked(true); @@ -442,21 +455,13 @@ void CWallet::chainStateFlushed(const CBlockLocator& loc) batch.WriteBestBlock(loc); } -void CWallet::SetMinVersion(enum WalletFeature nVersion, WalletBatch* batch_in, bool fExplicit) +void CWallet::SetMinVersion(enum WalletFeature nVersion, WalletBatch* batch_in) { LOCK(cs_wallet); if (nWalletVersion >= nVersion) return; - - // when doing an explicit upgrade, if we pass the max version permitted, upgrade all the way - if (fExplicit && nVersion > nWalletMaxVersion) - nVersion = FEATURE_LATEST; - nWalletVersion = nVersion; - if (nVersion > nWalletMaxVersion) - nWalletMaxVersion = nVersion; - { WalletBatch* batch = batch_in ? batch_in : new WalletBatch(GetDatabase()); if (nWalletVersion > 40000) @@ -466,18 +471,6 @@ void CWallet::SetMinVersion(enum WalletFeature nVersion, WalletBatch* batch_in, } } -bool CWallet::SetMaxVersion(int nVersion) -{ - LOCK(cs_wallet); - // cannot downgrade below current version - if (nWalletVersion > nVersion) - return false; - - nWalletMaxVersion = nVersion; - - return true; -} - std::set CWallet::GetConflicts(const uint256& txid) const { std::set result; @@ -665,7 +658,7 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) // Encryption was introduced in version 0.4.0 - SetMinVersion(FEATURE_WALLETCRYPT, encrypted_batch, true); + SetMinVersion(FEATURE_WALLETCRYPT, encrypted_batch); if (!encrypted_batch->TxnCommit()) { delete encrypted_batch; @@ -1684,19 +1677,28 @@ bool CWallet::IsWalletFlagSet(uint64_t flag) const return (m_wallet_flags & flag); } -bool CWallet::SetWalletFlags(uint64_t overwriteFlags, bool memonly) +bool CWallet::LoadWalletFlags(uint64_t flags) { LOCK(cs_wallet); - m_wallet_flags = overwriteFlags; - if (((overwriteFlags & KNOWN_WALLET_FLAGS) >> 32) ^ (overwriteFlags >> 32)) { + if (((flags & KNOWN_WALLET_FLAGS) >> 32) ^ (flags >> 32)) { // contains unknown non-tolerable wallet flags return false; } - if (!memonly && !WalletBatch(GetDatabase()).WriteWalletFlags(m_wallet_flags)) { + m_wallet_flags = flags; + + return true; +} + +bool CWallet::AddWalletFlags(uint64_t flags) +{ + LOCK(cs_wallet); + // We should never be writing unknown non-tolerable wallet flags + assert(((flags & KNOWN_WALLET_FLAGS) >> 32) == (flags >> 32)); + if (!WalletBatch(GetDatabase()).WriteWalletFlags(flags)) { throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed"); } - return true; + return LoadWalletFlags(flags); } int64_t CWalletTx::GetTxTime() const @@ -4574,8 +4576,9 @@ std::shared_ptr CWallet::Create(interfaces::Chain& chain, interfaces::C if (fFirstRun) { - walletInstance->SetMaxVersion(FEATURE_LATEST); - walletInstance->SetWalletFlags(wallet_creation_flags, false); + walletInstance->SetMinVersion(FEATURE_LATEST); + + walletInstance->AddWalletFlags(wallet_creation_flags); // Only create LegacyScriptPubKeyMan when not descriptor wallet if (!walletInstance->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) { @@ -4600,8 +4603,8 @@ std::shared_ptr CWallet::Create(interfaces::Chain& chain, interfaces::C } LOCK(walletInstance->cs_wallet); if (auto spk_man = walletInstance->GetLegacyScriptPubKeyMan()) { - if (!spk_man->SetHDChainSingle(newHdChain, false)) { - error = strprintf(_("%s failed"), "SetHDChainSingle"); + if (!spk_man->AddHDChainSingle(newHdChain)) { + error = strprintf(_("%s failed"), "AddHDChainSingle"); return nullptr; } } @@ -4887,6 +4890,7 @@ std::shared_ptr CWallet::Create(interfaces::Chain& chain, interfaces::C bool CWallet::UpgradeWallet(int version, bilingual_str& error) { + int prev_version = GetVersion(); int nMaxVersion = version; auto nMinVersion = DEFAULT_USE_HD_WALLET ? FEATURE_LATEST : FEATURE_COMPRPUBKEY; if (nMaxVersion == 0) { @@ -4898,17 +4902,18 @@ bool CWallet::UpgradeWallet(int version, bilingual_str& error) } if (nMaxVersion < GetVersion()) { - error = Untranslated("Cannot downgrade wallet"); + error = strprintf(_("Cannot downgrade wallet from version %i to version %i. Wallet version unchanged."), prev_version, version); return false; } // TODO: consider discourage users to skip passphrase for HD wallets for v21 if (false && nMaxVersion >= FEATURE_HD && !IsHDEnabled()) { error = Untranslated("You should use upgradetohd RPC to upgrade non-HD wallet to HD"); + error = strprintf(_("Cannot upgrade a non HD wallet from version %i to version %i which is non-HD wallet. Use upgradetohd RPC"), prev_version, version); return false; } - SetMaxVersion(nMaxVersion); + SetMinVersion(GetClosestWalletFeature(version)); return true; } @@ -5649,12 +5654,21 @@ void CWallet::SetupDescriptorScriptPubKeyMans() spk_manager->SetupDescriptorGeneration(master_key); uint256 id = spk_manager->GetID(); m_spk_managers[id] = std::move(spk_manager); - SetActiveScriptPubKeyMan(id, internal); + AddActiveScriptPubKeyMan(id, internal); } } } -void CWallet::SetActiveScriptPubKeyMan(uint256 id, bool internal, bool memonly) +void CWallet::AddActiveScriptPubKeyMan(uint256 id, bool internal) +{ + WalletBatch batch(GetDatabase()); + if (!batch.WriteActiveScriptPubKeyMan(id, internal)) { + throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed"); + } + LoadActiveScriptPubKeyMan(id, internal); +} + +void CWallet::LoadActiveScriptPubKeyMan(uint256 id, bool internal) { WalletLogPrintf("Setting spkMan to active: id = %s, type = %d, internal = %d\n", id.ToString(), static_cast(OutputType::LEGACY), static_cast(internal)); auto& spk_mans = internal ? m_internal_spk_managers : m_external_spk_managers; @@ -5662,12 +5676,6 @@ void CWallet::SetActiveScriptPubKeyMan(uint256 id, bool internal, bool memonly) spk_man->SetInternal(internal); spk_mans = spk_man; - if (!memonly) { - WalletBatch batch(GetDatabase()); - if (!batch.WriteActiveScriptPubKeyMan(id, internal)) { - throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed"); - } - } NotifyCanGetAddressesChanged(); } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index a0c5618ac5..39dc9b96d9 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -128,6 +128,7 @@ static constexpr uint64_t KNOWN_WALLET_FLAGS = WALLET_FLAG_AVOID_REUSE | WALLET_FLAG_BLANK_WALLET | WALLET_FLAG_KEY_ORIGIN_METADATA + | WALLET_FLAG_LAST_HARDENED_XPUB_CACHED | WALLET_FLAG_DISABLE_PRIVATE_KEYS | WALLET_FLAG_DESCRIPTORS; @@ -138,6 +139,7 @@ static const std::map WALLET_FLAG_MAP{ {"avoid_reuse", WALLET_FLAG_AVOID_REUSE}, {"blank", WALLET_FLAG_BLANK_WALLET}, {"key_origin_metadata", WALLET_FLAG_KEY_ORIGIN_METADATA}, + {"last_hardened_xpub_cached", WALLET_FLAG_LAST_HARDENED_XPUB_CACHED}, {"disable_private_keys", WALLET_FLAG_DISABLE_PRIVATE_KEYS}, {"descriptor_wallet", WALLET_FLAG_DESCRIPTORS}, }; @@ -706,9 +708,6 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati //! the current wallet version: clients below this version are not able to load the wallet int nWalletVersion GUARDED_BY(cs_wallet){FEATURE_BASE}; - //! the maximum wallet format version: memory-only variable that specifies to what version this wallet may be upgraded - int nWalletMaxVersion GUARDED_BY(cs_wallet) = FEATURE_BASE; - int64_t nNextResend = 0; bool fBroadcastTransactions = false; // Local time that the tip block was received. Used to schedule wallet rebroadcasts. @@ -911,8 +910,8 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati const CWalletTx* GetWalletTx(const uint256& hash) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); bool IsTrusted(const CWalletTx& wtx, std::set& trusted_parents) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - //! check whether we are allowed to upgrade (or already support) to the named feature - bool CanSupportFeature(enum WalletFeature wf) const override EXCLUSIVE_LOCKS_REQUIRED(cs_wallet) { AssertLockHeld(cs_wallet); return nWalletMaxVersion >= wf; } + //! check whether we support the named feature + bool CanSupportFeature(enum WalletFeature wf) const override EXCLUSIVE_LOCKS_REQUIRED(cs_wallet) { AssertLockHeld(cs_wallet); return IsFeatureSupported(nWalletVersion, wf); } /** * populate vCoins with vector of available COutputs. @@ -981,7 +980,10 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati //! Upgrade stored CKeyMetadata objects to store key origin info as KeyOriginInfo void UpgradeKeyMetadata() EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - bool LoadMinVersion(int nVersion) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet) { AssertLockHeld(cs_wallet); nWalletVersion = nVersion; nWalletMaxVersion = std::max(nWalletMaxVersion, nVersion); return true; } + //! Upgrade DescriptorCaches + void UpgradeDescriptorCache() EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + + bool LoadMinVersion(int nVersion) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet) { AssertLockHeld(cs_wallet); nWalletVersion = nVersion; return true; } //! Adds a destination data tuple to the store, without saving it to disk void LoadDestData(const CTxDestination& dest, const std::string& key, const std::string& value) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); @@ -1198,11 +1200,8 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati unsigned int GetKeyPoolSize() const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - //! signify that a particular wallet feature is now used. this may change nWalletVersion and nWalletMaxVersion if those are lower - void SetMinVersion(enum WalletFeature, WalletBatch* batch_in = nullptr, bool fExplicit = false) override; - - //! change which version we're allowed to upgrade to (note that this does not immediately imply upgrading to that format) - bool SetMaxVersion(int nVersion); + //! signify that a particular wallet feature is now used. + void SetMinVersion(enum WalletFeature, WalletBatch* batch_in = nullptr) override; //! get the current wallet format (the oldest client version guaranteed to understand this wallet) int GetVersion() const { LOCK(cs_wallet); return nWalletVersion; } @@ -1334,7 +1333,9 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati /** overwrite all flags by the given uint64_t returns false if unknown, non-tolerable flags are present */ - bool SetWalletFlags(uint64_t overwriteFlags, bool memOnly); + bool AddWalletFlags(uint64_t flags); + /** Loads the flags into the wallet. (used by LoadWallet) */ + bool LoadWalletFlags(uint64_t flags); /** Determine if we are a legacy wallet */ bool IsLegacy() const; @@ -1415,12 +1416,15 @@ class CWallet final : public WalletStorage, public interfaces::Chain::Notificati //! Instantiate a descriptor ScriptPubKeyMan from the WalletDescriptor and load it void LoadDescriptorScriptPubKeyMan(uint256 id, WalletDescriptor& desc); - //! Sets the active ScriptPubKeyMan for the specified type and internal + //! Adds the active ScriptPubKeyMan for the specified type and internal. Writes it to the wallet file + //! @param[in] id The unique id for the ScriptPubKeyMan + //! @param[in] internal Whether this ScriptPubKeyMan provides change addresses + void AddActiveScriptPubKeyMan(uint256 id, bool internal); + + //! Loads an active ScriptPubKeyMan for the specified type and internal. (used by LoadWallet) //! @param[in] id The unique id for the ScriptPubKeyMan - //! @param[in] type The OutputType this ScriptPubKeyMan provides addresses for //! @param[in] internal Whether this ScriptPubKeyMan provides change addresses - //! @param[in] memonly Whether to record this update to the database. Set to true for wallet loading, normally false when actually updating the wallet. - void SetActiveScriptPubKeyMan(uint256 id, bool internal, bool memonly = false); + void LoadActiveScriptPubKeyMan(uint256 id, bool internal); //! Create new DescriptorScriptPubKeyMans and add them to the wallet void SetupDescriptorScriptPubKeyMans(); diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 3c9d2394bd..a7b02934f3 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -60,6 +60,7 @@ const std::string TX{"tx"}; const std::string VERSION{"version"}; const std::string WALLETDESCRIPTOR{"walletdescriptor"}; const std::string WALLETDESCRIPTORCACHE{"walletdescriptorcache"}; +const std::string WALLETDESCRIPTORLHCACHE{"walletdescriptorlhcache"}; const std::string WALLETDESCRIPTORCKEY{"walletdescriptorckey"}; const std::string WALLETDESCRIPTORKEY{"walletdescriptorkey"}; const std::string WATCHMETA{"watchmeta"}; @@ -272,6 +273,35 @@ bool WalletBatch::WriteDescriptorParentCache(const CExtPubKey& xpub, const uint2 return WriteIC(std::make_pair(std::make_pair(DBKeys::WALLETDESCRIPTORCACHE, desc_id), key_exp_index), ser_xpub); } +bool WalletBatch::WriteDescriptorLastHardenedCache(const CExtPubKey& xpub, const uint256& desc_id, uint32_t key_exp_index) +{ + std::vector ser_xpub(BIP32_EXTKEY_SIZE); + xpub.Encode(ser_xpub.data()); + return WriteIC(std::make_pair(std::make_pair(DBKeys::WALLETDESCRIPTORLHCACHE, desc_id), key_exp_index), ser_xpub); +} + +bool WalletBatch::WriteDescriptorCacheItems(const uint256& desc_id, const DescriptorCache& cache) +{ + for (const auto& parent_xpub_pair : cache.GetCachedParentExtPubKeys()) { + if (!WriteDescriptorParentCache(parent_xpub_pair.second, desc_id, parent_xpub_pair.first)) { + return false; + } + } + for (const auto& derived_xpub_map_pair : cache.GetCachedDerivedExtPubKeys()) { + for (const auto& derived_xpub_pair : derived_xpub_map_pair.second) { + if (!WriteDescriptorDerivedCache(derived_xpub_pair.second, desc_id, derived_xpub_map_pair.first, derived_xpub_pair.first)) { + return false; + } + } + } + for (const auto& lh_xpub_pair : cache.GetCachedLastHardenedExtPubKeys()) { + if (!WriteDescriptorLastHardenedCache(lh_xpub_pair.second, desc_id, lh_xpub_pair.first)) { + return false; + } + } + return true; +} + class CWalletScanState { public: unsigned int nKeys{0}; @@ -516,7 +546,7 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, CHDChain chain; ssValue >> chain; assert ((strType == DBKeys::CRYPTED_HDCHAIN) == chain.IsCrypted()); - if (!pwallet->GetOrCreateLegacyScriptPubKeyMan()->SetHDChainSingle(chain, true)) + if (!pwallet->GetOrCreateLegacyScriptPubKeyMan()->LoadHDChain(chain)) { strErr = "Error reading wallet database: SetHDChain failed"; return false; @@ -554,13 +584,6 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, strErr = "Invalid governance object: LoadGovernanceObject"; return false; } - } else if (strType == DBKeys::FLAGS) { - uint64_t flags; - ssValue >> flags; - if (!pwallet->SetWalletFlags(flags, true)) { - strErr = "Error reading wallet database: Unknown non-tolerable wallet flags found"; - return false; - } } else if (strType == DBKeys::OLD_KEY) { strErr = "Found unsupported 'wkey' record, try loading with version 0.17"; return false; @@ -610,6 +633,17 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, } else { wss.m_descriptor_caches[desc_id].CacheDerivedExtPubKey(key_exp_index, der_index, xpub); } + } else if (strType == DBKeys::WALLETDESCRIPTORLHCACHE) { + uint256 desc_id; + uint32_t key_exp_index; + ssKey >> desc_id; + ssKey >> key_exp_index; + + std::vector ser_xpub(BIP32_EXTKEY_SIZE); + ssValue >> ser_xpub; + CExtPubKey xpub; + xpub.Decode(ser_xpub.data()); + wss.m_descriptor_caches[desc_id].CacheLastHardenedExtPubKey(key_exp_index, xpub); } else if (strType == DBKeys::WALLETDESCRIPTORKEY) { uint256 desc_id; CPubKey pubkey; @@ -665,7 +699,8 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, } else if (strType != DBKeys::BESTBLOCK && strType != DBKeys::BESTBLOCK_NOMERKLE && strType != DBKeys::MINVERSION && strType != DBKeys::ACENTRY && strType != DBKeys::VERSION && strType != DBKeys::SETTINGS && - strType != DBKeys::PRIVATESEND_SALT && strType != DBKeys::COINJOIN_SALT) { + strType != DBKeys::PRIVATESEND_SALT && strType != DBKeys::COINJOIN_SALT && + strType != DBKeys::FLAGS) { wss.m_unknown_records++; } } catch (const std::exception& e) { @@ -711,6 +746,16 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) pwallet->LoadMinVersion(nMinVersion); } + // Load wallet flags, so they are known when processing other records. + // The FLAGS key is absent during wallet creation. + uint64_t flags; + if (m_batch->Read(DBKeys::FLAGS, flags)) { + if (!pwallet->LoadWalletFlags(flags)) { + pwallet->WalletLogPrintf("Error reading wallet database: Unknown non-tolerable wallet flags found\n"); + return DBErrors::CORRUPT; + } + } + // Get cursor if (!m_batch->StartCursor()) { @@ -768,10 +813,10 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) // Set the active ScriptPubKeyMans for (auto spk_man : wss.m_active_external_spks) { - pwallet->SetActiveScriptPubKeyMan(spk_man.second, /* internal */ false, /* memonly */ true); + pwallet->LoadActiveScriptPubKeyMan(spk_man.second, /* internal */ false); } for (auto spk_man : wss.m_active_internal_spks) { - pwallet->SetActiveScriptPubKeyMan(spk_man.second, /* internal */ true, /* memonly */ true); + pwallet->LoadActiveScriptPubKeyMan(spk_man.second, /* internal */ true); } // Set the descriptor caches @@ -840,6 +885,14 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) result = DBErrors::CORRUPT; } + // Upgrade all of the descriptor caches to cache the last hardened xpub + // This operation is not atomic, but if it fails, only new entries are added so it is backwards compatible + try { + pwallet->UpgradeDescriptorCache(); + } catch (...) { + result = DBErrors::CORRUPT; + } + return result; } diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index 547fb1e351..916155d727 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -217,6 +217,8 @@ class WalletBatch bool WriteDescriptor(const uint256& desc_id, const WalletDescriptor& descriptor); bool WriteDescriptorDerivedCache(const CExtPubKey& xpub, const uint256& desc_id, uint32_t key_exp_index, uint32_t der_index); bool WriteDescriptorParentCache(const CExtPubKey& xpub, const uint256& desc_id, uint32_t key_exp_index); + bool WriteDescriptorLastHardenedCache(const CExtPubKey& xpub, const uint256& desc_id, uint32_t key_exp_index); + bool WriteDescriptorCacheItems(const uint256& desc_id, const DescriptorCache& cache); /// Write destination data key,value tuple to database bool WriteDestData(const std::string &address, const std::string &key, const std::string &value); diff --git a/src/wallet/walletutil.cpp b/src/wallet/walletutil.cpp index 2e98940e3b..68bae737a8 100644 --- a/src/wallet/walletutil.cpp +++ b/src/wallet/walletutil.cpp @@ -28,3 +28,17 @@ fs::path GetWalletDir() return path; } + +bool IsFeatureSupported(int wallet_version, int feature_version) +{ + return wallet_version >= feature_version; +} + +WalletFeature GetClosestWalletFeature(int version) +{ + const std::array wallet_features{{FEATURE_LATEST, FEATURE_HD, FEATURE_COMPRPUBKEY, FEATURE_WALLETCRYPT, FEATURE_BASE}}; + for (const WalletFeature& wf : wallet_features) { + if (version >= wf) return wf; + } + return static_cast(0); +} diff --git a/src/wallet/walletutil.h b/src/wallet/walletutil.h index 93547ce0d9..2e67db11d4 100644 --- a/src/wallet/walletutil.h +++ b/src/wallet/walletutil.h @@ -23,6 +23,8 @@ enum WalletFeature FEATURE_LATEST = FEATURE_HD }; +bool IsFeatureSupported(int wallet_version, int feature_version); +WalletFeature GetClosestWalletFeature(int version); enum WalletFlags : uint64_t { // wallet flags in the upper section (> 1 << 31) will lead to not opening the wallet if flag is unknown @@ -35,6 +37,9 @@ enum WalletFlags : uint64_t { // Indicates that the metadata has already been upgraded to contain key origins WALLET_FLAG_KEY_ORIGIN_METADATA = (1ULL << 1), + // Indicates that the descriptor cache has been upgraded to cache last hardened xpubs + WALLET_FLAG_LAST_HARDENED_XPUB_CACHED = (1ULL << 2), + // will enforce the rule that the wallet can't contain any private keys (only watch-only/pubkeys) WALLET_FLAG_DISABLE_PRIVATE_KEYS = (1ULL << 32), diff --git a/test/functional/data/wallets/high_minversion/.walletlock b/test/functional/data/wallets/high_minversion/.walletlock deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/functional/data/wallets/high_minversion/GENERATE.md b/test/functional/data/wallets/high_minversion/GENERATE.md deleted file mode 100644 index e55c4557ca..0000000000 --- a/test/functional/data/wallets/high_minversion/GENERATE.md +++ /dev/null @@ -1,8 +0,0 @@ -The wallet has been created by starting Bitcoin Core with the options -`-regtest -datadir=/tmp -nowallet -walletdir=$(pwd)/test/functional/data/wallets/`. - -In the source code, `WalletFeature::FEATURE_LATEST` has been modified to be large, so that the minversion is too high -for a current build of the wallet. - -The wallet has then been created with the RPC `createwallet high_minversion true true`, so that a blank wallet with -private keys disabled is created. diff --git a/test/functional/data/wallets/high_minversion/db.log b/test/functional/data/wallets/high_minversion/db.log deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/functional/data/wallets/high_minversion/wallet.dat b/test/functional/data/wallets/high_minversion/wallet.dat deleted file mode 100644 index 99ab809263..0000000000 Binary files a/test/functional/data/wallets/high_minversion/wallet.dat and /dev/null differ diff --git a/test/functional/test_framework/bdb.py b/test/functional/test_framework/bdb.py new file mode 100644 index 0000000000..9de358aa0a --- /dev/null +++ b/test/functional/test_framework/bdb.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# Copyright (c) 2020 The Bitcoin Core developers +# Distributed under the MIT software license, see the accompanying +# file COPYING or http://www.opensource.org/licenses/mit-license.php. +""" +Utilities for working directly with the wallet's BDB database file + +This is specific to the configuration of BDB used in this project: + - pagesize: 4096 bytes + - Outer database contains single subdatabase named 'main' + - btree + - btree leaf pages + +Each key-value pair is two entries in a btree leaf. The first is the key, the one that follows +is the value. And so on. Note that the entry data is itself not in the correct order. Instead +entry offsets are stored in the correct order and those offsets are needed to then retrieve +the data itself. + +Page format can be found in BDB source code dbinc/db_page.h +This only implements the deserialization of btree metadata pages and normal btree pages. Overflow +pages are not implemented but may be needed in the future if dealing with wallets with large +transactions. + +`db_dump -da wallet.dat` is useful to see the data in a wallet.dat BDB file +""" + +import binascii +import struct + +# Important constants +PAGESIZE = 4096 +OUTER_META_PAGE = 0 +INNER_META_PAGE = 2 + +# Page type values +BTREE_INTERNAL = 3 +BTREE_LEAF = 5 +BTREE_META = 9 + +# Some magic numbers for sanity checking +BTREE_MAGIC = 0x053162 +DB_VERSION = 9 + +# Deserializes a leaf page into a dict. +# Btree internal pages have the same header, for those, return None. +# For the btree leaf pages, deserialize them and put all the data into a dict +def dump_leaf_page(data): + page_info = {} + page_header = data[0:26] + _, pgno, prev_pgno, next_pgno, entries, hf_offset, level, pg_type = struct.unpack('QIIIHHBB', page_header) + page_info['pgno'] = pgno + page_info['prev_pgno'] = prev_pgno + page_info['next_pgno'] = next_pgno + page_info['entries'] = entries + page_info['hf_offset'] = hf_offset + page_info['level'] = level + page_info['pg_type'] = pg_type + page_info['entry_offsets'] = struct.unpack('{}H'.format(entries), data[26:26 + entries * 2]) + page_info['entries'] = [] + + if pg_type == BTREE_INTERNAL: + # Skip internal pages. These are the internal nodes of the btree and don't contain anything relevant to us + return None + + assert pg_type == BTREE_LEAF, 'A non-btree leaf page has been encountered while dumping leaves' + + for i in range(0, entries): + offset = page_info['entry_offsets'][i] + entry = {'offset': offset} + page_data_header = data[offset:offset + 3] + e_len, pg_type = struct.unpack('HB', page_data_header) + entry['len'] = e_len + entry['pg_type'] = pg_type + entry['data'] = data[offset + 3:offset + 3 + e_len] + page_info['entries'].append(entry) + + return page_info + +# Deserializes a btree metadata page into a dict. +# Does a simple sanity check on the magic value, type, and version +def dump_meta_page(page): + # metadata page + # general metadata + metadata = {} + meta_page = page[0:72] + _, pgno, magic, version, pagesize, encrypt_alg, pg_type, metaflags, _, free, last_pgno, nparts, key_count, record_count, flags, uid = struct.unpack('QIIIIBBBBIIIIII20s', meta_page) + metadata['pgno'] = pgno + metadata['magic'] = magic + metadata['version'] = version + metadata['pagesize'] = pagesize + metadata['encrypt_alg'] = encrypt_alg + metadata['pg_type'] = pg_type + metadata['metaflags'] = metaflags + metadata['free'] = free + metadata['last_pgno'] = last_pgno + metadata['nparts'] = nparts + metadata['key_count'] = key_count + metadata['record_count'] = record_count + metadata['flags'] = flags + metadata['uid'] = binascii.hexlify(uid) + + assert magic == BTREE_MAGIC, 'bdb magic does not match bdb btree magic' + assert pg_type == BTREE_META, 'Metadata page is not a btree metadata page' + assert version == DB_VERSION, 'Database too new' + + # btree metadata + btree_meta_page = page[72:512] + _, minkey, re_len, re_pad, root, _, crypto_magic, _, iv, chksum = struct.unpack('IIIII368sI12s16s20s', btree_meta_page) + metadata['minkey'] = minkey + metadata['re_len'] = re_len + metadata['re_pad'] = re_pad + metadata['root'] = root + metadata['crypto_magic'] = crypto_magic + metadata['iv'] = binascii.hexlify(iv) + metadata['chksum'] = binascii.hexlify(chksum) + return metadata + +# Given the dict from dump_leaf_page, get the key-value pairs and put them into a dict +def extract_kv_pairs(page_data): + out = {} + last_key = None + for i, entry in enumerate(page_data['entries']): + # By virtue of these all being pairs, even number entries are keys, and odd are values + if i % 2 == 0: + out[entry['data']] = b'' + last_key = entry['data'] + else: + out[last_key] = entry['data'] + return out + +# Extract the key-value pairs of the BDB file given in filename +def dump_bdb_kv(filename): + # Read in the BDB file and start deserializing it + pages = [] + with open(filename, 'rb') as f: + data = f.read(PAGESIZE) + while len(data) > 0: + pages.append(data) + data = f.read(PAGESIZE) + + # Sanity check the meta pages + dump_meta_page(pages[OUTER_META_PAGE]) + dump_meta_page(pages[INNER_META_PAGE]) + + # Fetch the kv pairs from the leaf pages + kv = {} + for i in range(3, len(pages)): + info = dump_leaf_page(pages[i]) + if info is not None: + info_kv = extract_kv_pairs(info) + kv = {**kv, **info_kv} + return kv diff --git a/test/functional/test_framework/util.py b/test/functional/test_framework/util.py index fd636ec136..0736669811 100644 --- a/test/functional/test_framework/util.py +++ b/test/functional/test_framework/util.py @@ -9,6 +9,7 @@ from binascii import unhexlify from decimal import Decimal, ROUND_DOWN from subprocess import CalledProcessError +import hashlib import inspect import json import logging @@ -268,6 +269,14 @@ def wait_until_helper(predicate, *, attempts=float('inf'), timeout=float('inf'), else: return False +def sha256sum_file(filename): + h = hashlib.sha256() + with open(filename, 'rb') as f: + d = f.read(4096) + while len(d) > 0: + h.update(d) + d = f.read(4096) + return h.digest() # RPC/P2P connection constants and functions ############################################ diff --git a/test/functional/wallet_hd.py b/test/functional/wallet_hd.py index f07bf69625..eaa2c0e710 100755 --- a/test/functional/wallet_hd.py +++ b/test/functional/wallet_hd.py @@ -11,6 +11,7 @@ from test_framework.test_framework import BitcoinTestFramework from test_framework.util import ( assert_equal, + assert_raises_rpc_error, ) class WalletHDTest(BitcoinTestFramework): @@ -137,5 +138,148 @@ def run_test(self): assert_equal(keypath[0:13], "m/44'/1'/0'/1") + if not self.options.descriptors: + # NOTE: sethdseed can't replace existing seed in Dash Core + # though bitcoin lets to do it. Therefore this functional test + # are not the same with bitcoin's + # Generate a new HD seed on node 1 and make sure it is set + + self.nodes[1].createwallet(wallet_name='wallet_new_seed', blank=True) + wallet_new_seed = self.nodes[1].get_wallet_rpc('wallet_new_seed') + assert 'hdchainid' not in wallet_new_seed.getwalletinfo() + wallet_new_seed.sethdseed() + new_masterkeyid = wallet_new_seed.getwalletinfo()['hdchainid'] + addr = wallet_new_seed.getnewaddress() + # Make sure the new address is the first from the keypool + assert_equal(wallet_new_seed.getaddressinfo(addr)['hdkeypath'], "m/44'/1'/0'/0/1") + wallet_new_seed.keypoolrefill(1) # Fill keypool with 1 key + + # Set a new HD seed on node 1 without flushing the keypool + new_seed = self.nodes[0].dumpprivkey(self.nodes[0].getnewaddress()) + assert_raises_rpc_error(-4, "Cannot set a HD seed. The wallet already has a seed", wallet_new_seed.sethdseed, False, new_seed) + self.nodes[1].createwallet(wallet_name='wallet_imported_seed', blank=True) + wallet_imported_seed = self.nodes[1].get_wallet_rpc('wallet_imported_seed') + wallet_imported_seed.sethdseed(False, new_seed) + + new_masterkeyid = wallet_imported_seed.getwalletinfo()['hdchainid'] + addr = wallet_imported_seed.getnewaddress() + assert_equal(new_masterkeyid, wallet_imported_seed.getaddressinfo(addr)['hdchainid']) + # Make sure the new address continues previous keypool + assert_equal(wallet_imported_seed.getaddressinfo(addr)['hdkeypath'], "m/44'/1'/0'/0/0") + + # Check that the next address is from the new seed + wallet_imported_seed.keypoolrefill(1) + next_addr = wallet_imported_seed.getnewaddress() + assert_equal(new_masterkeyid, wallet_imported_seed.getaddressinfo(next_addr)['hdchainid']) + # Make sure the new address is not from previous keypool + assert_equal(wallet_imported_seed.getaddressinfo(next_addr)['hdkeypath'], "m/44'/1'/0'/0/1") + assert next_addr != addr + + self.nodes[1].createwallet(wallet_name='wallet_no_seed', blank=True) + wallet_no_seed = self.nodes[1].get_wallet_rpc('wallet_no_seed') + wallet_no_seed.importprivkey(non_hd_key) + # Sethdseed parameter validity + assert_raises_rpc_error(-1, 'sethdseed', self.nodes[0].sethdseed, False, new_seed, 0) + assert_raises_rpc_error(-5, "Invalid private key", wallet_no_seed.sethdseed, False, "not_wif") + assert_raises_rpc_error(-1, "JSON value is not a boolean as expected", wallet_no_seed.sethdseed, "Not_bool") + assert_raises_rpc_error(-1, "JSON value is not a string as expected", wallet_no_seed.sethdseed, False, True) + assert_raises_rpc_error(-5, "Already have this key", wallet_no_seed.sethdseed, False, non_hd_key) + + self.log.info('Test sethdseed restoring with keys outside of the initial keypool') + self.nodes[0].generate(10) + # Restart node 1 with keypool of 3 and a different wallet + self.nodes[1].createwallet(wallet_name='origin', blank=True) + self.restart_node(1, extra_args=['-keypool=3', '-wallet=origin']) + self.connect_nodes(0, 1) + + # sethdseed restoring and seeing txs to addresses out of the keypool + origin_rpc = self.nodes[1].get_wallet_rpc('origin') + seed = self.nodes[0].dumpprivkey(self.nodes[0].getnewaddress()) + origin_rpc.sethdseed(True, seed) + + self.nodes[1].createwallet(wallet_name='restore', blank=True) + restore_rpc = self.nodes[1].get_wallet_rpc('restore') + restore_rpc.sethdseed(True, seed) # Set to be the same seed as origin_rpc + + self.nodes[1].createwallet(wallet_name='restore2', blank=True) + restore2_rpc = self.nodes[1].get_wallet_rpc('restore2') + restore2_rpc.sethdseed(True, seed) # Set to be the same seed as origin_rpc + + # Check persistence of inactive seed by reloading restore. restore2 is still loaded to test the case where the wallet is not reloaded + restore_rpc.unloadwallet() + self.nodes[1].loadwallet('restore') + restore_rpc = self.nodes[1].get_wallet_rpc('restore') + + # Empty origin keypool and get an address that is beyond the initial keypool + origin_rpc.getnewaddress() + origin_rpc.getnewaddress() + last_addr = origin_rpc.getnewaddress() # Last address of initial keypool + addr = origin_rpc.getnewaddress() # First address beyond initial keypool + + # Check that the restored seed has last_addr but does not have addr + info = restore_rpc.getaddressinfo(last_addr) + assert_equal(info['ismine'], True) + info = restore_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], False) + info = restore2_rpc.getaddressinfo(last_addr) + assert_equal(info['ismine'], True) + info = restore2_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], False) + # Check that the origin seed has addr + info = origin_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], True) + + # Send a transaction to addr, which is out of the initial keypool. + # The wallet that has set a new seed (restore_rpc) should not detect this transaction. + txid = self.nodes[0].sendtoaddress(addr, 1) + origin_rpc.sendrawtransaction(self.nodes[0].gettransaction(txid)['hex']) + self.nodes[0].generate(1) + self.sync_blocks() + origin_rpc.gettransaction(txid) + assert_raises_rpc_error(-5, 'Invalid or non-wallet transaction id', restore_rpc.gettransaction, txid) + out_of_kp_txid = txid + + # Send a transaction to last_addr, which is in the initial keypool. + # The wallet that has set a new seed (restore_rpc) should detect this transaction and generate 3 new keys from the initial seed. + # The previous transaction (out_of_kp_txid) should still not be detected as a rescan is required. + txid = self.nodes[0].sendtoaddress(last_addr, 1) + origin_rpc.sendrawtransaction(self.nodes[0].gettransaction(txid)['hex']) + self.nodes[0].generate(1) + self.sync_blocks() + origin_rpc.gettransaction(txid) + restore_rpc.gettransaction(txid) + assert_raises_rpc_error(-5, 'Invalid or non-wallet transaction id', restore_rpc.gettransaction, out_of_kp_txid) + restore2_rpc.gettransaction(txid) + assert_raises_rpc_error(-5, 'Invalid or non-wallet transaction id', restore2_rpc.gettransaction, out_of_kp_txid) + + # After rescanning, restore_rpc should now see out_of_kp_txid and generate an additional key. + # addr should now be part of restore_rpc and be ismine + restore_rpc.rescanblockchain() + restore_rpc.gettransaction(out_of_kp_txid) + info = restore_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], True) + restore2_rpc.rescanblockchain() + restore2_rpc.gettransaction(out_of_kp_txid) + info = restore2_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], True) + + # Check again that 3 keys were derived. + # Empty keypool and get an address that is beyond the initial keypool + origin_rpc.getnewaddress() + origin_rpc.getnewaddress() + last_addr = origin_rpc.getnewaddress() + addr = origin_rpc.getnewaddress() + + # Check that the restored seed has last_addr but does not have addr + info = restore_rpc.getaddressinfo(last_addr) + assert_equal(info['ismine'], True) + info = restore_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], False) + info = restore2_rpc.getaddressinfo(last_addr) + assert_equal(info['ismine'], True) + info = restore2_rpc.getaddressinfo(addr) + assert_equal(info['ismine'], False) + + if __name__ == '__main__': WalletHDTest().main () diff --git a/test/functional/wallet_listdescriptors.py b/test/functional/wallet_listdescriptors.py index e1ffd34ecd..8896a765f0 100755 --- a/test/functional/wallet_listdescriptors.py +++ b/test/functional/wallet_listdescriptors.py @@ -73,6 +73,10 @@ def run_test(self): } assert_equal(expected, wallet.listdescriptors()) + self.log.info("Test listdescriptors with encrypted wallet") + wallet.encryptwallet("pass") + assert_equal(expected, wallet.listdescriptors()) + self.log.info('Test non-active non-range combo descriptor') node.createwallet(wallet_name='w4', blank=True, descriptors=True) wallet = node.get_wallet_rpc('w4') diff --git a/test/functional/wallet_upgradewallet.py b/test/functional/wallet_upgradewallet.py index 4d81201c01..f7cf9ee8da 100755 --- a/test/functional/wallet_upgradewallet.py +++ b/test/functional/wallet_upgradewallet.py @@ -10,25 +10,29 @@ import os import shutil +import struct +from test_framework.bdb import dump_bdb_kv from test_framework.blocktools import COINBASE_MATURITY from test_framework.test_framework import BitcoinTestFramework from test_framework.util import ( assert_equal, assert_greater_than, - assert_greater_than_or_equal, assert_is_hex_string, + sha256sum_file, ) +UPGRADED_KEYMETA_VERSION = 12 + class UpgradeWalletTest(BitcoinTestFramework): def set_test_params(self): self.setup_clean_chain = True self.num_nodes = 3 self.extra_args = [ - [], # current wallet version - ["-usehd=1"], # v18.2.2 wallet - ["-usehd=0"] # v0.16.1.1 wallet + ["-keypool=2"], # current wallet version + ["-usehd=1", "-keypool=2"], # v18.2.2 wallet + ["-usehd=0", "-keypool=2"], # v0.16.1.1 wallet ] self.wallet_names = [self.default_wallet_name, None, None] @@ -65,6 +69,32 @@ def dumb_sync_blocks(self): v18_2_node.submitblock(b) assert_equal(v18_2_node.getblockcount(), to_height) + def test_upgradewallet(self, wallet, previous_version, requested_version=None, expected_version=None): + unchanged = expected_version == previous_version + new_version = previous_version if unchanged else expected_version if expected_version else requested_version + assert_equal(wallet.getwalletinfo()["walletversion"], previous_version) + assert_equal(wallet.upgradewallet(requested_version), + { + "wallet_name": "", + "previous_version": previous_version, + "current_version": new_version, + "result": "Already at latest version. Wallet version unchanged." if unchanged else "Wallet upgraded successfully from version {} to version {}.".format(previous_version, new_version), + } + ) + assert_equal(wallet.getwalletinfo()["walletversion"], new_version) + + def test_upgradewallet_error(self, wallet, previous_version, requested_version, msg): + assert_equal(wallet.getwalletinfo()["walletversion"], previous_version) + assert_equal(wallet.upgradewallet(requested_version), + { + "wallet_name": "", + "previous_version": previous_version, + "current_version": previous_version, + "error": msg, + } + ) + assert_equal(wallet.getwalletinfo()["walletversion"], previous_version) + def run_test(self): self.nodes[0].generatetoaddress(COINBASE_MATURITY + 1, self.nodes[0].getnewaddress()) self.dumb_sync_blocks() @@ -84,12 +114,12 @@ def run_test(self): self.log.info("Test upgradewallet RPC...") # Prepare for copying of the older wallet - node_master_wallet_dir = os.path.join(node_master.datadir, "regtest/wallets") + node_master_wallet_dir = os.path.join(node_master.datadir, "regtest/wallets", self.default_wallet_name) + node_master_wallet = os.path.join(node_master_wallet_dir, self.default_wallet_name, self.wallet_data_filename) v18_2_wallet = os.path.join(v18_2_node.datadir, "regtest/wallets/wallet.dat") v16_1_wallet = os.path.join(v16_1_node.datadir, "regtest/wallets/wallet.dat") self.stop_nodes() - # Copy the 0.16.3 wallet to the last Dash Core version and open it: shutil.rmtree(node_master_wallet_dir) os.mkdir(node_master_wallet_dir) shutil.copy( @@ -99,39 +129,45 @@ def run_test(self): self.restart_node(0, ['-nowallet']) node_master.loadwallet('') - wallet = node_master.get_wallet_rpc('') - old_version = wallet.getwalletinfo()["walletversion"] - - # calling upgradewallet without version arguments - # should return nothing if successful - assert_equal(wallet.upgradewallet(), {}) - new_version = wallet.getwalletinfo()["walletversion"] - # upgraded wallet version should be greater than older one - assert_greater_than_or_equal(new_version, old_version) + def copy_v16(): + node_master.get_wallet_rpc(self.default_wallet_name).unloadwallet() + # Copy the 0.16.3 wallet to the last Dash Core version and open it: + shutil.rmtree(node_master_wallet_dir) + os.mkdir(node_master_wallet_dir) + shutil.copy( + v18_2_wallet, + node_master_wallet_dir + ) + node_master.loadwallet(self.default_wallet_name) + + def copy_non_hd(): + node_master.get_wallet_rpc(self.default_wallet_name).unloadwallet() + # Copy the 19.3.0 wallet to the last Dash Core version and open it: + shutil.rmtree(node_master_wallet_dir) + os.mkdir(node_master_wallet_dir) + shutil.copy( + v16_1_wallet, + node_master_wallet_dir + ) + node_master.loadwallet(self.default_wallet_name) + + + self.restart_node(0) + copy_v16() + wallet = node_master.get_wallet_rpc(self.default_wallet_name) + self.test_upgradewallet(wallet, previous_version=120200, expected_version=120200) # wallet should still contain the same balance assert_equal(wallet.getbalance(), v18_2_balance) - self.stop_node(0) - # Copy the 19.3.0 wallet to the last Dash Core version and open it: - shutil.rmtree(node_master_wallet_dir) - os.mkdir(node_master_wallet_dir) - shutil.copy( - v16_1_wallet, - node_master_wallet_dir - ) - self.restart_node(0, ['-nowallet']) - node_master.loadwallet('') - - wallet = node_master.get_wallet_rpc('') + copy_non_hd() + wallet = node_master.get_wallet_rpc(self.default_wallet_name) # should have no master key hash before conversion - assert_equal('hdseedid' in wallet.getwalletinfo(), False) + assert_equal('hdchainid' in wallet.getwalletinfo(), False) # calling upgradewallet with explicit version number # should return nothing if successful - assert_equal(wallet.upgradewallet(169900), {}) - new_version = wallet.getwalletinfo()["walletversion"] - # upgraded wallet would not have 120200 version until HD seed actually appeared - assert_greater_than(120200, new_version) + self.log.info("Test upgradewallet to HD will have version 120200 but no HD seed actually appeared") + self.test_upgradewallet(wallet, previous_version=61000, expected_version=120200, requested_version=169900) # after conversion master key hash should not be present yet assert 'hdchainid' not in wallet.getwalletinfo() assert_equal(wallet.upgradetohd(), True) @@ -139,5 +175,65 @@ def run_test(self): assert_equal(new_version, 120200) assert_is_hex_string(wallet.getwalletinfo()['hdchainid']) + self.log.info("Intermediary versions don't effect anything") + copy_non_hd() + # Wallet starts with 61000 (legacy "latest") + assert_equal(61000, wallet.getwalletinfo()['walletversion']) + wallet.unloadwallet() + before_checksum = sha256sum_file(node_master_wallet) + node_master.loadwallet('') + # Test an "upgrade" from 61000 to 120199 has no effect, as the next version is 120200 + self.test_upgradewallet(wallet, previous_version=61000, requested_version=120199, expected_version=61000) + wallet.unloadwallet() + assert_equal(before_checksum, sha256sum_file(node_master_wallet)) + node_master.loadwallet('') + + self.log.info('Wallets cannot be downgraded') + copy_non_hd() + self.test_upgradewallet_error(wallet, previous_version=61000, requested_version=40000, + msg="Cannot downgrade wallet from version 61000 to version 40000. Wallet version unchanged.") + wallet.unloadwallet() + assert_equal(before_checksum, sha256sum_file(node_master_wallet)) + node_master.loadwallet('') + + self.log.info('Can upgrade to HD') + # Inspect the old wallet and make sure there is no hdchain + orig_kvs = dump_bdb_kv(node_master_wallet) + assert b'\x07hdchain' not in orig_kvs + # Upgrade to HD + self.test_upgradewallet(wallet, previous_version=61000, requested_version=120200) + # Check that there is now a hd chain and it is version 1, no internal chain counter + new_kvs = dump_bdb_kv(node_master_wallet) + wallet.upgradetohd() + new_kvs = dump_bdb_kv(node_master_wallet) + assert b'\x07hdchain' in new_kvs + hd_chain = new_kvs[b'\x07hdchain'] + # hd_chain: + # obj.nVersion int + # obj.id uint256 + # obj.fCrypted bool + # obj.vchSeed SecureVector + # obj.vchMnemonic SecureVector + # obj.vchMnemonicPassphrase SecureVector + # obj.mapAccounts map<> accounts + assert_greater_than(220, len(hd_chain)) + assert_greater_than(len(hd_chain), 180) + hd_chain_version, seed_id, is_crypted = struct.unpack('