Skip to content

Commit

Permalink
feat(btc): txn add EstimateTransactionSize().
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhangguiguang committed May 24, 2024
1 parent dc98732 commit ba6c628
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 49 deletions.
59 changes: 30 additions & 29 deletions core/btc/transaction_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,59 @@ const (
)

type Transaction struct {
inputs []input
outputs []output
netParams *chaincfg.Params
netParams *chaincfg.Params
msgTx *wire.MsgTx
prevOutFetcher *txscript.MultiPrevOutFetcher
}

type input struct {
outPoint *wire.OutPoint
prevOut *wire.TxOut
}

type output *wire.TxOut

func NewTransaction(chainnet string) (*Transaction, error) {
net, err := netParamsOf(chainnet)
if err != nil {
return nil, err
}
return &Transaction{netParams: net}, nil
tx := wire.NewMsgTx(wire.TxVersion)
prevOutFetcher := txscript.NewMultiPrevOutFetcher(nil)
return &Transaction{
netParams: net,
msgTx: tx,
prevOutFetcher: prevOutFetcher,
}, nil
}

func (t *Transaction) TotalInputValue() int64 {
total := int64(0)
for _, v := range t.inputs {
total += v.prevOut.Value
for _, v := range t.msgTx.TxIn {
out := t.prevOutFetcher.FetchPrevOutput(v.PreviousOutPoint)
total += out.Value
}
return total
}

func (t *Transaction) TotalOutputValue() int64 {
total := int64(0)
for _, v := range t.outputs {
for _, v := range t.msgTx.TxOut {
total += v.Value
}
return total
}

func (t *Transaction) EstimateTransactionSize() int64 {
return virtualSize(ensureSignOrFakeSign(t.msgTx, t.prevOutFetcher))
}

func (t *Transaction) AddInput(txId string, index int64, address string, value int64) error {
outPoint, err := outPoint(txId, uint32(index))
if err != nil {
return err
}
pkScript, err := addrToPkScript(address, t.netParams)
pkScript, err := addressToPkScript(address, t.netParams)
if err != nil {
return err
}
input := input{
outPoint: outPoint,
prevOut: wire.NewTxOut(value, pkScript),
}
t.inputs = append(t.inputs, input)
txIn := wire.NewTxIn(outPoint, nil, nil)
prevOut := wire.NewTxOut(value, pkScript)
t.msgTx.TxIn = append(t.msgTx.TxIn, txIn)
t.prevOutFetcher.AddPrevOut(*outPoint, prevOut)
return nil
}

Expand All @@ -80,11 +83,9 @@ func (t *Transaction) AddInput2(txId string, index int64, prevTx string) error {
if err != nil {
return err
}
input := input{
outPoint: outPoint,
prevOut: prevOut,
}
t.inputs = append(t.inputs, input)
txIn := wire.NewTxIn(outPoint, nil, nil)
t.msgTx.TxIn = append(t.msgTx.TxIn, txIn)
t.prevOutFetcher.AddPrevOut(*outPoint, prevOut)
return nil
}

Expand All @@ -93,12 +94,12 @@ func (t *Transaction) AddOutput(address string, value int64) error {
if value == 0 {
return t.AddOpReturn(address)
}
pkScript, err := addrToPkScript(address, t.netParams)
pkScript, err := addressToPkScript(address, t.netParams)
if err != nil {
return err
}
output := wire.NewTxOut(value, pkScript)
t.outputs = append(t.outputs, output)
t.msgTx.TxOut = append(t.msgTx.TxOut, output)
return nil
}

Expand All @@ -109,7 +110,7 @@ func (t *Transaction) AddOpReturn(opReturn string) error {
return err
}
output := wire.NewTxOut(0, script)
t.outputs = append(t.outputs, output)
t.msgTx.TxOut = append(t.msgTx.TxOut, output)
return nil
}

Expand Down Expand Up @@ -139,7 +140,7 @@ func prevTxOut(preTx string, index uint32) (*wire.TxOut, error) {
}
}

