diff --git a/core/btc/transaction_build.go b/core/btc/transaction_build.go index c8ac8a2..7d72219 100644 --- a/core/btc/transaction_build.go +++ b/core/btc/transaction_build.go @@ -18,56 +18,59 @@ const ( ) type Transaction struct { - inputs []input - outputs []output - netParams *chaincfg.Params + netParams *chaincfg.Params + msgTx *wire.MsgTx + prevOutFetcher *txscript.MultiPrevOutFetcher } -type input struct { - outPoint *wire.OutPoint - prevOut *wire.TxOut -} - -type output *wire.TxOut - func NewTransaction(chainnet string) (*Transaction, error) { net, err := netParamsOf(chainnet) if err != nil { return nil, err } - return &Transaction{netParams: net}, nil + tx := wire.NewMsgTx(wire.TxVersion) + prevOutFetcher := txscript.NewMultiPrevOutFetcher(nil) + return &Transaction{ + netParams: net, + msgTx: tx, + prevOutFetcher: prevOutFetcher, + }, nil } func (t *Transaction) TotalInputValue() int64 { total := int64(0) - for _, v := range t.inputs { - total += v.prevOut.Value + for _, v := range t.msgTx.TxIn { + out := t.prevOutFetcher.FetchPrevOutput(v.PreviousOutPoint) + total += out.Value } return total } func (t *Transaction) TotalOutputValue() int64 { total := int64(0) - for _, v := range t.outputs { + for _, v := range t.msgTx.TxOut { total += v.Value } return total } +func (t *Transaction) EstimateTransactionSize() int64 { + return virtualSize(ensureSignOrFakeSign(t.msgTx, t.prevOutFetcher)) +} + func (t *Transaction) AddInput(txId string, index int64, address string, value int64) error { outPoint, err := outPoint(txId, uint32(index)) if err != nil { return err } - pkScript, err := addrToPkScript(address, t.netParams) + pkScript, err := addressToPkScript(address, t.netParams) if err != nil { return err } - input := input{ - outPoint: outPoint, - prevOut: wire.NewTxOut(value, pkScript), - } - t.inputs = append(t.inputs, input) + txIn := wire.NewTxIn(outPoint, nil, nil) + prevOut := wire.NewTxOut(value, pkScript) + t.msgTx.TxIn = append(t.msgTx.TxIn, txIn) + t.prevOutFetcher.AddPrevOut(*outPoint, prevOut) return nil } @@ -80,11 +83,9 @@ func (t *Transaction) AddInput2(txId string, index int64, prevTx string) error { if err != nil { return err } - input := input{ - outPoint: outPoint, - prevOut: prevOut, - } - t.inputs = append(t.inputs, input) + txIn := wire.NewTxIn(outPoint, nil, nil) + t.msgTx.TxIn = append(t.msgTx.TxIn, txIn) + t.prevOutFetcher.AddPrevOut(*outPoint, prevOut) return nil } @@ -93,12 +94,12 @@ func (t *Transaction) AddOutput(address string, value int64) error { if value == 0 { return t.AddOpReturn(address) } - pkScript, err := addrToPkScript(address, t.netParams) + pkScript, err := addressToPkScript(address, t.netParams) if err != nil { return err } output := wire.NewTxOut(value, pkScript) - t.outputs = append(t.outputs, output) + t.msgTx.TxOut = append(t.msgTx.TxOut, output) return nil } @@ -109,7 +110,7 @@ func (t *Transaction) AddOpReturn(opReturn string) error { return err } output := wire.NewTxOut(0, script) - t.outputs = append(t.outputs, output) + t.msgTx.TxOut = append(t.msgTx.TxOut, output) return nil } @@ -139,7 +140,7 @@ func prevTxOut(preTx string, index uint32) (*wire.TxOut, error) { } } -func addrToPkScript(addr string, network *chaincfg.Params) ([]byte, error) { +func addressToPkScript(addr string, network *chaincfg.Params) ([]byte, error) { address, err := btcutil.DecodeAddress(addr, network) if err != nil { return nil, err diff --git a/core/btc/transaction_build_test.go b/core/btc/transaction_build_test.go index 3a61c47..c488293 100644 --- a/core/btc/transaction_build_test.go +++ b/core/btc/transaction_build_test.go @@ -12,7 +12,7 @@ func TestTransaction_AddOpReturn(t *testing.T) { require.NoError(t, err) err = txn.AddOpReturn("ComingChat") require.NoError(t, err) - require.Equal(t, hex.EncodeToString(txn.outputs[0].PkScript), "6a0a436f6d696e6743686174") + require.Equal(t, hex.EncodeToString(txn.msgTx.TxOut[0].PkScript), "6a0a436f6d696e6743686174") err = txn.AddOpReturn("len75hellohellohellohellohellohellohellohellohellohellohellohellohellohello") require.NoError(t, err) diff --git a/core/btc/transaction_decode.go b/core/btc/transaction_decode.go index 5e34ab0..8adaa0a 100644 --- a/core/btc/transaction_decode.go +++ b/core/btc/transaction_decode.go @@ -5,6 +5,7 @@ import ( "strconv" "github.com/btcsuite/btcd/blockchain" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/txscript" @@ -93,7 +94,8 @@ func DecodePsbtTransactionDetail(psbtHex string, chainnet string) (d *Transactio return } feeFloat := txFee.ToUnit(btcutil.AmountSatoshi) - vSize := virtualSize(packet.UnsignedTx) + copyedTx := packet.UnsignedTx.Copy() + vSize := virtualSize(ensureSignOrFakeSign(copyedTx, nil)) feeRate := feeFloat / float64(vSize) inputs := make([]*TxOut, len(packet.Inputs)) @@ -205,3 +207,27 @@ func virtualSize(tx *wire.MsgTx) int64 { weight := (baseSize * (blockchain.WitnessScaleFactor - 1)) + totalSize return (weight + blockchain.WitnessScaleFactor - 1) / blockchain.WitnessScaleFactor } + +var _fakePrivatekey *btcec.PrivateKey + +func ensureSignOrFakeSign(tx *wire.MsgTx, fetcher txscript.PrevOutputFetcher) *wire.MsgTx { + if len(tx.TxIn) == 0 || len(tx.TxOut) == 0 { + return tx // cannot sign + } + if len(tx.TxIn[0].SignatureScript) != 0 || len(tx.TxIn[0].Witness) != 0 { + return tx // no need sign + } + + var err error + if _fakePrivatekey == nil { + if _fakePrivatekey, err = btcec.NewPrivateKey(); err != nil { + return tx // sign failed + } + } + if fetcher == nil { + fetcher = txscript.NewCannedPrevOutputFetcher(tx.TxOut[0].PkScript, 100000000) + } + // fake sign + _ = Sign(tx, _fakePrivatekey, fetcher, false) + return tx +} diff --git a/core/btc/transaction_sign.go b/core/btc/transaction_sign.go index 74182fe..b04e339 100644 --- a/core/btc/transaction_sign.go +++ b/core/btc/transaction_sign.go @@ -26,34 +26,20 @@ func (t *Transaction) SignWithAccount(account base.Account) (signedTxn *base.Opt } func (t *Transaction) SignedTransactionWithAccount(account base.Account) (signedTxn base.SignedTransaction, err error) { - if len(t.inputs) == 0 || len(t.outputs) == 0 { - return nil, errors.New("invalid inputs or outputs") - } - btcAcc, ok := account.(*Account) if !ok { return nil, base.ErrInvalidAccountType } - tx := wire.NewMsgTx(wire.TxVersion) - prevOutFetcher := txscript.NewMultiPrevOutFetcher(nil) - for _, input := range t.inputs { - txIn := wire.NewTxIn(input.outPoint, nil, nil) - tx.TxIn = append(tx.TxIn, txIn) - prevOutFetcher.AddPrevOut(*input.outPoint, input.prevOut) - } - for _, output := range t.outputs { - tx.TxOut = append(tx.TxOut, output) - } - + copyedTx := t.msgTx.Copy() privateKey := btcAcc.privateKey isComing := btcAcc.addressType == AddressTypeComingTaproot - err = Sign(tx, privateKey, prevOutFetcher, isComing) + err = Sign(copyedTx, privateKey, t.prevOutFetcher, isComing) if err != nil { return nil, err } return &SignedTransaction{ - msgTx: tx, + msgTx: copyedTx, }, nil } @@ -66,7 +52,10 @@ func (t *SignedTransaction) HexString() (res *base.OptionalString, err error) { return base.NewOptionalString(str), nil } -func Sign(tx *wire.MsgTx, privKey *btcec.PrivateKey, prevOutFetcher *txscript.MultiPrevOutFetcher, isComing bool) error { +func Sign(tx *wire.MsgTx, privKey *btcec.PrivateKey, prevOutFetcher txscript.PrevOutputFetcher, isComing bool) error { + if len(tx.TxIn) == 0 || len(tx.TxOut) == 0 { + return errors.New("invalid inputs or outputs") + } for i, in := range tx.TxIn { prevOut := prevOutFetcher.FetchPrevOutput(in.PreviousOutPoint) txSigHashes := txscript.NewTxSigHashes(tx, prevOutFetcher)