func addrToPkScript(addr string, network *chaincfg.Params) ([]byte, error) {
func addressToPkScript(addr string, network *chaincfg.Params) ([]byte, error) {
address, err := btcutil.DecodeAddress(addr, network)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion core/btc/transaction_build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func TestTransaction_AddOpReturn(t *testing.T) {
require.NoError(t, err)
err = txn.AddOpReturn("ComingChat")
require.NoError(t, err)
require.Equal(t, hex.EncodeToString(txn.outputs[0].PkScript), "6a0a436f6d696e6743686174")
require.Equal(t, hex.EncodeToString(txn.msgTx.TxOut[0].PkScript), "6a0a436f6d696e6743686174")

err = txn.AddOpReturn("len75hellohellohellohellohellohellohellohellohellohellohellohellohellohello")
require.NoError(t, err)
Expand Down
28 changes: 27 additions & 1 deletion core/btc/transaction_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strconv"

"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript"
Expand Down Expand Up @@ -93,7 +94,8 @@ func DecodePsbtTransactionDetail(psbtHex string, chainnet string) (d *Transactio
return
}
feeFloat := txFee.ToUnit(btcutil.AmountSatoshi)
vSize := virtualSize(packet.UnsignedTx)
copyedTx := packet.UnsignedTx.Copy()
vSize := virtualSize(ensureSignOrFakeSign(copyedTx, nil))
feeRate := feeFloat / float64(vSize)

inputs := make([]*TxOut, len(packet.Inputs))
Expand Down Expand Up @@ -205,3 +207,27 @@ func virtualSize(tx *wire.MsgTx) int64 {
weight := (baseSize * (blockchain.WitnessScaleFactor - 1)) + totalSize
return (weight + blockchain.WitnessScaleFactor - 1) / blockchain.WitnessScaleFactor
}

var _fakePrivatekey *btcec.PrivateKey

func ensureSignOrFakeSign(tx *wire.MsgTx, fetcher txscript.PrevOutputFetcher) *wire.MsgTx {
if len(tx.TxIn) == 0 || len(tx.TxOut) == 0 {
return tx // cannot sign
}
if len(tx.TxIn[0].SignatureScript) != 0 || len(tx.TxIn[0].Witness) != 0 {
return tx // no need sign
}

var err error
if _fakePrivatekey == nil {
if _fakePrivatekey, err = btcec.NewPrivateKey(); err != nil {
return tx // sign failed
}
}
if fetcher == nil {
fetcher = txscript.NewCannedPrevOutputFetcher(tx.TxOut[0].PkScript, 100000000)
}
// fake sign
_ = Sign(tx, _fakePrivatekey, fetcher, false)
return tx
}
25 changes: 7 additions & 18 deletions core/btc/transaction_sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,20 @@ func (t *Transaction) SignWithAccount(account base.Account) (signedTxn *base.Opt
}

func (t *Transaction) SignedTransactionWithAccount(account base.Account) (signedTxn base.SignedTransaction, err error) {
if len(t.inputs) == 0 || len(t.outputs) == 0 {
return nil, errors.New("invalid inputs or outputs")
}

btcAcc, ok := account.(*Account)
if !ok {
return nil, base.ErrInvalidAccountType
}

tx := wire.NewMsgTx(wire.TxVersion)
prevOutFetcher := txscript.NewMultiPrevOutFetcher(nil)
for _, input := range t.inputs {
txIn := wire.NewTxIn(input.outPoint, nil, nil)
tx.TxIn = append(tx.TxIn, txIn)
prevOutFetcher.AddPrevOut(*input.outPoint, input.prevOut)
}
for _, output := range t.outputs {
tx.TxOut = append(tx.TxOut, output)
}

copyedTx := t.msgTx.Copy()
privateKey := btcAcc.privateKey
isComing := btcAcc.addressType == AddressTypeComingTaproot
err = Sign(tx, privateKey, prevOutFetcher, isComing)
err = Sign(copyedTx, privateKey, t.prevOutFetcher, isComing)
if err != nil {
return nil, err
}
return &SignedTransaction{
msgTx: tx,
msgTx: copyedTx,
}, nil
}

Expand All @@ -66,7 +52,10 @@ func (t *SignedTransaction) HexString() (res *base.OptionalString, err error) {
return base.NewOptionalString(str), nil
}

func Sign(tx *wire.MsgTx, privKey *btcec.PrivateKey, prevOutFetcher *txscript.MultiPrevOutFetcher, isComing bool) error {
func Sign(tx *wire.MsgTx, privKey *btcec.PrivateKey, prevOutFetcher txscript.PrevOutputFetcher, isComing bool) error {
if len(tx.TxIn) == 0 || len(tx.TxOut) == 0 {
return errors.New("invalid inputs or outputs")
}
for i, in := range tx.TxIn {
prevOut := prevOutFetcher.FetchPrevOutput(in.PreviousOutPoint)
txSigHashes := txscript.NewTxSigHashes(tx, prevOutFetcher)
Expand Down

0 comments on commit ba6c628

Please sign in to comment.