Compare commits
10 Commits
bdee27b19e
...
9f5189cd00
Author | SHA1 | Date |
---|---|---|
Gyu-Ho Lee | 9f5189cd00 | |
Gyu-Ho Lee | 6bc9b6e4ed | |
Gyu-Ho Lee | bc9ddf2601 | |
Gyu-Ho Lee | e7285f5626 | |
Gyu-Ho Lee | 4613a7e61b | |
Xiang Li | 31d1fa20bf | |
Hitoshi Mitake | 259f89d59a | |
Gyu-Ho Lee | 756d701f13 | |
Xiang Li | d14a673ace | |
Gyu-Ho Lee | 4abb231505 |
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"ImportPath": "github.com/coreos/etcd",
|
||||
"GoVersion": "go1.4.2",
|
||||
"GoVersion": "go1.5.1",
|
||||
"Packages": [
|
||||
"./..."
|
||||
],
|
||||
|
@ -10,6 +10,10 @@
|
|||
"Comment": "null-5",
|
||||
"Rev": "75cd24fc2f2c2a2088577d12123ddee5f54e0675"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/akrennmair/gopcap",
|
||||
"Rev": "00e11033259acb75598ba416495bb708d864a010"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/beorn7/perks/quantile",
|
||||
"Rev": "b965b613227fddccbfffe13eae360ed3fa822f8d"
|
||||
|
@ -20,17 +24,21 @@
|
|||
},
|
||||
{
|
||||
"ImportPath": "github.com/boltdb/bolt",
|
||||
"Comment": "v1.0-119-g90fef38",
|
||||
"Rev": "90fef389f98027ca55594edd7dbd6e7f3926fdad"
|
||||
"Comment": "v1.1.0-19-g0b00eff",
|
||||
"Rev": "0b00effdd7a8270ebd91c24297e51643e370dd52"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/bradfitz/http2",
|
||||
"Rev": "3e36af6d3af0e56fa3da71099f864933dea3d9fb"
|
||||
"ImportPath": "github.com/cheggaaa/pb",
|
||||
"Rev": "da1f27ad1d9509b16f65f52fd9d8138b0f2dc7b2"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/codegangsta/cli",
|
||||
"Comment": "1.2.0-26-gf7ebb76",
|
||||
"Rev": "f7ebb761e83e21225d1d8954fde853bf8edd46c4"
|
||||
"Comment": "1.2.0-183-gb5232bb",
|
||||
"Rev": "b5232bb2934f606f9f27a1305f1eea224e8e8b88"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/coreos/gexpect",
|
||||
"Rev": "5173270e159f5aa8fbc999dc7e3dcb50f4098a69"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/coreos/go-semver/semver",
|
||||
|
@ -55,9 +63,15 @@
|
|||
"ImportPath": "github.com/coreos/pkg/capnslog",
|
||||
"Rev": "2c77715c4df99b5420ffcae14ead08f52104065d"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/cpuguy83/go-md2man/md2man",
|
||||
"Comment": "v1.0.4",
|
||||
"Rev": "71acacd42f85e5e82f70a55327789582a5200a90"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/gogo/protobuf/proto",
|
||||
"Rev": "64f27bf06efee53589314a6e5a4af34cdd85adf6"
|
||||
"Comment": "v0.1-118-ge8904f5",
|
||||
"Rev": "e8904f58e872a473a5b91bc9bf3377d223555263"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/golang/glog",
|
||||
|
@ -65,20 +79,37 @@
|
|||
},
|
||||
{
|
||||
"ImportPath": "github.com/golang/protobuf/proto",
|
||||
"Rev": "5677a0e3d5e89854c9974e1256839ee23f8233ca"
|
||||
"Rev": "6aaa8d47701fa6cf07e914ec01fde3d4a1fe79c3"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/google/btree",
|
||||
"Rev": "cc6329d4279e3f025a53a83c397d2339b5705c45"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/inconshreveable/mousetrap",
|
||||
"Rev": "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/jonboulle/clockwork",
|
||||
"Rev": "72f9bd7c4e0c2a40055ab3d0f09654f730cce982"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/kballard/go-shellquote",
|
||||
"Rev": "d8ec1a69a250a17bb0e419c386eac1f3711dc142"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/kr/pty",
|
||||
"Comment": "release.r56-29-gf7ee69f",
|
||||
"Rev": "f7ee69f31298ecbe5d2b349c711e2547a617d398"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/pbutil",
|
||||
"Rev": "fc2b8d3a73c4867e51861bbdd5ae3c1f0869dd6a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/olekukonko/ts",
|
||||
"Rev": "ecf753e7c962639ab5a1fb46f7da627d4c0a04b8"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/prometheus/client_golang/prometheus",
|
||||
"Comment": "0.7.0-52-ge51041b",
|
||||
|
@ -102,8 +133,25 @@
|
|||
"Rev": "454a56f35412459b5e684fd5ec0f9211b94f002a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/rakyll/pb",
|
||||
"Rev": "dc507ad06b7462501281bb4691ee43f0b1d1ec37"
|
||||
"ImportPath": "github.com/russross/blackfriday",
|
||||
"Comment": "v1.4-2-g300106c",
|
||||
"Rev": "300106c228d52c8941d4b3de6054a6062a86dda3"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/shurcooL/sanitized_anchor_name",
|
||||
"Rev": "10ef21a441db47d8b13ebcc5fd2310f636973c77"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/spacejam/loghisto",
|
||||
"Rev": "323309774dec8b7430187e46cd0793974ccca04a"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/spf13/cobra",
|
||||
"Rev": "1c44ec8d3f1552cac48999f9306da23c4d8a288b"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/spf13/pflag",
|
||||
"Rev": "08b1a584251b5b62f458943640fc8ebd4d50aaa5"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/stretchr/testify/assert",
|
||||
|
@ -127,27 +175,27 @@
|
|||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/context",
|
||||
"Rev": "7dbad50ab5b31073856416cdcfeb2796d682f844"
|
||||
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/oauth2",
|
||||
"Rev": "3046bc76d6dfd7d3707f6640f85e42d9c4050f50"
|
||||
"ImportPath": "golang.org/x/net/http2",
|
||||
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/internal/timeseries",
|
||||
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/trace",
|
||||
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/sys/unix",
|
||||
"Rev": "9c60d1c508f5134d1ca726b4641db998f2523357"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/cloud/compute/metadata",
|
||||
"Rev": "f20d6dcccb44ed49de45ae3703312cb46e627db1"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/cloud/internal",
|
||||
"Rev": "f20d6dcccb44ed49de45ae3703312cb46e627db1"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc",
|
||||
"Rev": "f5ebd86be717593ab029545492c93ddf8914832b"
|
||||
"Rev": "e29d659177655e589850ba7d3d83f7ce12ef23dd"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
#*
|
||||
*~
|
||||
/tools/pass/pass
|
||||
/tools/pcaptest/pcaptest
|
||||
/tools/tcpdump/tcpdump
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2009-2011 Andreas Krennmair. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Andreas Krennmair nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,11 @@
|
|||
# PCAP
|
||||
|
||||
This is a simple wrapper around libpcap for Go. Originally written by Andreas
|
||||
Krennmair <ak@synflood.at> and only minorly touched up by Mark Smith <mark@qq.is>.
|
||||
|
||||
Please see the included pcaptest.go and tcpdump.go programs for instructions on
|
||||
how to use this library.
|
||||
|
||||
Miek Gieben <miek@miek.nl> has created a more Go-like package and replaced functionality
|
||||
with standard functions from the standard library. The package has also been renamed to
|
||||
pcap.
|
|
@ -0,0 +1,527 @@
|
|||
package pcap
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
TYPE_IP = 0x0800
|
||||
TYPE_ARP = 0x0806
|
||||
TYPE_IP6 = 0x86DD
|
||||
TYPE_VLAN = 0x8100
|
||||
|
||||
IP_ICMP = 1
|
||||
IP_INIP = 4
|
||||
IP_TCP = 6
|
||||
IP_UDP = 17
|
||||
)
|
||||
|
||||
const (
|
||||
ERRBUF_SIZE = 256
|
||||
|
||||
// According to pcap-linktype(7).
|
||||
LINKTYPE_NULL = 0
|
||||
LINKTYPE_ETHERNET = 1
|
||||
LINKTYPE_TOKEN_RING = 6
|
||||
LINKTYPE_ARCNET = 7
|
||||
LINKTYPE_SLIP = 8
|
||||
LINKTYPE_PPP = 9
|
||||
LINKTYPE_FDDI = 10
|
||||
LINKTYPE_ATM_RFC1483 = 100
|
||||
LINKTYPE_RAW = 101
|
||||
LINKTYPE_PPP_HDLC = 50
|
||||
LINKTYPE_PPP_ETHER = 51
|
||||
LINKTYPE_C_HDLC = 104
|
||||
LINKTYPE_IEEE802_11 = 105
|
||||
LINKTYPE_FRELAY = 107
|
||||
LINKTYPE_LOOP = 108
|
||||
LINKTYPE_LINUX_SLL = 113
|
||||
LINKTYPE_LTALK = 104
|
||||
LINKTYPE_PFLOG = 117
|
||||
LINKTYPE_PRISM_HEADER = 119
|
||||
LINKTYPE_IP_OVER_FC = 122
|
||||
LINKTYPE_SUNATM = 123
|
||||
LINKTYPE_IEEE802_11_RADIO = 127
|
||||
LINKTYPE_ARCNET_LINUX = 129
|
||||
LINKTYPE_LINUX_IRDA = 144
|
||||
LINKTYPE_LINUX_LAPD = 177
|
||||
)
|
||||
|
||||
type addrHdr interface {
|
||||
SrcAddr() string
|
||||
DestAddr() string
|
||||
Len() int
|
||||
}
|
||||
|
||||
type addrStringer interface {
|
||||
String(addr addrHdr) string
|
||||
}
|
||||
|
||||
func decodemac(pkt []byte) uint64 {
|
||||
mac := uint64(0)
|
||||
for i := uint(0); i < 6; i++ {
|
||||
mac = (mac << 8) + uint64(pkt[i])
|
||||
}
|
||||
return mac
|
||||
}
|
||||
|
||||
// Decode decodes the headers of a Packet.
|
||||
func (p *Packet) Decode() {
|
||||
if len(p.Data) <= 14 {
|
||||
return
|
||||
}
|
||||
|
||||
p.Type = int(binary.BigEndian.Uint16(p.Data[12:14]))
|
||||
p.DestMac = decodemac(p.Data[0:6])
|
||||
p.SrcMac = decodemac(p.Data[6:12])
|
||||
|
||||
if len(p.Data) >= 15 {
|
||||
p.Payload = p.Data[14:]
|
||||
}
|
||||
|
||||
switch p.Type {
|
||||
case TYPE_IP:
|
||||
p.decodeIp()
|
||||
case TYPE_IP6:
|
||||
p.decodeIp6()
|
||||
case TYPE_ARP:
|
||||
p.decodeArp()
|
||||
case TYPE_VLAN:
|
||||
p.decodeVlan()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Packet) headerString(headers []interface{}) string {
|
||||
// If there's just one header, return that.
|
||||
if len(headers) == 1 {
|
||||
if hdr, ok := headers[0].(fmt.Stringer); ok {
|
||||
return hdr.String()
|
||||
}
|
||||
}
|
||||
// If there are two headers (IPv4/IPv6 -> TCP/UDP/IP..)
|
||||
if len(headers) == 2 {
|
||||
// Commonly the first header is an address.
|
||||
if addr, ok := p.Headers[0].(addrHdr); ok {
|
||||
if hdr, ok := p.Headers[1].(addrStringer); ok {
|
||||
return fmt.Sprintf("%s %s", p.Time, hdr.String(addr))
|
||||
}
|
||||
}
|
||||
}
|
||||
// For IP in IP, we do a recursive call.
|
||||
if len(headers) >= 2 {
|
||||
if addr, ok := headers[0].(addrHdr); ok {
|
||||
if _, ok := headers[1].(addrHdr); ok {
|
||||
return fmt.Sprintf("%s > %s IP in IP: ",
|
||||
addr.SrcAddr(), addr.DestAddr(), p.headerString(headers[1:]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var typeNames []string
|
||||
for _, hdr := range headers {
|
||||
typeNames = append(typeNames, reflect.TypeOf(hdr).String())
|
||||
}
|
||||
|
||||
return fmt.Sprintf("unknown [%s]", strings.Join(typeNames, ","))
|
||||
}
|
||||
|
||||
// String prints a one-line representation of the packet header.
|
||||
// The output is suitable for use in a tcpdump program.
|
||||
func (p *Packet) String() string {
|
||||
// If there are no headers, print "unsupported protocol".
|
||||
if len(p.Headers) == 0 {
|
||||
return fmt.Sprintf("%s unsupported protocol %d", p.Time, int(p.Type))
|
||||
}
|
||||
return fmt.Sprintf("%s %s", p.Time, p.headerString(p.Headers))
|
||||
}
|
||||
|
||||
// Arphdr is a ARP packet header.
|
||||
type Arphdr struct {
|
||||
Addrtype uint16
|
||||
Protocol uint16
|
||||
HwAddressSize uint8
|
||||
ProtAddressSize uint8
|
||||
Operation uint16
|
||||
SourceHwAddress []byte
|
||||
SourceProtAddress []byte
|
||||
DestHwAddress []byte
|
||||
DestProtAddress []byte
|
||||
}
|
||||
|
||||
func (arp *Arphdr) String() (s string) {
|
||||
switch arp.Operation {
|
||||
case 1:
|
||||
s = "ARP request"
|
||||
case 2:
|
||||
s = "ARP Reply"
|
||||
}
|
||||
if arp.Addrtype == LINKTYPE_ETHERNET && arp.Protocol == TYPE_IP {
|
||||
s = fmt.Sprintf("%012x (%s) > %012x (%s)",
|
||||
decodemac(arp.SourceHwAddress), arp.SourceProtAddress,
|
||||
decodemac(arp.DestHwAddress), arp.DestProtAddress)
|
||||
} else {
|
||||
s = fmt.Sprintf("addrtype = %d protocol = %d", arp.Addrtype, arp.Protocol)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Packet) decodeArp() {
|
||||
if len(p.Payload) < 8 {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := p.Payload
|
||||
arp := new(Arphdr)
|
||||
arp.Addrtype = binary.BigEndian.Uint16(pkt[0:2])
|
||||
arp.Protocol = binary.BigEndian.Uint16(pkt[2:4])
|
||||
arp.HwAddressSize = pkt[4]
|
||||
arp.ProtAddressSize = pkt[5]
|
||||
arp.Operation = binary.BigEndian.Uint16(pkt[6:8])
|
||||
|
||||
if len(pkt) < int(8+2*arp.HwAddressSize+2*arp.ProtAddressSize) {
|
||||
return
|
||||
}
|
||||
arp.SourceHwAddress = pkt[8 : 8+arp.HwAddressSize]
|
||||
arp.SourceProtAddress = pkt[8+arp.HwAddressSize : 8+arp.HwAddressSize+arp.ProtAddressSize]
|
||||
arp.DestHwAddress = pkt[8+arp.HwAddressSize+arp.ProtAddressSize : 8+2*arp.HwAddressSize+arp.ProtAddressSize]
|
||||
arp.DestProtAddress = pkt[8+2*arp.HwAddressSize+arp.ProtAddressSize : 8+2*arp.HwAddressSize+2*arp.ProtAddressSize]
|
||||
|
||||
p.Headers = append(p.Headers, arp)
|
||||
|
||||
if len(pkt) >= int(8+2*arp.HwAddressSize+2*arp.ProtAddressSize) {
|
||||
p.Payload = p.Payload[8+2*arp.HwAddressSize+2*arp.ProtAddressSize:]
|
||||
}
|
||||
}
|
||||
|
||||
// IPadr is the header of an IP packet.
|
||||
type Iphdr struct {
|
||||
Version uint8
|
||||
Ihl uint8
|
||||
Tos uint8
|
||||
Length uint16
|
||||
Id uint16
|
||||
Flags uint8
|
||||
FragOffset uint16
|
||||
Ttl uint8
|
||||
Protocol uint8
|
||||
Checksum uint16
|
||||
SrcIp []byte
|
||||
DestIp []byte
|
||||
}
|
||||
|
||||
func (p *Packet) decodeIp() {
|
||||
if len(p.Payload) < 20 {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := p.Payload
|
||||
ip := new(Iphdr)
|
||||
|
||||
ip.Version = uint8(pkt[0]) >> 4
|
||||
ip.Ihl = uint8(pkt[0]) & 0x0F
|
||||
ip.Tos = pkt[1]
|
||||
ip.Length = binary.BigEndian.Uint16(pkt[2:4])
|
||||
ip.Id = binary.BigEndian.Uint16(pkt[4:6])
|
||||
flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
|
||||
ip.Flags = uint8(flagsfrags >> 13)
|
||||
ip.FragOffset = flagsfrags & 0x1FFF
|
||||
ip.Ttl = pkt[8]
|
||||
ip.Protocol = pkt[9]
|
||||
ip.Checksum = binary.BigEndian.Uint16(pkt[10:12])
|
||||
ip.SrcIp = pkt[12:16]
|
||||
ip.DestIp = pkt[16:20]
|
||||
|
||||
pEnd := int(ip.Length)
|
||||
if pEnd > len(pkt) {
|
||||
pEnd = len(pkt)
|
||||
}
|
||||
|
||||
if len(pkt) >= pEnd && int(ip.Ihl*4) < pEnd {
|
||||
p.Payload = pkt[ip.Ihl*4 : pEnd]
|
||||
} else {
|
||||
p.Payload = []byte{}
|
||||
}
|
||||
|
||||
p.Headers = append(p.Headers, ip)
|
||||
p.IP = ip
|
||||
|
||||
switch ip.Protocol {
|
||||
case IP_TCP:
|
||||
p.decodeTcp()
|
||||
case IP_UDP:
|
||||
p.decodeUdp()
|
||||
case IP_ICMP:
|
||||
p.decodeIcmp()
|
||||
case IP_INIP:
|
||||
p.decodeIp()
|
||||
}
|
||||
}
|
||||
|
||||
func (ip *Iphdr) SrcAddr() string { return net.IP(ip.SrcIp).String() }
|
||||
func (ip *Iphdr) DestAddr() string { return net.IP(ip.DestIp).String() }
|
||||
func (ip *Iphdr) Len() int { return int(ip.Length) }
|
||||
|
||||
type Vlanhdr struct {
|
||||
Priority byte
|
||||
DropEligible bool
|
||||
VlanIdentifier int
|
||||
Type int // Not actually part of the vlan header, but the type of the actual packet
|
||||
}
|
||||
|
||||
func (v *Vlanhdr) String() {
|
||||
fmt.Sprintf("VLAN Priority:%d Drop:%v Tag:%d", v.Priority, v.DropEligible, v.VlanIdentifier)
|
||||
}
|
||||
|
||||
func (p *Packet) decodeVlan() {
|
||||
pkt := p.Payload
|
||||
vlan := new(Vlanhdr)
|
||||
if len(pkt) < 4 {
|
||||
return
|
||||
}
|
||||
|
||||
vlan.Priority = (pkt[2] & 0xE0) >> 13
|
||||
vlan.DropEligible = pkt[2]&0x10 != 0
|
||||
vlan.VlanIdentifier = int(binary.BigEndian.Uint16(pkt[:2])) & 0x0FFF
|
||||
vlan.Type = int(binary.BigEndian.Uint16(p.Payload[2:4]))
|
||||
p.Headers = append(p.Headers, vlan)
|
||||
|
||||
if len(pkt) >= 5 {
|
||||
p.Payload = p.Payload[4:]
|
||||
}
|
||||
|
||||
switch vlan.Type {
|
||||
case TYPE_IP:
|
||||
p.decodeIp()
|
||||
case TYPE_IP6:
|
||||
p.decodeIp6()
|
||||
case TYPE_ARP:
|
||||
p.decodeArp()
|
||||
}
|
||||
}
|
||||
|
||||
type Tcphdr struct {
|
||||
SrcPort uint16
|
||||
DestPort uint16
|
||||
Seq uint32
|
||||
Ack uint32
|
||||
DataOffset uint8
|
||||
Flags uint16
|
||||
Window uint16
|
||||
Checksum uint16
|
||||
Urgent uint16
|
||||
Data []byte
|
||||
}
|
||||
|
||||
const (
|
||||
TCP_FIN = 1 << iota
|
||||
TCP_SYN
|
||||
TCP_RST
|
||||
TCP_PSH
|
||||
TCP_ACK
|
||||
TCP_URG
|
||||
TCP_ECE
|
||||
TCP_CWR
|
||||
TCP_NS
|
||||
)
|
||||
|
||||
func (p *Packet) decodeTcp() {
|
||||
if len(p.Payload) < 20 {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := p.Payload
|
||||
tcp := new(Tcphdr)
|
||||
tcp.SrcPort = binary.BigEndian.Uint16(pkt[0:2])
|
||||
tcp.DestPort = binary.BigEndian.Uint16(pkt[2:4])
|
||||
tcp.Seq = binary.BigEndian.Uint32(pkt[4:8])
|
||||
tcp.Ack = binary.BigEndian.Uint32(pkt[8:12])
|
||||
tcp.DataOffset = (pkt[12] & 0xF0) >> 4
|
||||
tcp.Flags = binary.BigEndian.Uint16(pkt[12:14]) & 0x1FF
|
||||
tcp.Window = binary.BigEndian.Uint16(pkt[14:16])
|
||||
tcp.Checksum = binary.BigEndian.Uint16(pkt[16:18])
|
||||
tcp.Urgent = binary.BigEndian.Uint16(pkt[18:20])
|
||||
if len(pkt) >= int(tcp.DataOffset*4) {
|
||||
p.Payload = pkt[tcp.DataOffset*4:]
|
||||
}
|
||||
p.Headers = append(p.Headers, tcp)
|
||||
p.TCP = tcp
|
||||
}
|
||||
|
||||
func (tcp *Tcphdr) String(hdr addrHdr) string {
|
||||
return fmt.Sprintf("TCP %s:%d > %s:%d %s SEQ=%d ACK=%d LEN=%d",
|
||||
hdr.SrcAddr(), int(tcp.SrcPort), hdr.DestAddr(), int(tcp.DestPort),
|
||||
tcp.FlagsString(), int64(tcp.Seq), int64(tcp.Ack), hdr.Len())
|
||||
}
|
||||
|
||||
func (tcp *Tcphdr) FlagsString() string {
|
||||
var sflags []string
|
||||
if 0 != (tcp.Flags & TCP_SYN) {
|
||||
sflags = append(sflags, "syn")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_FIN) {
|
||||
sflags = append(sflags, "fin")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_ACK) {
|
||||
sflags = append(sflags, "ack")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_PSH) {
|
||||
sflags = append(sflags, "psh")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_RST) {
|
||||
sflags = append(sflags, "rst")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_URG) {
|
||||
sflags = append(sflags, "urg")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_NS) {
|
||||
sflags = append(sflags, "ns")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_CWR) {
|
||||
sflags = append(sflags, "cwr")
|
||||
}
|
||||
if 0 != (tcp.Flags & TCP_ECE) {
|
||||
sflags = append(sflags, "ece")
|
||||
}
|
||||
return fmt.Sprintf("[%s]", strings.Join(sflags, " "))
|
||||
}
|
||||
|
||||
type Udphdr struct {
|
||||
SrcPort uint16
|
||||
DestPort uint16
|
||||
Length uint16
|
||||
Checksum uint16
|
||||
}
|
||||
|
||||
func (p *Packet) decodeUdp() {
|
||||
if len(p.Payload) < 8 {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := p.Payload
|
||||
udp := new(Udphdr)
|
||||
udp.SrcPort = binary.BigEndian.Uint16(pkt[0:2])
|
||||
udp.DestPort = binary.BigEndian.Uint16(pkt[2:4])
|
||||
udp.Length = binary.BigEndian.Uint16(pkt[4:6])
|
||||
udp.Checksum = binary.BigEndian.Uint16(pkt[6:8])
|
||||
p.Headers = append(p.Headers, udp)
|
||||
p.UDP = udp
|
||||
if len(p.Payload) >= 8 {
|
||||
p.Payload = pkt[8:]
|
||||
}
|
||||
}
|
||||
|
||||
func (udp *Udphdr) String(hdr addrHdr) string {
|
||||
return fmt.Sprintf("UDP %s:%d > %s:%d LEN=%d CHKSUM=%d",
|
||||
hdr.SrcAddr(), int(udp.SrcPort), hdr.DestAddr(), int(udp.DestPort),
|
||||
int(udp.Length), int(udp.Checksum))
|
||||
}
|
||||
|
||||
type Icmphdr struct {
|
||||
Type uint8
|
||||
Code uint8
|
||||
Checksum uint16
|
||||
Id uint16
|
||||
Seq uint16
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (p *Packet) decodeIcmp() *Icmphdr {
|
||||
if len(p.Payload) < 8 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pkt := p.Payload
|
||||
icmp := new(Icmphdr)
|
||||
icmp.Type = pkt[0]
|
||||
icmp.Code = pkt[1]
|
||||
icmp.Checksum = binary.BigEndian.Uint16(pkt[2:4])
|
||||
icmp.Id = binary.BigEndian.Uint16(pkt[4:6])
|
||||
icmp.Seq = binary.BigEndian.Uint16(pkt[6:8])
|
||||
p.Payload = pkt[8:]
|
||||
p.Headers = append(p.Headers, icmp)
|
||||
return icmp
|
||||
}
|
||||
|
||||
func (icmp *Icmphdr) String(hdr addrHdr) string {
|
||||
return fmt.Sprintf("ICMP %s > %s Type = %d Code = %d ",
|
||||
hdr.SrcAddr(), hdr.DestAddr(), icmp.Type, icmp.Code)
|
||||
}
|
||||
|
||||
func (icmp *Icmphdr) TypeString() (result string) {
|
||||
switch icmp.Type {
|
||||
case 0:
|
||||
result = fmt.Sprintf("Echo reply seq=%d", icmp.Seq)
|
||||
case 3:
|
||||
switch icmp.Code {
|
||||
case 0:
|
||||
result = "Network unreachable"
|
||||
case 1:
|
||||
result = "Host unreachable"
|
||||
case 2:
|
||||
result = "Protocol unreachable"
|
||||
case 3:
|
||||
result = "Port unreachable"
|
||||
default:
|
||||
result = "Destination unreachable"
|
||||
}
|
||||
case 8:
|
||||
result = fmt.Sprintf("Echo request seq=%d", icmp.Seq)
|
||||
case 30:
|
||||
result = "Traceroute"
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type Ip6hdr struct {
|
||||
// http://www.networksorcery.com/enp/protocol/ipv6.htm
|
||||
Version uint8 // 4 bits
|
||||
TrafficClass uint8 // 8 bits
|
||||
FlowLabel uint32 // 20 bits
|
||||
Length uint16 // 16 bits
|
||||
NextHeader uint8 // 8 bits, same as Protocol in Iphdr
|
||||
HopLimit uint8 // 8 bits
|
||||
SrcIp []byte // 16 bytes
|
||||
DestIp []byte // 16 bytes
|
||||
}
|
||||
|
||||
func (p *Packet) decodeIp6() {
|
||||
if len(p.Payload) < 40 {
|
||||
return
|
||||
}
|
||||
|
||||
pkt := p.Payload
|
||||
ip6 := new(Ip6hdr)
|
||||
ip6.Version = uint8(pkt[0]) >> 4
|
||||
ip6.TrafficClass = uint8((binary.BigEndian.Uint16(pkt[0:2]) >> 4) & 0x00FF)
|
||||
ip6.FlowLabel = binary.BigEndian.Uint32(pkt[0:4]) & 0x000FFFFF
|
||||
ip6.Length = binary.BigEndian.Uint16(pkt[4:6])
|
||||
ip6.NextHeader = pkt[6]
|
||||
ip6.HopLimit = pkt[7]
|
||||
ip6.SrcIp = pkt[8:24]
|
||||
ip6.DestIp = pkt[24:40]
|
||||
|
||||
if len(p.Payload) >= 40 {
|
||||
p.Payload = pkt[40:]
|
||||
}
|
||||
|
||||
p.Headers = append(p.Headers, ip6)
|
||||
|
||||
switch ip6.NextHeader {
|
||||
case IP_TCP:
|
||||
p.decodeTcp()
|
||||
case IP_UDP:
|
||||
p.decodeUdp()
|
||||
case IP_ICMP:
|
||||
p.decodeIcmp()
|
||||
case IP_INIP:
|
||||
p.decodeIp()
|
||||
}
|
||||
}
|
||||
|
||||
func (ip6 *Ip6hdr) SrcAddr() string { return net.IP(ip6.SrcIp).String() }
|
||||
func (ip6 *Ip6hdr) DestAddr() string { return net.IP(ip6.DestIp).String() }
|
||||
func (ip6 *Ip6hdr) Len() int { return int(ip6.Length) }
|
247
Godeps/_workspace/src/github.com/akrennmair/gopcap/decode_test.go
generated
vendored
Normal file
247
Godeps/_workspace/src/github.com/akrennmair/gopcap/decode_test.go
generated
vendored
Normal file
|
@ -0,0 +1,247 @@
|
|||
package pcap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var testSimpleTcpPacket *Packet = &Packet{
|
||||
Data: []byte{
|
||||
0x00, 0x00, 0x0c, 0x9f, 0xf0, 0x20, 0xbc, 0x30, 0x5b, 0xe8, 0xd3, 0x49,
|
||||
0x08, 0x00, 0x45, 0x00, 0x01, 0xa4, 0x39, 0xdf, 0x40, 0x00, 0x40, 0x06,
|
||||
0x55, 0x5a, 0xac, 0x11, 0x51, 0x49, 0xad, 0xde, 0xfe, 0xe1, 0xc5, 0xf7,
|
||||
0x00, 0x50, 0xc5, 0x7e, 0x0e, 0x48, 0x49, 0x07, 0x42, 0x32, 0x80, 0x18,
|
||||
0x00, 0x73, 0xab, 0xb1, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x03, 0x77,
|
||||
0x37, 0x9c, 0x42, 0x77, 0x5e, 0x3a, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20,
|
||||
0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, 0x0d, 0x0a, 0x48, 0x6f,
|
||||
0x73, 0x74, 0x3a, 0x20, 0x77, 0x77, 0x77, 0x2e, 0x66, 0x69, 0x73, 0x68,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x0d, 0x0a, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63,
|
||||
0x74, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x6b, 0x65, 0x65, 0x70, 0x2d, 0x61,
|
||||
0x6c, 0x69, 0x76, 0x65, 0x0d, 0x0a, 0x55, 0x73, 0x65, 0x72, 0x2d, 0x41,
|
||||
0x67, 0x65, 0x6e, 0x74, 0x3a, 0x20, 0x4d, 0x6f, 0x7a, 0x69, 0x6c, 0x6c,
|
||||
0x61, 0x2f, 0x35, 0x2e, 0x30, 0x20, 0x28, 0x58, 0x31, 0x31, 0x3b, 0x20,
|
||||
0x4c, 0x69, 0x6e, 0x75, 0x78, 0x20, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34,
|
||||
0x29, 0x20, 0x41, 0x70, 0x70, 0x6c, 0x65, 0x57, 0x65, 0x62, 0x4b, 0x69,
|
||||
0x74, 0x2f, 0x35, 0x33, 0x35, 0x2e, 0x32, 0x20, 0x28, 0x4b, 0x48, 0x54,
|
||||
0x4d, 0x4c, 0x2c, 0x20, 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x47, 0x65, 0x63,
|
||||
0x6b, 0x6f, 0x29, 0x20, 0x43, 0x68, 0x72, 0x6f, 0x6d, 0x65, 0x2f, 0x31,
|
||||
0x35, 0x2e, 0x30, 0x2e, 0x38, 0x37, 0x34, 0x2e, 0x31, 0x32, 0x31, 0x20,
|
||||
0x53, 0x61, 0x66, 0x61, 0x72, 0x69, 0x2f, 0x35, 0x33, 0x35, 0x2e, 0x32,
|
||||
0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x3a, 0x20, 0x74, 0x65,
|
||||
0x78, 0x74, 0x2f, 0x68, 0x74, 0x6d, 0x6c, 0x2c, 0x61, 0x70, 0x70, 0x6c,
|
||||
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, 0x68, 0x74, 0x6d,
|
||||
0x6c, 0x2b, 0x78, 0x6d, 0x6c, 0x2c, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, 0x6d, 0x6c, 0x3b, 0x71, 0x3d,
|
||||
0x30, 0x2e, 0x39, 0x2c, 0x2a, 0x2f, 0x2a, 0x3b, 0x71, 0x3d, 0x30, 0x2e,
|
||||
0x38, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x45, 0x6e,
|
||||
0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0x20, 0x67, 0x7a, 0x69, 0x70,
|
||||
0x2c, 0x64, 0x65, 0x66, 0x6c, 0x61, 0x74, 0x65, 0x2c, 0x73, 0x64, 0x63,
|
||||
0x68, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x4c, 0x61,
|
||||
0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x3a, 0x20, 0x65, 0x6e, 0x2d, 0x55,
|
||||
0x53, 0x2c, 0x65, 0x6e, 0x3b, 0x71, 0x3d, 0x30, 0x2e, 0x38, 0x0d, 0x0a,
|
||||
0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x43, 0x68, 0x61, 0x72, 0x73,
|
||||
0x65, 0x74, 0x3a, 0x20, 0x49, 0x53, 0x4f, 0x2d, 0x38, 0x38, 0x35, 0x39,
|
||||
0x2d, 0x31, 0x2c, 0x75, 0x74, 0x66, 0x2d, 0x38, 0x3b, 0x71, 0x3d, 0x30,
|
||||
0x2e, 0x37, 0x2c, 0x2a, 0x3b, 0x71, 0x3d, 0x30, 0x2e, 0x33, 0x0d, 0x0a,
|
||||
0x0d, 0x0a,
|
||||
}}
|
||||
|
||||
func BenchmarkDecodeSimpleTcpPacket(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
testSimpleTcpPacket.Decode()
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeSimpleTcpPacket(t *testing.T) {
|
||||
p := testSimpleTcpPacket
|
||||
p.Decode()
|
||||
if p.DestMac != 0x00000c9ff020 {
|
||||
t.Error("Dest mac", p.DestMac)
|
||||
}
|
||||
if p.SrcMac != 0xbc305be8d349 {
|
||||
t.Error("Src mac", p.SrcMac)
|
||||
}
|
||||
if len(p.Headers) != 2 {
|
||||
t.Error("Incorrect number of headers", len(p.Headers))
|
||||
return
|
||||
}
|
||||
if ip, ipOk := p.Headers[0].(*Iphdr); ipOk {
|
||||
if ip.Version != 4 {
|
||||
t.Error("ip Version", ip.Version)
|
||||
}
|
||||
if ip.Ihl != 5 {
|
||||
t.Error("ip header length", ip.Ihl)
|
||||
}
|
||||
if ip.Tos != 0 {
|
||||
t.Error("ip TOS", ip.Tos)
|
||||
}
|
||||
if ip.Length != 420 {
|
||||
t.Error("ip Length", ip.Length)
|
||||
}
|
||||
if ip.Id != 14815 {
|
||||
t.Error("ip ID", ip.Id)
|
||||
}
|
||||
if ip.Flags != 0x02 {
|
||||
t.Error("ip Flags", ip.Flags)
|
||||
}
|
||||
if ip.FragOffset != 0 {
|
||||
t.Error("ip Fragoffset", ip.FragOffset)
|
||||
}
|
||||
if ip.Ttl != 64 {
|
||||
t.Error("ip TTL", ip.Ttl)
|
||||
}
|
||||
if ip.Protocol != 6 {
|
||||
t.Error("ip Protocol", ip.Protocol)
|
||||
}
|
||||
if ip.Checksum != 0x555A {
|
||||
t.Error("ip Checksum", ip.Checksum)
|
||||
}
|
||||
if !bytes.Equal(ip.SrcIp, []byte{172, 17, 81, 73}) {
|
||||
t.Error("ip Src", ip.SrcIp)
|
||||
}
|
||||
if !bytes.Equal(ip.DestIp, []byte{173, 222, 254, 225}) {
|
||||
t.Error("ip Dest", ip.DestIp)
|
||||
}
|
||||
if tcp, tcpOk := p.Headers[1].(*Tcphdr); tcpOk {
|
||||
if tcp.SrcPort != 50679 {
|
||||
t.Error("tcp srcport", tcp.SrcPort)
|
||||
}
|
||||
if tcp.DestPort != 80 {
|
||||
t.Error("tcp destport", tcp.DestPort)
|
||||
}
|
||||
if tcp.Seq != 0xc57e0e48 {
|
||||
t.Error("tcp seq", tcp.Seq)
|
||||
}
|
||||
if tcp.Ack != 0x49074232 {
|
||||
t.Error("tcp ack", tcp.Ack)
|
||||
}
|
||||
if tcp.DataOffset != 8 {
|
||||
t.Error("tcp dataoffset", tcp.DataOffset)
|
||||
}
|
||||
if tcp.Flags != 0x18 {
|
||||
t.Error("tcp flags", tcp.Flags)
|
||||
}
|
||||
if tcp.Window != 0x73 {
|
||||
t.Error("tcp window", tcp.Window)
|
||||
}
|
||||
if tcp.Checksum != 0xabb1 {
|
||||
t.Error("tcp checksum", tcp.Checksum)
|
||||
}
|
||||
if tcp.Urgent != 0 {
|
||||
t.Error("tcp urgent", tcp.Urgent)
|
||||
}
|
||||
} else {
|
||||
t.Error("Second header is not TCP header")
|
||||
}
|
||||
} else {
|
||||
t.Error("First header is not IP header")
|
||||
}
|
||||
if string(p.Payload) != "GET / HTTP/1.1\r\nHost: www.fish.com\r\nConnection: keep-alive\r\nUser-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.2 (KHTML, like Gecko) Chrome/15.0.874.121 Safari/535.2\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Encoding: gzip,deflate,sdch\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3\r\n\r\n" {
|
||||
t.Error("--- PAYLOAD STRING ---\n", string(p.Payload), "\n--- PAYLOAD BYTES ---\n", p.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
// Makes sure packet payload doesn't display the 6 trailing null of this packet
|
||||
// as part of the payload. They're actually the ethernet trailer.
|
||||
func TestDecodeSmallTcpPacketHasEmptyPayload(t *testing.T) {
|
||||
p := &Packet{
|
||||
// This packet is only 54 bits (an empty TCP RST), thus 6 trailing null
|
||||
// bytes are added by the ethernet layer to make it the minimum packet size.
|
||||
Data: []byte{
|
||||
0xbc, 0x30, 0x5b, 0xe8, 0xd3, 0x49, 0xb8, 0xac, 0x6f, 0x92, 0xd5, 0xbf,
|
||||
0x08, 0x00, 0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06,
|
||||
0x3f, 0x9f, 0xac, 0x11, 0x51, 0xc5, 0xac, 0x11, 0x51, 0x49, 0x00, 0x63,
|
||||
0x9a, 0xef, 0x00, 0x00, 0x00, 0x00, 0x2e, 0xc1, 0x27, 0x83, 0x50, 0x14,
|
||||
0x00, 0x00, 0xc3, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
}}
|
||||
p.Decode()
|
||||
if p.Payload == nil {
|
||||
t.Error("Nil payload")
|
||||
}
|
||||
if len(p.Payload) != 0 {
|
||||
t.Error("Non-empty payload:", p.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeVlanPacket(t *testing.T) {
|
||||
p := &Packet{
|
||||
Data: []byte{
|
||||
0x00, 0x10, 0xdb, 0xff, 0x10, 0x00, 0x00, 0x15, 0x2c, 0x9d, 0xcc, 0x00, 0x81, 0x00, 0x01, 0xf7,
|
||||
0x08, 0x00, 0x45, 0x00, 0x00, 0x28, 0x29, 0x8d, 0x40, 0x00, 0x7d, 0x06, 0x83, 0xa0, 0xac, 0x1b,
|
||||
0xca, 0x8e, 0x45, 0x16, 0x94, 0xe2, 0xd4, 0x0a, 0x00, 0x50, 0xdf, 0xab, 0x9c, 0xc6, 0xcd, 0x1e,
|
||||
0xe5, 0xd1, 0x50, 0x10, 0x01, 0x00, 0x5a, 0x74, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
}}
|
||||
p.Decode()
|
||||
if p.Type != TYPE_VLAN {
|
||||
t.Error("Didn't detect vlan")
|
||||
}
|
||||
if len(p.Headers) != 3 {
|
||||
t.Error("Incorrect number of headers:", len(p.Headers))
|
||||
for i, h := range p.Headers {
|
||||
t.Errorf("Header %d: %#v", i, h)
|
||||
}
|
||||
t.FailNow()
|
||||
}
|
||||
if _, ok := p.Headers[0].(*Vlanhdr); !ok {
|
||||
t.Errorf("First header isn't vlan: %q", p.Headers[0])
|
||||
}
|
||||
if _, ok := p.Headers[1].(*Iphdr); !ok {
|
||||
t.Errorf("Second header isn't IP: %q", p.Headers[1])
|
||||
}
|
||||
if _, ok := p.Headers[2].(*Tcphdr); !ok {
|
||||
t.Errorf("Third header isn't TCP: %q", p.Headers[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeFuzzFallout(t *testing.T) {
|
||||
testData := []struct {
|
||||
Data []byte
|
||||
}{
|
||||
{[]byte("000000000000\x81\x000")},
|
||||
{[]byte("000000000000\x81\x00000")},
|
||||
{[]byte("000000000000\x86\xdd0")},
|
||||
{[]byte("000000000000\b\x000")},
|
||||
{[]byte("000000000000\b\x060")},
|
||||
{[]byte{}},
|
||||
{[]byte("000000000000\b\x0600000000")},
|
||||
{[]byte("000000000000\x86\xdd000000\x01000000000000000000000000000000000")},
|
||||
{[]byte("000000000000\x81\x0000\b\x0600000000")},
|
||||
{[]byte("000000000000\b\x00n0000000000000000000")},
|
||||
{[]byte("000000000000\x86\xdd000000\x0100000000000000000000000000000000000")},
|
||||
{[]byte("000000000000\x81\x0000\b\x00g0000000000000000000")},
|
||||
//{[]byte()},
|
||||
{[]byte("000000000000\b\x00400000000\x110000000000")},
|
||||
{[]byte("0nMء\xfe\x13\x13\x81\x00gr\b\x00&x\xc9\xe5b'\x1e0\x00\x04\x00\x0020596224")},
|
||||
{[]byte("000000000000\x81\x0000\b\x00400000000\x110000000000")},
|
||||
{[]byte("000000000000\b\x00000000000\x0600\xff0000000")},
|
||||
{[]byte("000000000000\x86\xdd000000\x06000000000000000000000000000000000")},
|
||||
{[]byte("000000000000\x81\x0000\b\x00000000000\x0600b0000000")},
|
||||
{[]byte("000000000000\x81\x0000\b\x00400000000\x060000000000")},
|
||||
{[]byte("000000000000\x86\xdd000000\x11000000000000000000000000000000000")},
|
||||
{[]byte("000000000000\x86\xdd000000\x0600000000000000000000000000000000000000000000M")},
|
||||
{[]byte("000000000000\b\x00500000000\x0600000000000")},
|
||||
{[]byte("0nM\xd80\xfe\x13\x13\x81\x00gr\b\x00&x\xc9\xe5b'\x1e0\x00\x04\x00\x0020596224")},
|
||||
}
|
||||
|
||||
for _, entry := range testData {
|
||||
pkt := &Packet{
|
||||
Time: time.Now(),
|
||||
Caplen: uint32(len(entry.Data)),
|
||||
Len: uint32(len(entry.Data)),
|
||||
Data: entry.Data,
|
||||
}
|
||||
|
||||
pkt.Decode()
|
||||
/*
|
||||
func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
t.Fatalf("%d. %q failed: %v", idx, string(entry.Data), err)
|
||||
}
|
||||
}()
|
||||
pkt.Decode()
|
||||
}()
|
||||
*/
|
||||
}
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
package pcap
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileHeader is the parsed header of a pcap file.
|
||||
// http://wiki.wireshark.org/Development/LibpcapFileFormat
|
||||
type FileHeader struct {
|
||||
MagicNumber uint32
|
||||
VersionMajor uint16
|
||||
VersionMinor uint16
|
||||
TimeZone int32
|
||||
SigFigs uint32
|
||||
SnapLen uint32
|
||||
Network uint32
|
||||
}
|
||||
|
||||
type PacketTime struct {
|
||||
Sec int32
|
||||
Usec int32
|
||||
}
|
||||
|
||||
// Convert the PacketTime to a go Time struct.
|
||||
func (p *PacketTime) Time() time.Time {
|
||||
return time.Unix(int64(p.Sec), int64(p.Usec)*1000)
|
||||
}
|
||||
|
||||
// Packet is a single packet parsed from a pcap file.
|
||||
//
|
||||
// Convenient access to IP, TCP, and UDP headers is provided after Decode()
|
||||
// is called if the packet is of the appropriate type.
|
||||
type Packet struct {
|
||||
Time time.Time // packet send/receive time
|
||||
Caplen uint32 // bytes stored in the file (caplen <= len)
|
||||
Len uint32 // bytes sent/received
|
||||
Data []byte // packet data
|
||||
|
||||
Type int // protocol type, see LINKTYPE_*
|
||||
DestMac uint64
|
||||
SrcMac uint64
|
||||
|
||||
Headers []interface{} // decoded headers, in order
|
||||
Payload []byte // remaining non-header bytes
|
||||
|
||||
IP *Iphdr // IP header (for IP packets, after decoding)
|
||||
TCP *Tcphdr // TCP header (for TCP packets, after decoding)
|
||||
UDP *Udphdr // UDP header (for UDP packets after decoding)
|
||||
}
|
||||
|
||||
// Reader parses pcap files.
|
||||
type Reader struct {
|
||||
flip bool
|
||||
buf io.Reader
|
||||
err error
|
||||
fourBytes []byte
|
||||
twoBytes []byte
|
||||
sixteenBytes []byte
|
||||
Header FileHeader
|
||||
}
|
||||
|
||||
// NewReader reads pcap data from an io.Reader.
|
||||
func NewReader(reader io.Reader) (*Reader, error) {
|
||||
r := &Reader{
|
||||
buf: reader,
|
||||
fourBytes: make([]byte, 4),
|
||||
twoBytes: make([]byte, 2),
|
||||
sixteenBytes: make([]byte, 16),
|
||||
}
|
||||
switch magic := r.readUint32(); magic {
|
||||
case 0xa1b2c3d4:
|
||||
r.flip = false
|
||||
case 0xd4c3b2a1:
|
||||
r.flip = true
|
||||
default:
|
||||
return nil, fmt.Errorf("pcap: bad magic number: %0x", magic)
|
||||
}
|
||||
r.Header = FileHeader{
|
||||
MagicNumber: 0xa1b2c3d4,
|
||||
VersionMajor: r.readUint16(),
|
||||
VersionMinor: r.readUint16(),
|
||||
TimeZone: r.readInt32(),
|
||||
SigFigs: r.readUint32(),
|
||||
SnapLen: r.readUint32(),
|
||||
Network: r.readUint32(),
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Next returns the next packet or nil if no more packets can be read.
|
||||
func (r *Reader) Next() *Packet {
|
||||
d := r.sixteenBytes
|
||||
r.err = r.read(d)
|
||||
if r.err != nil {
|
||||
return nil
|
||||
}
|
||||
timeSec := asUint32(d[0:4], r.flip)
|
||||
timeUsec := asUint32(d[4:8], r.flip)
|
||||
capLen := asUint32(d[8:12], r.flip)
|
||||
origLen := asUint32(d[12:16], r.flip)
|
||||
|
||||
data := make([]byte, capLen)
|
||||
if r.err = r.read(data); r.err != nil {
|
||||
return nil
|
||||
}
|
||||
return &Packet{
|
||||
Time: time.Unix(int64(timeSec), int64(timeUsec)),
|
||||
Caplen: capLen,
|
||||
Len: origLen,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reader) read(data []byte) error {
|
||||
var err error
|
||||
n, err := r.buf.Read(data)
|
||||
for err == nil && n != len(data) {
|
||||
var chunk int
|
||||
chunk, err = r.buf.Read(data[n:])
|
||||
n += chunk
|
||||
}
|
||||
if len(data) == n {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *Reader) readUint32() uint32 {
|
||||
data := r.fourBytes
|
||||
if r.err = r.read(data); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return asUint32(data, r.flip)
|
||||
}
|
||||
|
||||
func (r *Reader) readInt32() int32 {
|
||||
data := r.fourBytes
|
||||
if r.err = r.read(data); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return int32(asUint32(data, r.flip))
|
||||
}
|
||||
|
||||
func (r *Reader) readUint16() uint16 {
|
||||
data := r.twoBytes
|
||||
if r.err = r.read(data); r.err != nil {
|
||||
return 0
|
||||
}
|
||||
return asUint16(data, r.flip)
|
||||
}
|
||||
|
||||
// Writer writes a pcap file.
|
||||
type Writer struct {
|
||||
writer io.Writer
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// NewWriter creates a Writer that stores output in an io.Writer.
|
||||
// The FileHeader is written immediately.
|
||||
func NewWriter(writer io.Writer, header *FileHeader) (*Writer, error) {
|
||||
w := &Writer{
|
||||
writer: writer,
|
||||
buf: make([]byte, 24),
|
||||
}
|
||||
binary.LittleEndian.PutUint32(w.buf, header.MagicNumber)
|
||||
binary.LittleEndian.PutUint16(w.buf[4:], header.VersionMajor)
|
||||
binary.LittleEndian.PutUint16(w.buf[6:], header.VersionMinor)
|
||||
binary.LittleEndian.PutUint32(w.buf[8:], uint32(header.TimeZone))
|
||||
binary.LittleEndian.PutUint32(w.buf[12:], header.SigFigs)
|
||||
binary.LittleEndian.PutUint32(w.buf[16:], header.SnapLen)
|
||||
binary.LittleEndian.PutUint32(w.buf[20:], header.Network)
|
||||
if _, err := writer.Write(w.buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// Writer writes a packet to the underlying writer.
|
||||
func (w *Writer) Write(pkt *Packet) error {
|
||||
binary.LittleEndian.PutUint32(w.buf, uint32(pkt.Time.Unix()))
|
||||
binary.LittleEndian.PutUint32(w.buf[4:], uint32(pkt.Time.Nanosecond()))
|
||||
binary.LittleEndian.PutUint32(w.buf[8:], uint32(pkt.Time.Unix()))
|
||||
binary.LittleEndian.PutUint32(w.buf[12:], pkt.Len)
|
||||
if _, err := w.writer.Write(w.buf[:16]); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.writer.Write(pkt.Data)
|
||||
return err
|
||||
}
|
||||
|
||||
func asUint32(data []byte, flip bool) uint32 {
|
||||
if flip {
|
||||
return binary.BigEndian.Uint32(data)
|
||||
}
|
||||
return binary.LittleEndian.Uint32(data)
|
||||
}
|
||||
|
||||
func asUint16(data []byte, flip bool) uint16 {
|
||||
if flip {
|
||||
return binary.BigEndian.Uint16(data)
|
||||
}
|
||||
return binary.LittleEndian.Uint16(data)
|
||||
}
|
|
@ -0,0 +1,266 @@
|
|||
// Interface to both live and offline pcap parsing.
|
||||
package pcap
|
||||
|
||||
/*
|
||||
#cgo linux LDFLAGS: -lpcap
|
||||
#cgo freebsd LDFLAGS: -lpcap
|
||||
#cgo darwin LDFLAGS: -lpcap
|
||||
#cgo windows CFLAGS: -I C:/WpdPack/Include
|
||||
#cgo windows,386 LDFLAGS: -L C:/WpdPack/Lib -lwpcap
|
||||
#cgo windows,amd64 LDFLAGS: -L C:/WpdPack/Lib/x64 -lwpcap
|
||||
#include <stdlib.h>
|
||||
#include <pcap.h>
|
||||
|
||||
// Workaround for not knowing how to cast to const u_char**
|
||||
int hack_pcap_next_ex(pcap_t *p, struct pcap_pkthdr **pkt_header,
|
||||
u_char **pkt_data) {
|
||||
return pcap_next_ex(p, pkt_header, (const u_char **)pkt_data);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Pcap struct {
|
||||
cptr *C.pcap_t
|
||||
}
|
||||
|
||||
type Stat struct {
|
||||
PacketsReceived uint32
|
||||
PacketsDropped uint32
|
||||
PacketsIfDropped uint32
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
Name string
|
||||
Description string
|
||||
Addresses []IFAddress
|
||||
// TODO: add more elements
|
||||
}
|
||||
|
||||
type IFAddress struct {
|
||||
IP net.IP
|
||||
Netmask net.IPMask
|
||||
// TODO: add broadcast + PtP dst ?
|
||||
}
|
||||
|
||||
func (p *Pcap) Next() (pkt *Packet) {
|
||||
rv, _ := p.NextEx()
|
||||
return rv
|
||||
}
|
||||
|
||||
// Openlive opens a device and returns a *Pcap handler
|
||||
func Openlive(device string, snaplen int32, promisc bool, timeout_ms int32) (handle *Pcap, err error) {
|
||||
var buf *C.char
|
||||
buf = (*C.char)(C.calloc(ERRBUF_SIZE, 1))
|
||||
h := new(Pcap)
|
||||
var pro int32
|
||||
if promisc {
|
||||
pro = 1
|
||||
}
|
||||
|
||||
dev := C.CString(device)
|
||||
defer C.free(unsafe.Pointer(dev))
|
||||
|
||||
h.cptr = C.pcap_open_live(dev, C.int(snaplen), C.int(pro), C.int(timeout_ms), buf)
|
||||
if nil == h.cptr {
|
||||
handle = nil
|
||||
err = errors.New(C.GoString(buf))
|
||||
} else {
|
||||
handle = h
|
||||
}
|
||||
C.free(unsafe.Pointer(buf))
|
||||
return
|
||||
}
|
||||
|
||||
func Openoffline(file string) (handle *Pcap, err error) {
|
||||
var buf *C.char
|
||||
buf = (*C.char)(C.calloc(ERRBUF_SIZE, 1))
|
||||
h := new(Pcap)
|
||||
|
||||
cf := C.CString(file)
|
||||
defer C.free(unsafe.Pointer(cf))
|
||||
|
||||
h.cptr = C.pcap_open_offline(cf, buf)
|
||||
if nil == h.cptr {
|
||||
handle = nil
|
||||
err = errors.New(C.GoString(buf))
|
||||
} else {
|
||||
handle = h
|
||||
}
|
||||
C.free(unsafe.Pointer(buf))
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Pcap) NextEx() (pkt *Packet, result int32) {
|
||||
var pkthdr *C.struct_pcap_pkthdr
|
||||
|
||||
var buf_ptr *C.u_char
|
||||
var buf unsafe.Pointer
|
||||
result = int32(C.hack_pcap_next_ex(p.cptr, &pkthdr, &buf_ptr))
|
||||
|
||||
buf = unsafe.Pointer(buf_ptr)
|
||||
if nil == buf {
|
||||
return
|
||||
}
|
||||
|
||||
pkt = new(Packet)
|
||||
pkt.Time = time.Unix(int64(pkthdr.ts.tv_sec), int64(pkthdr.ts.tv_usec)*1000)
|
||||
pkt.Caplen = uint32(pkthdr.caplen)
|
||||
pkt.Len = uint32(pkthdr.len)
|
||||
pkt.Data = C.GoBytes(buf, C.int(pkthdr.caplen))
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Pcap) Close() {
|
||||
C.pcap_close(p.cptr)
|
||||
}
|
||||
|
||||
func (p *Pcap) Geterror() error {
|
||||
return errors.New(C.GoString(C.pcap_geterr(p.cptr)))
|
||||
}
|
||||
|
||||
func (p *Pcap) Getstats() (stat *Stat, err error) {
|
||||
var cstats _Ctype_struct_pcap_stat
|
||||
if -1 == C.pcap_stats(p.cptr, &cstats) {
|
||||
return nil, p.Geterror()
|
||||
}
|
||||
stats := new(Stat)
|
||||
stats.PacketsReceived = uint32(cstats.ps_recv)
|
||||
stats.PacketsDropped = uint32(cstats.ps_drop)
|
||||
stats.PacketsIfDropped = uint32(cstats.ps_ifdrop)
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (p *Pcap) Setfilter(expr string) (err error) {
|
||||
var bpf _Ctype_struct_bpf_program
|
||||
cexpr := C.CString(expr)
|
||||
defer C.free(unsafe.Pointer(cexpr))
|
||||
|
||||
if -1 == C.pcap_compile(p.cptr, &bpf, cexpr, 1, 0) {
|
||||
return p.Geterror()
|
||||
}
|
||||
|
||||
if -1 == C.pcap_setfilter(p.cptr, &bpf) {
|
||||
C.pcap_freecode(&bpf)
|
||||
return p.Geterror()
|
||||
}
|
||||
C.pcap_freecode(&bpf)
|
||||
return nil
|
||||
}
|
||||
|
||||
func Version() string {
|
||||
return C.GoString(C.pcap_lib_version())
|
||||
}
|
||||
|
||||
func (p *Pcap) Datalink() int {
|
||||
return int(C.pcap_datalink(p.cptr))
|
||||
}
|
||||
|
||||
func (p *Pcap) Setdatalink(dlt int) error {
|
||||
if -1 == C.pcap_set_datalink(p.cptr, C.int(dlt)) {
|
||||
return p.Geterror()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DatalinkValueToName(dlt int) string {
|
||||
if name := C.pcap_datalink_val_to_name(C.int(dlt)); name != nil {
|
||||
return C.GoString(name)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func DatalinkValueToDescription(dlt int) string {
|
||||
if desc := C.pcap_datalink_val_to_description(C.int(dlt)); desc != nil {
|
||||
return C.GoString(desc)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func Findalldevs() (ifs []Interface, err error) {
|
||||
var buf *C.char
|
||||
buf = (*C.char)(C.calloc(ERRBUF_SIZE, 1))
|
||||
defer C.free(unsafe.Pointer(buf))
|
||||
var alldevsp *C.pcap_if_t
|
||||
|
||||
if -1 == C.pcap_findalldevs((**C.pcap_if_t)(&alldevsp), buf) {
|
||||
return nil, errors.New(C.GoString(buf))
|
||||
}
|
||||
defer C.pcap_freealldevs((*C.pcap_if_t)(alldevsp))
|
||||
dev := alldevsp
|
||||
var i uint32
|
||||
for i = 0; dev != nil; dev = (*C.pcap_if_t)(dev.next) {
|
||||
i++
|
||||
}
|
||||
ifs = make([]Interface, i)
|
||||
dev = alldevsp
|
||||
for j := uint32(0); dev != nil; dev = (*C.pcap_if_t)(dev.next) {
|
||||
var iface Interface
|
||||
iface.Name = C.GoString(dev.name)
|
||||
iface.Description = C.GoString(dev.description)
|
||||
iface.Addresses = findalladdresses(dev.addresses)
|
||||
// TODO: add more elements
|
||||
ifs[j] = iface
|
||||
j++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func findalladdresses(addresses *_Ctype_struct_pcap_addr) (retval []IFAddress) {
|
||||
// TODO - make it support more than IPv4 and IPv6?
|
||||
retval = make([]IFAddress, 0, 1)
|
||||
for curaddr := addresses; curaddr != nil; curaddr = (*_Ctype_struct_pcap_addr)(curaddr.next) {
|
||||
var a IFAddress
|
||||
var err error
|
||||
if a.IP, err = sockaddr_to_IP((*syscall.RawSockaddr)(unsafe.Pointer(curaddr.addr))); err != nil {
|
||||
continue
|
||||
}
|
||||
if a.Netmask, err = sockaddr_to_IP((*syscall.RawSockaddr)(unsafe.Pointer(curaddr.addr))); err != nil {
|
||||
continue
|
||||
}
|
||||
retval = append(retval, a)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func sockaddr_to_IP(rsa *syscall.RawSockaddr) (IP []byte, err error) {
|
||||
switch rsa.Family {
|
||||
case syscall.AF_INET:
|
||||
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(rsa))
|
||||
IP = make([]byte, 4)
|
||||
for i := 0; i < len(IP); i++ {
|
||||
IP[i] = pp.Addr[i]
|
||||
}
|
||||
return
|
||||
case syscall.AF_INET6:
|
||||
pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(rsa))
|
||||
IP = make([]byte, 16)
|
||||
for i := 0; i < len(IP); i++ {
|
||||
IP[i] = pp.Addr[i]
|
||||
}
|
||||
return
|
||||
}
|
||||
err = errors.New("Unsupported address type")
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Pcap) Inject(data []byte) (err error) {
|
||||
buf := (*C.char)(C.malloc((C.size_t)(len(data))))
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
*(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(buf)) + uintptr(i))) = data[i]
|
||||
}
|
||||
|
||||
if -1 == C.pcap_sendpacket(p.cptr, (*C.u_char)(unsafe.Pointer(buf)), (C.int)(len(data))) {
|
||||
err = p.Geterror()
|
||||
}
|
||||
C.free(unsafe.Pointer(buf))
|
||||
return
|
||||
}
|
49
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/benchmark/benchmark.go
generated
vendored
Normal file
49
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/benchmark/benchmark.go
generated
vendored
Normal file
|
@ -0,0 +1,49 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/pprof"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var filename *string = flag.String("file", "", "filename")
|
||||
var decode *bool = flag.Bool("d", false, "If true, decode each packet")
|
||||
var cpuprofile *string = flag.String("cpuprofile", "", "filename")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
h, err := pcap.Openoffline(*filename)
|
||||
if err != nil {
|
||||
fmt.Printf("Couldn't create pcap reader: %v", err)
|
||||
}
|
||||
|
||||
if *cpuprofile != "" {
|
||||
if out, err := os.Create(*cpuprofile); err == nil {
|
||||
pprof.StartCPUProfile(out)
|
||||
defer func() {
|
||||
pprof.StopCPUProfile()
|
||||
out.Close()
|
||||
}()
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
i, nilPackets := 0, 0
|
||||
start := time.Now()
|
||||
for pkt, code := h.NextEx(); code != -2; pkt, code = h.NextEx() {
|
||||
if pkt == nil {
|
||||
nilPackets++
|
||||
} else if *decode {
|
||||
pkt.Decode()
|
||||
}
|
||||
i++
|
||||
}
|
||||
duration := time.Since(start)
|
||||
fmt.Printf("Took %v to process %v packets, %v per packet, %d nil packets\n", duration, i, duration/time.Duration(i), nilPackets)
|
||||
}
|
96
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/pass/pass.go
generated
vendored
Normal file
96
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/pass/pass.go
generated
vendored
Normal file
|
@ -0,0 +1,96 @@
|
|||
package main
|
||||
|
||||
// Parses a pcap file, writes it back to disk, then verifies the files
|
||||
// are the same.
|
||||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
|
||||
)
|
||||
|
||||
var input *string = flag.String("input", "", "input file")
|
||||
var output *string = flag.String("output", "", "output file")
|
||||
var decode *bool = flag.Bool("decode", false, "print decoded packets")
|
||||
|
||||
func copyPcap(dest, src string) {
|
||||
f, err := os.Open(src)
|
||||
if err != nil {
|
||||
fmt.Printf("couldn't open %q: %v\n", src, err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
reader, err := pcap.NewReader(bufio.NewReader(f))
|
||||
if err != nil {
|
||||
fmt.Printf("couldn't create reader: %v\n", err)
|
||||
return
|
||||
}
|
||||
w, err := os.Create(dest)
|
||||
if err != nil {
|
||||
fmt.Printf("couldn't open %q: %v\n", dest, err)
|
||||
return
|
||||
}
|
||||
defer w.Close()
|
||||
buf := bufio.NewWriter(w)
|
||||
writer, err := pcap.NewWriter(buf, &reader.Header)
|
||||
if err != nil {
|
||||
fmt.Printf("couldn't create writer: %v\n", err)
|
||||
return
|
||||
}
|
||||
for {
|
||||
pkt := reader.Next()
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
if *decode {
|
||||
pkt.Decode()
|
||||
fmt.Println(pkt.String())
|
||||
}
|
||||
writer.Write(pkt)
|
||||
}
|
||||
buf.Flush()
|
||||
}
|
||||
|
||||
func check(dest, src string) {
|
||||
f, err := os.Open(src)
|
||||
if err != nil {
|
||||
fmt.Printf("couldn't open %q: %v\n", src, err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
freader := bufio.NewReader(f)
|
||||
|
||||
g, err := os.Open(dest)
|
||||
if err != nil {
|
||||
fmt.Printf("couldn't open %q: %v\n", src, err)
|
||||
return
|
||||
}
|
||||
defer g.Close()
|
||||
greader := bufio.NewReader(g)
|
||||
|
||||
for {
|
||||
fb, ferr := freader.ReadByte()
|
||||
gb, gerr := greader.ReadByte()
|
||||
|
||||
if ferr == io.EOF && gerr == io.EOF {
|
||||
break
|
||||
}
|
||||
if fb == gb {
|
||||
continue
|
||||
}
|
||||
fmt.Println("FAIL")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("PASS")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
copyPcap(*output, *input)
|
||||
check(*output, *input)
|
||||
}
|
82
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/pcaptest/pcaptest.go
generated
vendored
Normal file
82
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/pcaptest/pcaptest.go
generated
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
|
||||
)
|
||||
|
||||
func min(x uint32, y uint32) uint32 {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
func main() {
|
||||
var device *string = flag.String("d", "", "device")
|
||||
var file *string = flag.String("r", "", "file")
|
||||
var expr *string = flag.String("e", "", "filter expression")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
var h *pcap.Pcap
|
||||
var err error
|
||||
|
||||
ifs, err := pcap.Findalldevs()
|
||||
if len(ifs) == 0 {
|
||||
fmt.Printf("Warning: no devices found : %s\n", err)
|
||||
} else {
|
||||
for i := 0; i < len(ifs); i++ {
|
||||
fmt.Printf("dev %d: %s (%s)\n", i+1, ifs[i].Name, ifs[i].Description)
|
||||
}
|
||||
}
|
||||
|
||||
if *device != "" {
|
||||
h, err = pcap.Openlive(*device, 65535, true, 0)
|
||||
if h == nil {
|
||||
fmt.Printf("Openlive(%s) failed: %s\n", *device, err)
|
||||
return
|
||||
}
|
||||
} else if *file != "" {
|
||||
h, err = pcap.Openoffline(*file)
|
||||
if h == nil {
|
||||
fmt.Printf("Openoffline(%s) failed: %s\n", *file, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("usage: pcaptest [-d <device> | -r <file>]\n")
|
||||
return
|
||||
}
|
||||
defer h.Close()
|
||||
|
||||
fmt.Printf("pcap version: %s\n", pcap.Version())
|
||||
|
||||
if *expr != "" {
|
||||
fmt.Printf("Setting filter: %s\n", *expr)
|
||||
err := h.Setfilter(*expr)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: setting filter failed: %s\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
for pkt := h.Next(); pkt != nil; pkt = h.Next() {
|
||||
fmt.Printf("time: %d.%06d (%s) caplen: %d len: %d\nData:",
|
||||
int64(pkt.Time.Second()), int64(pkt.Time.Nanosecond()),
|
||||
time.Unix(int64(pkt.Time.Second()), 0).String(), int64(pkt.Caplen), int64(pkt.Len))
|
||||
for i := uint32(0); i < pkt.Caplen; i++ {
|
||||
if i%32 == 0 {
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
if 32 <= pkt.Data[i] && pkt.Data[i] <= 126 {
|
||||
fmt.Printf("%c", pkt.Data[i])
|
||||
} else {
|
||||
fmt.Printf(".")
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n\n")
|
||||
}
|
||||
|
||||
}
|
121
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/tcpdump/tcpdump.go
generated
vendored
Normal file
121
Godeps/_workspace/src/github.com/akrennmair/gopcap/tools/tcpdump/tcpdump.go
generated
vendored
Normal file
|
@ -0,0 +1,121 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
|
||||
)
|
||||
|
||||
const (
|
||||
TYPE_IP = 0x0800
|
||||
TYPE_ARP = 0x0806
|
||||
TYPE_IP6 = 0x86DD
|
||||
|
||||
IP_ICMP = 1
|
||||
IP_INIP = 4
|
||||
IP_TCP = 6
|
||||
IP_UDP = 17
|
||||
)
|
||||
|
||||
var out *bufio.Writer
|
||||
var errout *bufio.Writer
|
||||
|
||||
func main() {
|
||||
var device *string = flag.String("i", "", "interface")
|
||||
var snaplen *int = flag.Int("s", 65535, "snaplen")
|
||||
var hexdump *bool = flag.Bool("X", false, "hexdump")
|
||||
expr := ""
|
||||
|
||||
out = bufio.NewWriter(os.Stdout)
|
||||
errout = bufio.NewWriter(os.Stderr)
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(errout, "usage: %s [ -i interface ] [ -s snaplen ] [ -X ] [ expression ]\n", os.Args[0])
|
||||
errout.Flush()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if len(flag.Args()) > 0 {
|
||||
expr = flag.Arg(0)
|
||||
}
|
||||
|
||||
if *device == "" {
|
||||
devs, err := pcap.Findalldevs()
|
||||
if err != nil {
|
||||
fmt.Fprintf(errout, "tcpdump: couldn't find any devices: %s\n", err)
|
||||
}
|
||||
if 0 == len(devs) {
|
||||
flag.Usage()
|
||||
}
|
||||
*device = devs[0].Name
|
||||
}
|
||||
|
||||
h, err := pcap.Openlive(*device, int32(*snaplen), true, 0)
|
||||
if h == nil {
|
||||
fmt.Fprintf(errout, "tcpdump: %s\n", err)
|
||||
errout.Flush()
|
||||
return
|
||||
}
|
||||
defer h.Close()
|
||||
|
||||
if expr != "" {
|
||||
ferr := h.Setfilter(expr)
|
||||
if ferr != nil {
|
||||
fmt.Fprintf(out, "tcpdump: %s\n", ferr)
|
||||
out.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
for pkt := h.Next(); pkt != nil; pkt = h.Next() {
|
||||
pkt.Decode()
|
||||
fmt.Fprintf(out, "%s\n", pkt.String())
|
||||
if *hexdump {
|
||||
Hexdump(pkt)
|
||||
}
|
||||
out.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func Hexdump(pkt *pcap.Packet) {
|
||||
for i := 0; i < len(pkt.Data); i += 16 {
|
||||
Dumpline(uint32(i), pkt.Data[i:min(i+16, len(pkt.Data))])
|
||||
}
|
||||
}
|
||||
|
||||
func Dumpline(addr uint32, line []byte) {
|
||||
fmt.Fprintf(out, "\t0x%04x: ", int32(addr))
|
||||
var i uint16
|
||||
for i = 0; i < 16 && i < uint16(len(line)); i++ {
|
||||
if i%2 == 0 {
|
||||
out.WriteString(" ")
|
||||
}
|
||||
fmt.Fprintf(out, "%02x", line[i])
|
||||
}
|
||||
for j := i; j <= 16; j++ {
|
||||
if j%2 == 0 {
|
||||
out.WriteString(" ")
|
||||
}
|
||||
out.WriteString(" ")
|
||||
}
|
||||
out.WriteString(" ")
|
||||
for i = 0; i < 16 && i < uint16(len(line)); i++ {
|
||||
if line[i] >= 32 && line[i] <= 126 {
|
||||
fmt.Fprintf(out, "%c", line[i])
|
||||
} else {
|
||||
out.WriteString(".")
|
||||
}
|
||||
}
|
||||
out.WriteString("\n")
|
||||
}
|
|
@ -1,8 +1,8 @@
|
|||
Bolt [![Build Status](https://drone.io/github.com/boltdb/bolt/status.png)](https://drone.io/github.com/boltdb/bolt/latest) [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.png?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.png)](https://godoc.org/github.com/boltdb/bolt) ![Version](http://img.shields.io/badge/version-1.0-green.png)
|
||||
====
|
||||
|
||||
Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas] and
|
||||
the [LMDB project][lmdb]. The goal of the project is to provide a simple,
|
||||
Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas]
|
||||
[LMDB project][lmdb]. The goal of the project is to provide a simple,
|
||||
fast, and reliable database for projects that don't require a full database
|
||||
server such as Postgres or MySQL.
|
||||
|
||||
|
@ -180,8 +180,8 @@ and then safely close your transaction if an error is returned. This is the
|
|||
recommended way to use Bolt transactions.
|
||||
|
||||
However, sometimes you may want to manually start and end your transactions.
|
||||
You can use the `Tx.Begin()` function directly but _please_ be sure to close the
|
||||
transaction.
|
||||
You can use the `Tx.Begin()` function directly but **please** be sure to close
|
||||
the transaction.
|
||||
|
||||
```go
|
||||
// Start a writable transaction.
|
||||
|
@ -256,7 +256,7 @@ db.View(func(tx *bolt.Tx) error {
|
|||
```
|
||||
|
||||
The `Get()` function does not return an error because its operation is
|
||||
guarenteed to work (unless there is some kind of system failure). If the key
|
||||
guaranteed to work (unless there is some kind of system failure). If the key
|
||||
exists then it will return its byte slice value. If it doesn't exist then it
|
||||
will return `nil`. It's important to note that you can have a zero-length value
|
||||
set to a key which is different than the key not existing.
|
||||
|
@ -268,6 +268,50 @@ transaction is open. If you need to use a value outside of the transaction
|
|||
then you must use `copy()` to copy it to another byte slice.
|
||||
|
||||
|
||||
### Autoincrementing integer for the bucket
|
||||
By using the NextSequence() function, you can let Bolt determine a sequence
|
||||
which can be used as the unique identifier for your key/value pairs. See the
|
||||
example below.
|
||||
|
||||
```go
|
||||
// CreateUser saves u to the store. The new user ID is set on u once the data is persisted.
|
||||
func (s *Store) CreateUser(u *User) error {
|
||||
return s.db.Update(func(tx *bolt.Tx) error {
|
||||
// Retrieve the users bucket.
|
||||
// This should be created when the DB is first opened.
|
||||
b := tx.Bucket([]byte("users"))
|
||||
|
||||
// Generate ID for the user.
|
||||
// This returns an error only if the Tx is closed or not writeable.
|
||||
// That can't happen in an Update() call so I ignore the error check.
|
||||
id, _ = b.NextSequence()
|
||||
u.ID = int(id)
|
||||
|
||||
// Marshal user data into bytes.
|
||||
buf, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist bytes to users bucket.
|
||||
return b.Put(itob(u.ID), buf)
|
||||
})
|
||||
}
|
||||
|
||||
// itob returns an 8-byte big endian representation of v.
|
||||
func itob(v int) []byte {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(v))
|
||||
return b
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int
|
||||
...
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Iterating over keys
|
||||
|
||||
Bolt stores its keys in byte-sorted order within a bucket. This makes sequential
|
||||
|
@ -382,8 +426,11 @@ func (*Bucket) DeleteBucket(key []byte) error
|
|||
Bolt is a single file so it's easy to backup. You can use the `Tx.WriteTo()`
|
||||
function to write a consistent view of the database to a writer. If you call
|
||||
this from a read-only transaction, it will perform a hot backup and not block
|
||||
your other database reads and writes. It will also use `O_DIRECT` when available
|
||||
to prevent page cache trashing.
|
||||
your other database reads and writes.
|
||||
|
||||
By default, it will use a regular file handle which will utilize the operating
|
||||
system's page cache. See the [`Tx`](https://godoc.org/github.com/boltdb/bolt#Tx)
|
||||
documentation for information about optimizing for larger-than-RAM datasets.
|
||||
|
||||
One common use case is to backup over HTTP so you can use tools like `cURL` to
|
||||
do database backups:
|
||||
|
@ -500,7 +547,7 @@ they are libraries bundled into the application, however, their underlying
|
|||
structure is a log-structured merge-tree (LSM tree). An LSM tree optimizes
|
||||
random writes by using a write ahead log and multi-tiered, sorted files called
|
||||
SSTables. Bolt uses a B+tree internally and only a single file. Both approaches
|
||||
have trade offs.
|
||||
have trade-offs.
|
||||
|
||||
If you require a high random write throughput (>10,000 w/sec) or you need to use
|
||||
spinning disks then LevelDB could be a good choice. If your application is
|
||||
|
@ -568,7 +615,9 @@ Here are a few things to note when evaluating and using Bolt:
|
|||
can in memory and will release memory as needed to other processes. This means
|
||||
that Bolt can show very high memory usage when working with large databases.
|
||||
However, this is expected and the OS will release memory as needed. Bolt can
|
||||
handle databases much larger than the available physical RAM.
|
||||
handle databases much larger than the available physical RAM, provided its
|
||||
memory-map fits in the process virtual address space. It may be problematic
|
||||
on 32-bits systems.
|
||||
|
||||
* The data structures in the Bolt database are memory mapped so the data file
|
||||
will be endian specific. This means that you cannot copy a Bolt file from a
|
||||
|
@ -587,6 +636,56 @@ Here are a few things to note when evaluating and using Bolt:
|
|||
[page-allocation]: https://github.com/boltdb/bolt/issues/308#issuecomment-74811638
|
||||
|
||||
|
||||
## Reading the Source
|
||||
|
||||
Bolt is a relatively small code base (<3KLOC) for an embedded, serializable,
|
||||
transactional key/value database so it can be a good starting point for people
|
||||
interested in how databases work.
|
||||
|
||||
The best places to start are the main entry points into Bolt:
|
||||
|
||||
- `Open()` - Initializes the reference to the database. It's responsible for
|
||||
creating the database if it doesn't exist, obtaining an exclusive lock on the
|
||||
file, reading the meta pages, & memory-mapping the file.
|
||||
|
||||
- `DB.Begin()` - Starts a read-only or read-write transaction depending on the
|
||||
value of the `writable` argument. This requires briefly obtaining the "meta"
|
||||
lock to keep track of open transactions. Only one read-write transaction can
|
||||
exist at a time so the "rwlock" is acquired during the life of a read-write
|
||||
transaction.
|
||||
|
||||
- `Bucket.Put()` - Writes a key/value pair into a bucket. After validating the
|
||||
arguments, a cursor is used to traverse the B+tree to the page and position
|
||||
where they key & value will be written. Once the position is found, the bucket
|
||||
materializes the underlying page and the page's parent pages into memory as
|
||||
"nodes". These nodes are where mutations occur during read-write transactions.
|
||||
These changes get flushed to disk during commit.
|
||||
|
||||
- `Bucket.Get()` - Retrieves a key/value pair from a bucket. This uses a cursor
|
||||
to move to the page & position of a key/value pair. During a read-only
|
||||
transaction, the key and value data is returned as a direct reference to the
|
||||
underlying mmap file so there's no allocation overhead. For read-write
|
||||
transactions, this data may reference the mmap file or one of the in-memory
|
||||
node values.
|
||||
|
||||
- `Cursor` - This object is simply for traversing the B+tree of on-disk pages
|
||||
or in-memory nodes. It can seek to a specific key, move to the first or last
|
||||
value, or it can move forward or backward. The cursor handles the movement up
|
||||
and down the B+tree transparently to the end user.
|
||||
|
||||
- `Tx.Commit()` - Converts the in-memory dirty nodes and the list of free pages
|
||||
into pages to be written to disk. Writing to disk then occurs in two phases.
|
||||
First, the dirty pages are written to disk and an `fsync()` occurs. Second, a
|
||||
new meta page with an incremented transaction ID is written and another
|
||||
`fsync()` occurs. This two phase write ensures that partially written data
|
||||
pages are ignored in the event of a crash since the meta page pointing to them
|
||||
is never written. Partially written meta pages are invalidated because they
|
||||
are written with a checksum.
|
||||
|
||||
If you have additional notes that could be helpful for others, please submit
|
||||
them via pull request.
|
||||
|
||||
|
||||
## Other Projects Using Bolt
|
||||
|
||||
Below is a list of public, open source projects that use Bolt:
|
||||
|
@ -615,7 +714,11 @@ Below is a list of public, open source projects that use Bolt:
|
|||
* [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data.
|
||||
* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system.
|
||||
* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware.
|
||||
* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistant, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs.
|
||||
* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistent, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs.
|
||||
* [drive](https://github.com/odeke-em/drive) - drive is an unofficial Google Drive command line client for \*NIX operating systems.
|
||||
* [stow](https://github.com/djherbis/stow) - a persistence manager for objects
|
||||
backed by boltdb.
|
||||
* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining
|
||||
simple tx and key scans.
|
||||
|
||||
If you are using Bolt in a project please send a pull request to add it to the list.
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// +build arm64
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -4,8 +4,6 @@ import (
|
|||
"syscall"
|
||||
)
|
||||
|
||||
var odirect = syscall.O_DIRECT
|
||||
|
||||
// fdatasync flushes written data to a file descriptor.
|
||||
func fdatasync(db *DB) error {
|
||||
return syscall.Fdatasync(int(db.file.Fd()))
|
||||
|
|
|
@ -11,8 +11,6 @@ const (
|
|||
msInvalidate // invalidate cached data
|
||||
)
|
||||
|
||||
var odirect int
|
||||
|
||||
func msync(db *DB) error {
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(db.data)), uintptr(db.datasz), msInvalidate)
|
||||
if errno != 0 {
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// +build ppc64le
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -0,0 +1,9 @@
|
|||
// +build s390x
|
||||
|
||||
package bolt
|
||||
|
||||
// maxMapSize represents the largest mmap size supported by Bolt.
|
||||
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
|
||||
|
||||
// maxAllocSize is the size used when creating array pointers.
|
||||
const maxAllocSize = 0x7FFFFFFF
|
|
@ -58,7 +58,7 @@ func mmap(db *DB, sz int) error {
|
|||
}
|
||||
|
||||
// Map the data file to memory.
|
||||
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED)
|
||||
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,11 +2,12 @@ package bolt
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/sys/unix"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// flock acquires an advisory lock on a file descriptor.
|
||||
|
@ -67,7 +68,7 @@ func mmap(db *DB, sz int) error {
|
|||
}
|
||||
|
||||
// Map the data file to memory.
|
||||
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED)
|
||||
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -8,7 +8,37 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
var odirect int
|
||||
// LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1
|
||||
var (
|
||||
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procLockFileEx = modkernel32.NewProc("LockFileEx")
|
||||
procUnlockFileEx = modkernel32.NewProc("UnlockFileEx")
|
||||
)
|
||||
|
||||
const (
|
||||
// see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx
|
||||
flagLockExclusive = 2
|
||||
flagLockFailImmediately = 1
|
||||
|
||||
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx
|
||||
errLockViolation syscall.Errno = 0x21
|
||||
)
|
||||
|
||||
func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
|
||||
r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)))
|
||||
if r == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
|
||||
r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0)
|
||||
if r == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// fdatasync flushes written data to a file descriptor.
|
||||
func fdatasync(db *DB) error {
|
||||
|
@ -16,13 +46,37 @@ func fdatasync(db *DB) error {
|
|||
}
|
||||
|
||||
// flock acquires an advisory lock on a file descriptor.
|
||||
func flock(f *os.File, _ bool, _ time.Duration) error {
|
||||
return nil
|
||||
func flock(f *os.File, exclusive bool, timeout time.Duration) error {
|
||||
var t time.Time
|
||||
for {
|
||||
// If we're beyond our timeout then return an error.
|
||||
// This can only occur after we've attempted a flock once.
|
||||
if t.IsZero() {
|
||||
t = time.Now()
|
||||
} else if timeout > 0 && time.Since(t) > timeout {
|
||||
return ErrTimeout
|
||||
}
|
||||
|
||||
var flag uint32 = flagLockFailImmediately
|
||||
if exclusive {
|
||||
flag |= flagLockExclusive
|
||||
}
|
||||
|
||||
err := lockFileEx(syscall.Handle(f.Fd()), flag, 0, 1, 0, &syscall.Overlapped{})
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if err != errLockViolation {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for a bit and try again.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// funlock releases an advisory lock on a file descriptor.
|
||||
func funlock(f *os.File) error {
|
||||
return nil
|
||||
return unlockFileEx(syscall.Handle(f.Fd()), 0, 1, 0, &syscall.Overlapped{})
|
||||
}
|
||||
|
||||
// mmap memory maps a DB's data file.
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
package bolt
|
||||
|
||||
var odirect int
|
||||
|
||||
// fdatasync flushes written data to a file descriptor.
|
||||
func fdatasync(db *DB) error {
|
||||
return db.file.Sync()
|
||||
|
|
|
@ -11,7 +11,7 @@ const (
|
|||
MaxKeySize = 32768
|
||||
|
||||
// MaxValueSize is the maximum length of a value, in bytes.
|
||||
MaxValueSize = 4294967295
|
||||
MaxValueSize = (1 << 31) - 2
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -99,6 +99,7 @@ func (b *Bucket) Cursor() *Cursor {
|
|||
|
||||
// Bucket retrieves a nested bucket by name.
|
||||
// Returns nil if the bucket does not exist.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (b *Bucket) Bucket(name []byte) *Bucket {
|
||||
if b.buckets != nil {
|
||||
if child := b.buckets[string(name)]; child != nil {
|
||||
|
@ -148,6 +149,7 @@ func (b *Bucket) openBucket(value []byte) *Bucket {
|
|||
|
||||
// CreateBucket creates a new bucket at the given key and returns the new bucket.
|
||||
// Returns an error if the key already exists, if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) {
|
||||
if b.tx.db == nil {
|
||||
return nil, ErrTxClosed
|
||||
|
@ -192,6 +194,7 @@ func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) {
|
|||
|
||||
// CreateBucketIfNotExists creates a new bucket if it doesn't already exist and returns a reference to it.
|
||||
// Returns an error if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (b *Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) {
|
||||
child, err := b.CreateBucket(key)
|
||||
if err == ErrBucketExists {
|
||||
|
@ -270,6 +273,7 @@ func (b *Bucket) Get(key []byte) []byte {
|
|||
|
||||
// Put sets the value for a key in the bucket.
|
||||
// If the key exist then its previous value will be overwritten.
|
||||
// Supplied value must remain valid for the life of the transaction.
|
||||
// Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large.
|
||||
func (b *Bucket) Put(key []byte, value []byte) error {
|
||||
if b.tx.db == nil {
|
||||
|
@ -346,7 +350,8 @@ func (b *Bucket) NextSequence() (uint64, error) {
|
|||
|
||||
// ForEach executes a function for each key/value pair in a bucket.
|
||||
// If the provided function returns an error then the iteration is stopped and
|
||||
// the error is returned to the caller.
|
||||
// the error is returned to the caller. The provided function must not modify
|
||||
// the bucket; this will result in undefined behavior.
|
||||
func (b *Bucket) ForEach(fn func(k, v []byte) error) error {
|
||||
if b.tx.db == nil {
|
||||
return ErrTxClosed
|
||||
|
|
|
@ -253,7 +253,7 @@ func TestBucket_Delete_FreelistOverflow(t *testing.T) {
|
|||
b := tx.Bucket([]byte("0"))
|
||||
c := b.Cursor()
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
b.Delete(k)
|
||||
c.Delete()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
|
|
@ -34,6 +34,13 @@ func (c *Cursor) First() (key []byte, value []byte) {
|
|||
p, n := c.bucket.pageNode(c.bucket.root)
|
||||
c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
|
||||
c.first()
|
||||
|
||||
// If we land on an empty page then move to the next value.
|
||||
// https://github.com/boltdb/bolt/issues/450
|
||||
if c.stack[len(c.stack)-1].count() == 0 {
|
||||
c.next()
|
||||
}
|
||||
|
||||
k, v, flags := c.keyValue()
|
||||
if (flags & uint32(bucketLeafFlag)) != 0 {
|
||||
return k, nil
|
||||
|
@ -209,28 +216,37 @@ func (c *Cursor) last() {
|
|||
// next moves to the next leaf element and returns the key and value.
|
||||
// If the cursor is at the last leaf element then it stays there and returns nil.
|
||||
func (c *Cursor) next() (key []byte, value []byte, flags uint32) {
|
||||
// Attempt to move over one element until we're successful.
|
||||
// Move up the stack as we hit the end of each page in our stack.
|
||||
var i int
|
||||
for i = len(c.stack) - 1; i >= 0; i-- {
|
||||
elem := &c.stack[i]
|
||||
if elem.index < elem.count()-1 {
|
||||
elem.index++
|
||||
break
|
||||
for {
|
||||
// Attempt to move over one element until we're successful.
|
||||
// Move up the stack as we hit the end of each page in our stack.
|
||||
var i int
|
||||
for i = len(c.stack) - 1; i >= 0; i-- {
|
||||
elem := &c.stack[i]
|
||||
if elem.index < elem.count()-1 {
|
||||
elem.index++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we've hit the root page then stop and return. This will leave the
|
||||
// cursor on the last element of the last page.
|
||||
if i == -1 {
|
||||
return nil, nil, 0
|
||||
}
|
||||
// If we've hit the root page then stop and return. This will leave the
|
||||
// cursor on the last element of the last page.
|
||||
if i == -1 {
|
||||
return nil, nil, 0
|
||||
}
|
||||
|
||||
// Otherwise start from where we left off in the stack and find the
|
||||
// first element of the first leaf page.
|
||||
c.stack = c.stack[:i+1]
|
||||
c.first()
|
||||
return c.keyValue()
|
||||
// Otherwise start from where we left off in the stack and find the
|
||||
// first element of the first leaf page.
|
||||
c.stack = c.stack[:i+1]
|
||||
c.first()
|
||||
|
||||
// If this is an empty page then restart and move back up the stack.
|
||||
// https://github.com/boltdb/bolt/issues/450
|
||||
if c.stack[len(c.stack)-1].count() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
return c.keyValue()
|
||||
}
|
||||
}
|
||||
|
||||
// search recursively performs a binary search against a given page/node until it finds a given key.
|
||||
|
|
|
@ -303,6 +303,49 @@ func TestCursor_Restart(t *testing.T) {
|
|||
tx.Rollback()
|
||||
}
|
||||
|
||||
// Ensure that a cursor can skip over empty pages that have been deleted.
|
||||
func TestCursor_First_EmptyPages(t *testing.T) {
|
||||
db := NewTestDB()
|
||||
defer db.Close()
|
||||
|
||||
// Create 1000 keys in the "widgets" bucket.
|
||||
db.Update(func(tx *bolt.Tx) error {
|
||||
b, err := tx.CreateBucket([]byte("widgets"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
if err := b.Put(u64tob(uint64(i)), []byte{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Delete half the keys and then try to iterate.
|
||||
db.Update(func(tx *bolt.Tx) error {
|
||||
b := tx.Bucket([]byte("widgets"))
|
||||
for i := 0; i < 600; i++ {
|
||||
if err := b.Delete(u64tob(uint64(i))); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
c := b.Cursor()
|
||||
var n int
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
n++
|
||||
}
|
||||
if n != 400 {
|
||||
t.Fatalf("unexpected key count: %d", n)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Ensure that a Tx can iterate over all elements in a bucket.
|
||||
func TestCursor_QuickCheck(t *testing.T) {
|
||||
f := func(items testdata) bool {
|
||||
|
|
|
@ -63,6 +63,10 @@ type DB struct {
|
|||
// https://github.com/boltdb/bolt/issues/284
|
||||
NoGrowSync bool
|
||||
|
||||
// If you want to read the entire database fast, you can set MmapFlag to
|
||||
// syscall.MAP_POPULATE on Linux 2.6.23+ for sequential read-ahead.
|
||||
MmapFlags int
|
||||
|
||||
// MaxBatchSize is the maximum size of a batch. Default value is
|
||||
// copied from DefaultMaxBatchSize in Open.
|
||||
//
|
||||
|
@ -136,6 +140,7 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) {
|
|||
options = DefaultOptions
|
||||
}
|
||||
db.NoGrowSync = options.NoGrowSync
|
||||
db.MmapFlags = options.MmapFlags
|
||||
|
||||
// Set default values for later DB operations.
|
||||
db.MaxBatchSize = DefaultMaxBatchSize
|
||||
|
@ -672,6 +677,9 @@ type Options struct {
|
|||
// Open database in read-only mode. Uses flock(..., LOCK_SH |LOCK_NB) to
|
||||
// grab a shared lock (UNIX).
|
||||
ReadOnly bool
|
||||
|
||||
// Sets the DB.MmapFlags flag before memory mapping the file.
|
||||
MmapFlags int
|
||||
}
|
||||
|
||||
// DefaultOptions represent the options used if nil options are passed into Open().
|
||||
|
|
|
@ -39,9 +39,6 @@ func TestOpen(t *testing.T) {
|
|||
|
||||
// Ensure that opening an already open database file will timeout.
|
||||
func TestOpen_Timeout(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("timeout not supported on windows")
|
||||
}
|
||||
if runtime.GOOS == "solaris" {
|
||||
t.Skip("solaris fcntl locks don't support intra-process locking")
|
||||
}
|
||||
|
@ -66,9 +63,6 @@ func TestOpen_Timeout(t *testing.T) {
|
|||
|
||||
// Ensure that opening an already open database file will wait until its closed.
|
||||
func TestOpen_Wait(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("timeout not supported on windows")
|
||||
}
|
||||
if runtime.GOOS == "solaris" {
|
||||
t.Skip("solaris fcntl locks don't support intra-process locking")
|
||||
}
|
||||
|
@ -622,7 +616,7 @@ func TestDB_Consistency(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Ensure that DB stats can be substracted from one another.
|
||||
// Ensure that DB stats can be subtracted from one another.
|
||||
func TestDBStats_Sub(t *testing.T) {
|
||||
var a, b bolt.Stats
|
||||
a.TxStats.PageCount = 3
|
||||
|
|
|
@ -29,6 +29,14 @@ type Tx struct {
|
|||
pages map[pgid]*page
|
||||
stats TxStats
|
||||
commitHandlers []func()
|
||||
|
||||
// WriteFlag specifies the flag for write-related methods like WriteTo().
|
||||
// Tx opens the database file with the specified flag to copy the data.
|
||||
//
|
||||
// By default, the flag is unset, which works well for mostly in-memory
|
||||
// workloads. For databases that are much larger than available RAM,
|
||||
// set the flag to syscall.O_DIRECT to avoid trashing the page cache.
|
||||
WriteFlag int
|
||||
}
|
||||
|
||||
// init initializes the transaction.
|
||||
|
@ -87,18 +95,21 @@ func (tx *Tx) Stats() TxStats {
|
|||
|
||||
// Bucket retrieves a bucket by name.
|
||||
// Returns nil if the bucket does not exist.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (tx *Tx) Bucket(name []byte) *Bucket {
|
||||
return tx.root.Bucket(name)
|
||||
}
|
||||
|
||||
// CreateBucket creates a new bucket.
|
||||
// Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (tx *Tx) CreateBucket(name []byte) (*Bucket, error) {
|
||||
return tx.root.CreateBucket(name)
|
||||
}
|
||||
|
||||
// CreateBucketIfNotExists creates a new bucket if it doesn't already exist.
|
||||
// Returns an error if the bucket name is blank, or if the bucket name is too long.
|
||||
// The bucket instance is only valid for the lifetime of the transaction.
|
||||
func (tx *Tx) CreateBucketIfNotExists(name []byte) (*Bucket, error) {
|
||||
return tx.root.CreateBucketIfNotExists(name)
|
||||
}
|
||||
|
@ -236,7 +247,8 @@ func (tx *Tx) close() {
|
|||
var freelistPendingN = tx.db.freelist.pending_count()
|
||||
var freelistAlloc = tx.db.freelist.size()
|
||||
|
||||
// Remove writer lock.
|
||||
// Remove transaction ref & writer lock.
|
||||
tx.db.rwtx = nil
|
||||
tx.db.rwlock.Unlock()
|
||||
|
||||
// Merge statistics.
|
||||
|
@ -250,7 +262,12 @@ func (tx *Tx) close() {
|
|||
} else {
|
||||
tx.db.removeTx(tx)
|
||||
}
|
||||
|
||||
// Clear all references.
|
||||
tx.db = nil
|
||||
tx.meta = nil
|
||||
tx.root = Bucket{tx: tx}
|
||||
tx.pages = nil
|
||||
}
|
||||
|
||||
// Copy writes the entire database to a writer.
|
||||
|
@ -263,21 +280,18 @@ func (tx *Tx) Copy(w io.Writer) error {
|
|||
// WriteTo writes the entire database to a writer.
|
||||
// If err == nil then exactly tx.Size() bytes will be written into the writer.
|
||||
func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
|
||||
// Attempt to open reader directly.
|
||||
var f *os.File
|
||||
if f, err = os.OpenFile(tx.db.path, os.O_RDONLY|odirect, 0); err != nil {
|
||||
// Fallback to a regular open if that doesn't work.
|
||||
if f, err = os.OpenFile(tx.db.path, os.O_RDONLY, 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Attempt to open reader with WriteFlag
|
||||
f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Copy the meta pages.
|
||||
tx.db.metalock.Lock()
|
||||
n, err = io.CopyN(w, f, int64(tx.db.pageSize*2))
|
||||
tx.db.metalock.Unlock()
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return n, fmt.Errorf("meta copy: %s", err)
|
||||
}
|
||||
|
||||
|
@ -285,7 +299,6 @@ func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
|
|||
wn, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2))
|
||||
n += wn
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return n, err
|
||||
}
|
||||
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
*~
|
|
@ -1,19 +0,0 @@
|
|||
# This file is like Go's AUTHORS file: it lists Copyright holders.
|
||||
# The list of humans who have contributd is in the CONTRIBUTORS file.
|
||||
#
|
||||
# To contribute to this project, because it will eventually be folded
|
||||
# back in to Go itself, you need to submit a CLA:
|
||||
#
|
||||
# http://golang.org/doc/contribute.html#copyright
|
||||
#
|
||||
# Then you get added to CONTRIBUTORS and you or your company get added
|
||||
# to the AUTHORS file.
|
||||
|
||||
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
|
||||
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
|
||||
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
|
||||
Google, Inc.
|
||||
Keith Rarick <kr@xph.us> github=kr
|
||||
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
|
||||
Matt Layher <mdlayher@gmail.com> github=mdlayher
|
||||
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t
|
|
@ -1,19 +0,0 @@
|
|||
# This file is like Go's CONTRIBUTORS file: it lists humans.
|
||||
# The list of copyright holders (which may be companies) are in the AUTHORS file.
|
||||
#
|
||||
# To contribute to this project, because it will eventually be folded
|
||||
# back in to Go itself, you need to submit a CLA:
|
||||
#
|
||||
# http://golang.org/doc/contribute.html#copyright
|
||||
#
|
||||
# Then you get added to CONTRIBUTORS and you or your company get added
|
||||
# to the AUTHORS file.
|
||||
|
||||
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
|
||||
Brad Fitzpatrick <bradfitz@golang.org> github=bradfitz
|
||||
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
|
||||
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
|
||||
Keith Rarick <kr@xph.us> github=kr
|
||||
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
|
||||
Matt Layher <mdlayher@gmail.com> github=mdlayher
|
||||
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t
|
|
@ -1,5 +0,0 @@
|
|||
We only accept contributions from users who have gone through Go's
|
||||
contribution process (signed a CLA).
|
||||
|
||||
Please acknowledge whether you have (and use the same email) if
|
||||
sending a pull request.
|
|
@ -1,7 +0,0 @@
|
|||
Copyright 2014 Google & the Go AUTHORS
|
||||
|
||||
Go AUTHORS are:
|
||||
See https://code.google.com/p/go/source/browse/AUTHORS
|
||||
|
||||
Licensed under the terms of Go itself:
|
||||
https://code.google.com/p/go/source/browse/LICENSE
|
|
@ -1,75 +0,0 @@
|
|||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// buffer is an io.ReadWriteCloser backed by a fixed size buffer.
|
||||
// It never allocates, but moves old data as new data is written.
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
r, w int
|
||||
closed bool
|
||||
err error // err to return to reader
|
||||
}
|
||||
|
||||
var (
|
||||
errReadEmpty = errors.New("read from empty buffer")
|
||||
errWriteFull = errors.New("write on full buffer")
|
||||
)
|
||||
|
||||
// Read copies bytes from the buffer into p.
|
||||
// It is an error to read when no data is available.
|
||||
func (b *buffer) Read(p []byte) (n int, err error) {
|
||||
n = copy(p, b.buf[b.r:b.w])
|
||||
b.r += n
|
||||
if b.closed && b.r == b.w {
|
||||
err = b.err
|
||||
} else if b.r == b.w && n == 0 {
|
||||
err = errReadEmpty
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Len returns the number of bytes of the unread portion of the buffer.
|
||||
func (b *buffer) Len() int {
|
||||
return b.w - b.r
|
||||
}
|
||||
|
||||
// Write copies bytes from p into the buffer.
|
||||
// It is an error to write more data than the buffer can hold.
|
||||
func (b *buffer) Write(p []byte) (n int, err error) {
|
||||
if b.closed {
|
||||
return 0, errors.New("closed")
|
||||
}
|
||||
|
||||
// Slide existing data to beginning.
|
||||
if b.r > 0 && len(p) > len(b.buf)-b.w {
|
||||
copy(b.buf, b.buf[b.r:b.w])
|
||||
b.w -= b.r
|
||||
b.r = 0
|
||||
}
|
||||
|
||||
// Write new data.
|
||||
n = copy(b.buf[b.w:], p)
|
||||
b.w += n
|
||||
if n < len(p) {
|
||||
err = errWriteFull
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close marks the buffer as closed. Future calls to Write will
|
||||
// return an error. Future calls to Read, once the buffer is
|
||||
// empty, will return err.
|
||||
func (b *buffer) Close(err error) {
|
||||
if !b.closed {
|
||||
b.closed = true
|
||||
b.err = err
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var bufferReadTests = []struct {
|
||||
buf buffer
|
||||
read, wn int
|
||||
werr error
|
||||
wp []byte
|
||||
wbuf buffer
|
||||
}{
|
||||
{
|
||||
buffer{[]byte{'a', 0}, 0, 1, false, nil},
|
||||
5, 1, nil, []byte{'a'},
|
||||
buffer{[]byte{'a', 0}, 1, 1, false, nil},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{'a', 0}, 0, 1, true, io.EOF},
|
||||
5, 1, io.EOF, []byte{'a'},
|
||||
buffer{[]byte{'a', 0}, 1, 1, true, io.EOF},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{0, 'a'}, 1, 2, false, nil},
|
||||
5, 1, nil, []byte{'a'},
|
||||
buffer{[]byte{0, 'a'}, 2, 2, false, nil},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF},
|
||||
5, 1, io.EOF, []byte{'a'},
|
||||
buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{}, 0, 0, false, nil},
|
||||
5, 0, errReadEmpty, []byte{},
|
||||
buffer{[]byte{}, 0, 0, false, nil},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{}, 0, 0, true, io.EOF},
|
||||
5, 0, io.EOF, []byte{},
|
||||
buffer{[]byte{}, 0, 0, true, io.EOF},
|
||||
},
|
||||
}
|
||||
|
||||
func TestBufferRead(t *testing.T) {
|
||||
for i, tt := range bufferReadTests {
|
||||
read := make([]byte, tt.read)
|
||||
n, err := tt.buf.Read(read)
|
||||
if n != tt.wn {
|
||||
t.Errorf("#%d: wn = %d want %d", i, n, tt.wn)
|
||||
continue
|
||||
}
|
||||
if err != tt.werr {
|
||||
t.Errorf("#%d: werr = %v want %v", i, err, tt.werr)
|
||||
continue
|
||||
}
|
||||
read = read[:n]
|
||||
if !reflect.DeepEqual(read, tt.wp) {
|
||||
t.Errorf("#%d: read = %+v want %+v", i, read, tt.wp)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
|
||||
t.Errorf("#%d: buf = %+v want %+v", i, tt.buf, tt.wbuf)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
h2demo.linux: h2demo.go
|
||||
GOOS=linux go build --tags=h2demo -o h2demo.linux .
|
||||
|
||||
upload: h2demo.linux
|
||||
cat h2demo.linux | go run launch.go --write_object=http2-demo-server-tls/h2demo --write_object_is_public
|
|
@ -1,43 +0,0 @@
|
|||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type pipe struct {
|
||||
b buffer
|
||||
c sync.Cond
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// Read waits until data is available and copies bytes
|
||||
// from the buffer into p.
|
||||
func (r *pipe) Read(p []byte) (n int, err error) {
|
||||
r.c.L.Lock()
|
||||
defer r.c.L.Unlock()
|
||||
for r.b.Len() == 0 && !r.b.closed {
|
||||
r.c.Wait()
|
||||
}
|
||||
return r.b.Read(p)
|
||||
}
|
||||
|
||||
// Write copies bytes from p into the buffer and wakes a reader.
|
||||
// It is an error to write more data than the buffer can hold.
|
||||
func (w *pipe) Write(p []byte) (n int, err error) {
|
||||
w.c.L.Lock()
|
||||
defer w.c.L.Unlock()
|
||||
defer w.c.Signal()
|
||||
return w.b.Write(p)
|
||||
}
|
||||
|
||||
func (c *pipe) Close(err error) {
|
||||
c.c.L.Lock()
|
||||
defer c.c.L.Unlock()
|
||||
defer c.c.Signal()
|
||||
c.b.Close(err)
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPipeClose(t *testing.T) {
|
||||
var p pipe
|
||||
p.c.L = &p.m
|
||||
a := errors.New("a")
|
||||
b := errors.New("b")
|
||||
p.Close(a)
|
||||
p.Close(b)
|
||||
_, err := p.Read(make([]byte, 1))
|
||||
if err != a {
|
||||
t.Errorf("err = %v want %v", err, a)
|
||||
}
|
||||
}
|
|
@ -1,553 +0,0 @@
|
|||
// Copyright 2015 The Go Authors.
|
||||
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://go.googlesource.com/go/+/master/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/bradfitz/http2/hpack"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
Fallback http.RoundTripper
|
||||
|
||||
// TODO: remove this and make more general with a TLS dial hook, like http
|
||||
InsecureTLSDial bool
|
||||
|
||||
connMu sync.Mutex
|
||||
conns map[string][]*clientConn // key is host:port
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
t *Transport
|
||||
tconn *tls.Conn
|
||||
tlsState *tls.ConnectionState
|
||||
connKey []string // key(s) this connection is cached in, in t.conns
|
||||
|
||||
readerDone chan struct{} // closed on error
|
||||
readerErr error // set before readerDone is closed
|
||||
hdec *hpack.Decoder
|
||||
nextRes *http.Response
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
|
||||
streams map[uint32]*clientStream
|
||||
nextStreamID uint32
|
||||
bw *bufio.Writer
|
||||
werr error // first write error that has occurred
|
||||
br *bufio.Reader
|
||||
fr *Framer
|
||||
// Settings from peer:
|
||||
maxFrameSize uint32
|
||||
maxConcurrentStreams uint32
|
||||
initialWindowSize uint32
|
||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
henc *hpack.Encoder
|
||||
}
|
||||
|
||||
type clientStream struct {
|
||||
ID uint32
|
||||
resc chan resAndError
|
||||
pw *io.PipeWriter
|
||||
pr *io.PipeReader
|
||||
}
|
||||
|
||||
type stickyErrWriter struct {
|
||||
w io.Writer
|
||||
err *error
|
||||
}
|
||||
|
||||
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
|
||||
if *sew.err != nil {
|
||||
return 0, *sew.err
|
||||
}
|
||||
n, err = sew.w.Write(p)
|
||||
*sew.err = err
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme != "https" {
|
||||
if t.Fallback == nil {
|
||||
return nil, errors.New("http2: unsupported scheme and no Fallback")
|
||||
}
|
||||
return t.Fallback.RoundTrip(req)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(req.URL.Host)
|
||||
if err != nil {
|
||||
host = req.URL.Host
|
||||
port = "443"
|
||||
}
|
||||
|
||||
for {
|
||||
cc, err := t.getClientConn(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := cc.roundTrip(req)
|
||||
if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes any connections which were previously
|
||||
// connected from previous requests but are now sitting idle.
|
||||
// It does not interrupt any connections currently in use.
|
||||
func (t *Transport) CloseIdleConnections() {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
for _, vv := range t.conns {
|
||||
for _, cc := range vv {
|
||||
cc.closeIfIdle()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var errClientConnClosed = errors.New("http2: client conn is closed")
|
||||
|
||||
func shouldRetryRequest(err error) bool {
|
||||
// TODO: or GOAWAY graceful shutdown stuff
|
||||
return err == errClientConnClosed
|
||||
}
|
||||
|
||||
func (t *Transport) removeClientConn(cc *clientConn) {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
for _, key := range cc.connKey {
|
||||
vv, ok := t.conns[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
newList := filterOutClientConn(vv, cc)
|
||||
if len(newList) > 0 {
|
||||
t.conns[key] = newList
|
||||
} else {
|
||||
delete(t.conns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
|
||||
out := in[:0]
|
||||
for _, v := range in {
|
||||
if v != exclude {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
|
||||
key := net.JoinHostPort(host, port)
|
||||
|
||||
for _, cc := range t.conns[key] {
|
||||
if cc.canTakeNewRequest() {
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
if t.conns == nil {
|
||||
t.conns = make(map[string][]*clientConn)
|
||||
}
|
||||
cc, err := t.newClientConn(host, port, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.conns[key] = append(t.conns[key], cc)
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
|
||||
cfg := &tls.Config{
|
||||
ServerName: host,
|
||||
NextProtos: []string{NextProtoTLS},
|
||||
InsecureSkipVerify: t.InsecureTLSDial,
|
||||
}
|
||||
tconn, err := tls.Dial("tcp", host+":"+port, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tconn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !t.InsecureTLSDial {
|
||||
if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
state := tconn.ConnectionState()
|
||||
if p := state.NegotiatedProtocol; p != NextProtoTLS {
|
||||
// TODO(bradfitz): fall back to Fallback
|
||||
return nil, fmt.Errorf("bad protocol: %v", p)
|
||||
}
|
||||
if !state.NegotiatedProtocolIsMutual {
|
||||
return nil, errors.New("could not negotiate protocol mutually")
|
||||
}
|
||||
if _, err := tconn.Write(clientPreface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc := &clientConn{
|
||||
t: t,
|
||||
tconn: tconn,
|
||||
connKey: []string{key}, // TODO: cert's validated hostnames too
|
||||
tlsState: &state,
|
||||
readerDone: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
maxFrameSize: 16 << 10, // spec default
|
||||
initialWindowSize: 65535, // spec default
|
||||
maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
|
||||
streams: make(map[uint32]*clientStream),
|
||||
}
|
||||
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
|
||||
cc.br = bufio.NewReader(tconn)
|
||||
cc.fr = NewFramer(cc.bw, cc.br)
|
||||
cc.henc = hpack.NewEncoder(&cc.hbuf)
|
||||
|
||||
cc.fr.WriteSettings()
|
||||
// TODO: re-send more conn-level flow control tokens when server uses all these.
|
||||
cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
|
||||
cc.bw.Flush()
|
||||
if cc.werr != nil {
|
||||
return nil, cc.werr
|
||||
}
|
||||
|
||||
// Read the obligatory SETTINGS frame
|
||||
f, err := cc.fr.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sf, ok := f.(*SettingsFrame)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected settings frame, got: %T", f)
|
||||
}
|
||||
cc.fr.WriteSettingsAck()
|
||||
cc.bw.Flush()
|
||||
|
||||
sf.ForeachSetting(func(s Setting) error {
|
||||
switch s.ID {
|
||||
case SettingMaxFrameSize:
|
||||
cc.maxFrameSize = s.Val
|
||||
case SettingMaxConcurrentStreams:
|
||||
cc.maxConcurrentStreams = s.Val
|
||||
case SettingInitialWindowSize:
|
||||
cc.initialWindowSize = s.Val
|
||||
default:
|
||||
// TODO(bradfitz): handle more
|
||||
log.Printf("Unhandled Setting: %v", s)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// TODO: figure out henc size
|
||||
cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
|
||||
|
||||
go cc.readLoop()
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (cc *clientConn) setGoAway(f *GoAwayFrame) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cc.goAway = f
|
||||
}
|
||||
|
||||
func (cc *clientConn) canTakeNewRequest() bool {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
return cc.goAway == nil &&
|
||||
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
|
||||
cc.nextStreamID < 2147483647
|
||||
}
|
||||
|
||||
func (cc *clientConn) closeIfIdle() {
|
||||
cc.mu.Lock()
|
||||
if len(cc.streams) > 0 {
|
||||
cc.mu.Unlock()
|
||||
return
|
||||
}
|
||||
cc.closed = true
|
||||
// TODO: do clients send GOAWAY too? maybe? Just Close:
|
||||
cc.mu.Unlock()
|
||||
|
||||
cc.tconn.Close()
|
||||
}
|
||||
|
||||
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
||||
cc.mu.Lock()
|
||||
|
||||
if cc.closed {
|
||||
cc.mu.Unlock()
|
||||
return nil, errClientConnClosed
|
||||
}
|
||||
|
||||
cs := cc.newStream()
|
||||
hasBody := false // TODO
|
||||
|
||||
// we send: HEADERS[+CONTINUATION] + (DATA?)
|
||||
hdrs := cc.encodeHeaders(req)
|
||||
first := true
|
||||
for len(hdrs) > 0 {
|
||||
chunk := hdrs
|
||||
if len(chunk) > int(cc.maxFrameSize) {
|
||||
chunk = chunk[:cc.maxFrameSize]
|
||||
}
|
||||
hdrs = hdrs[len(chunk):]
|
||||
endHeaders := len(hdrs) == 0
|
||||
if first {
|
||||
cc.fr.WriteHeaders(HeadersFrameParam{
|
||||
StreamID: cs.ID,
|
||||
BlockFragment: chunk,
|
||||
EndStream: !hasBody,
|
||||
EndHeaders: endHeaders,
|
||||
})
|
||||
first = false
|
||||
} else {
|
||||
cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
|
||||
}
|
||||
}
|
||||
cc.bw.Flush()
|
||||
werr := cc.werr
|
||||
cc.mu.Unlock()
|
||||
|
||||
if hasBody {
|
||||
// TODO: write data. and it should probably be interleaved:
|
||||
// go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
|
||||
}
|
||||
|
||||
if werr != nil {
|
||||
return nil, werr
|
||||
}
|
||||
|
||||
re := <-cs.resc
|
||||
if re.err != nil {
|
||||
return nil, re.err
|
||||
}
|
||||
res := re.res
|
||||
res.Request = req
|
||||
res.TLS = cc.tlsState
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// requires cc.mu be held.
|
||||
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
|
||||
cc.hbuf.Reset()
|
||||
|
||||
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
path := req.URL.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
cc.writeHeader(":authority", host) // probably not right for all sites
|
||||
cc.writeHeader(":method", req.Method)
|
||||
cc.writeHeader(":path", path)
|
||||
cc.writeHeader(":scheme", "https")
|
||||
|
||||
for k, vv := range req.Header {
|
||||
lowKey := strings.ToLower(k)
|
||||
if lowKey == "host" {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
cc.writeHeader(lowKey, v)
|
||||
}
|
||||
}
|
||||
return cc.hbuf.Bytes()
|
||||
}
|
||||
|
||||
func (cc *clientConn) writeHeader(name, value string) {
|
||||
log.Printf("sending %q = %q", name, value)
|
||||
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
type resAndError struct {
|
||||
res *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
// requires cc.mu be held.
|
||||
func (cc *clientConn) newStream() *clientStream {
|
||||
cs := &clientStream{
|
||||
ID: cc.nextStreamID,
|
||||
resc: make(chan resAndError, 1),
|
||||
}
|
||||
cc.nextStreamID += 2
|
||||
cc.streams[cs.ID] = cs
|
||||
return cs
|
||||
}
|
||||
|
||||
func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cs := cc.streams[id]
|
||||
if andRemove {
|
||||
delete(cc.streams, id)
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
// runs in its own goroutine.
|
||||
func (cc *clientConn) readLoop() {
|
||||
defer cc.t.removeClientConn(cc)
|
||||
defer close(cc.readerDone)
|
||||
|
||||
activeRes := map[uint32]*clientStream{} // keyed by streamID
|
||||
// Close any response bodies if the server closes prematurely.
|
||||
// TODO: also do this if we've written the headers but not
|
||||
// gotten a response yet.
|
||||
defer func() {
|
||||
err := cc.readerErr
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
for _, cs := range activeRes {
|
||||
cs.pw.CloseWithError(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// continueStreamID is the stream ID we're waiting for
|
||||
// continuation frames for.
|
||||
var continueStreamID uint32
|
||||
|
||||
for {
|
||||
f, err := cc.fr.ReadFrame()
|
||||
if err != nil {
|
||||
cc.readerErr = err
|
||||
return
|
||||
}
|
||||
log.Printf("Transport received %v: %#v", f.Header(), f)
|
||||
|
||||
streamID := f.Header().StreamID
|
||||
|
||||
_, isContinue := f.(*ContinuationFrame)
|
||||
if isContinue {
|
||||
if streamID != continueStreamID {
|
||||
log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
|
||||
cc.readerErr = ConnectionError(ErrCodeProtocol)
|
||||
return
|
||||
}
|
||||
} else if continueStreamID != 0 {
|
||||
// Continue frames need to be adjacent in the stream
|
||||
// and we were in the middle of headers.
|
||||
log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
|
||||
cc.readerErr = ConnectionError(ErrCodeProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
if streamID%2 == 0 {
|
||||
// Ignore streams pushed from the server for now.
|
||||
// These always have an even stream id.
|
||||
continue
|
||||
}
|
||||
streamEnded := false
|
||||
if ff, ok := f.(streamEnder); ok {
|
||||
streamEnded = ff.StreamEnded()
|
||||
}
|
||||
|
||||
cs := cc.streamByID(streamID, streamEnded)
|
||||
if cs == nil {
|
||||
log.Printf("Received frame for untracked stream ID %d", streamID)
|
||||
continue
|
||||
}
|
||||
|
||||
switch f := f.(type) {
|
||||
case *HeadersFrame:
|
||||
cc.nextRes = &http.Response{
|
||||
Proto: "HTTP/2.0",
|
||||
ProtoMajor: 2,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
cs.pr, cs.pw = io.Pipe()
|
||||
cc.hdec.Write(f.HeaderBlockFragment())
|
||||
case *ContinuationFrame:
|
||||
cc.hdec.Write(f.HeaderBlockFragment())
|
||||
case *DataFrame:
|
||||
log.Printf("DATA: %q", f.Data())
|
||||
cs.pw.Write(f.Data())
|
||||
case *GoAwayFrame:
|
||||
cc.t.removeClientConn(cc)
|
||||
if f.ErrCode != 0 {
|
||||
// TODO: deal with GOAWAY more. particularly the error code
|
||||
log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
|
||||
}
|
||||
cc.setGoAway(f)
|
||||
default:
|
||||
log.Printf("Transport: unhandled response frame type %T", f)
|
||||
}
|
||||
headersEnded := false
|
||||
if he, ok := f.(headersEnder); ok {
|
||||
headersEnded = he.HeadersEnded()
|
||||
if headersEnded {
|
||||
continueStreamID = 0
|
||||
} else {
|
||||
continueStreamID = streamID
|
||||
}
|
||||
}
|
||||
|
||||
if streamEnded {
|
||||
cs.pw.Close()
|
||||
delete(activeRes, streamID)
|
||||
}
|
||||
if headersEnded {
|
||||
if cs == nil {
|
||||
panic("couldn't find stream") // TODO be graceful
|
||||
}
|
||||
// TODO: set the Body to one which notes the
|
||||
// Close and also sends the server a
|
||||
// RST_STREAM
|
||||
cc.nextRes.Body = cs.pr
|
||||
res := cc.nextRes
|
||||
activeRes[streamID] = cs
|
||||
cs.resc <- resAndError{res: res}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
|
||||
// TODO: verifiy pseudo headers come before non-pseudo headers
|
||||
// TODO: verifiy the status is set
|
||||
log.Printf("Header field: %+v", f)
|
||||
if f.Name == ":status" {
|
||||
code, err := strconv.Atoi(f.Value)
|
||||
if err != nil {
|
||||
panic("TODO: be graceful")
|
||||
}
|
||||
cc.nextRes.Status = f.Value + " " + http.StatusText(code)
|
||||
cc.nextRes.StatusCode = code
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(f.Name, ":") {
|
||||
// "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
|
||||
// TODO: treat as invalid?
|
||||
return
|
||||
}
|
||||
cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
|
||||
}
|
|
@ -1,168 +0,0 @@
|
|||
// Copyright 2015 The Go Authors.
|
||||
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://go.googlesource.com/go/+/master/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
extNet = flag.Bool("extnet", false, "do external network tests")
|
||||
transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
|
||||
insecure = flag.Bool("insecure", false, "insecure TLS dials")
|
||||
)
|
||||
|
||||
func TestTransportExternal(t *testing.T) {
|
||||
if !*extNet {
|
||||
t.Skip("skipping external network test")
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
|
||||
rt := &Transport{
|
||||
InsecureTLSDial: *insecure,
|
||||
}
|
||||
res, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
res.Write(os.Stdout)
|
||||
}
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
const body = "sup"
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, body)
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
tr := &Transport{InsecureTLSDial: true}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
t.Logf("Got res: %+v", res)
|
||||
if g, w := res.StatusCode, 200; g != w {
|
||||
t.Errorf("StatusCode = %v; want %v", g, w)
|
||||
}
|
||||
if g, w := res.Status, "200 OK"; g != w {
|
||||
t.Errorf("Status = %q; want %q", g, w)
|
||||
}
|
||||
wantHeader := http.Header{
|
||||
"Content-Length": []string{"3"},
|
||||
"Content-Type": []string{"text/plain; charset=utf-8"},
|
||||
}
|
||||
if !reflect.DeepEqual(res.Header, wantHeader) {
|
||||
t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
|
||||
}
|
||||
if res.Request != req {
|
||||
t.Errorf("Response.Request = %p; want %p", res.Request, req)
|
||||
}
|
||||
if res.TLS == nil {
|
||||
t.Errorf("Response.TLS = nil; want non-nil", res.TLS)
|
||||
}
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Error("Body read: %v", err)
|
||||
} else if string(slurp) != body {
|
||||
t.Errorf("Body = %q; want %q", slurp, body)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestTransportReusesConns(t *testing.T) {
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, r.RemoteAddr)
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
tr := &Transport{InsecureTLSDial: true}
|
||||
defer tr.CloseIdleConnections()
|
||||
get := func() string {
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Body read: %v", err)
|
||||
}
|
||||
addr := strings.TrimSpace(string(slurp))
|
||||
if addr == "" {
|
||||
t.Fatalf("didn't get an addr in response")
|
||||
}
|
||||
return addr
|
||||
}
|
||||
first := get()
|
||||
second := get()
|
||||
if first != second {
|
||||
t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportAbortClosesPipes(t *testing.T) {
|
||||
shutdown := make(chan struct{})
|
||||
st := newServerTester(t,
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.(http.Flusher).Flush()
|
||||
<-shutdown
|
||||
},
|
||||
optOnlyServer,
|
||||
)
|
||||
defer st.Close()
|
||||
defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
|
||||
|
||||
done := make(chan struct{})
|
||||
requestMade := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
tr := &Transport{
|
||||
InsecureTLSDial: true,
|
||||
}
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
close(requestMade)
|
||||
_, err = ioutil.ReadAll(res.Body)
|
||||
if err == nil {
|
||||
t.Error("expected error from res.Body.Read")
|
||||
}
|
||||
}()
|
||||
|
||||
<-requestMade
|
||||
// Now force the serve loop to end, via closing the connection.
|
||||
st.closeConn()
|
||||
// deadlock? that's a bug.
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
language: go
|
||||
go:
|
||||
- 1.4.2
|
||||
sudo: false
|
|
@ -57,6 +57,12 @@ bar.ShowTimeLeft = true
|
|||
// show average speed
|
||||
bar.ShowSpeed = true
|
||||
|
||||
// sets the width of the progress bar
|
||||
bar.SetWidth(80)
|
||||
|
||||
// sets the width of the progress bar, but if terminal size smaller will be ignored
|
||||
bar.SetMaxWidth(80)
|
||||
|
||||
// convert output to readable format (like KB, MB)
|
||||
bar.SetUnits(pb.U_BYTES)
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/cheggaaa/pb"
|
||||
"os"
|
||||
"fmt"
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/cheggaaa/pb"
|
||||
"io"
|
||||
"time"
|
||||
"strings"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -18,12 +18,12 @@ func main() {
|
|||
return
|
||||
}
|
||||
sourceName, destName := os.Args[1], os.Args[2]
|
||||
|
||||
|
||||
// check source
|
||||
var source io.Reader
|
||||
var sourceSize int64
|
||||
if strings.HasPrefix(sourceName, "http://") {
|
||||
// open as url
|
||||
// open as url
|
||||
resp, err := http.Get(sourceName)
|
||||
if err != nil {
|
||||
fmt.Printf("Can't get %s: %v\n", sourceName, err)
|
||||
|
@ -54,9 +54,7 @@ func main() {
|
|||
sourceSize = sourceStat.Size()
|
||||
source = s
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// create dest
|
||||
dest, err := os.Create(destName)
|
||||
if err != nil {
|
||||
|
@ -64,15 +62,15 @@ func main() {
|
|||
return
|
||||
}
|
||||
defer dest.Close()
|
||||
|
||||
// create bar
|
||||
|
||||
// create bar
|
||||
bar := pb.New(int(sourceSize)).SetUnits(pb.U_BYTES).SetRefreshRate(time.Millisecond * 10)
|
||||
bar.ShowSpeed = true
|
||||
bar.Start()
|
||||
|
||||
|
||||
// create multi writer
|
||||
writer := io.MultiWriter(dest, bar)
|
||||
|
||||
|
||||
// and copy
|
||||
io.Copy(writer, source)
|
||||
bar.Finish()
|
|
@ -1,7 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/cheggaaa/pb"
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/cheggaaa/pb"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ func main() {
|
|||
bar.ShowPercent = true
|
||||
|
||||
// show bar (by default already true)
|
||||
bar.ShowPercent = true
|
||||
bar.ShowBar = true
|
||||
|
||||
// no need counters
|
||||
bar.ShowCounters = true
|
|
@ -0,0 +1,45 @@
|
|||
package pb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Units int
|
||||
|
||||
const (
|
||||
// By default, without type handle
|
||||
U_NO Units = iota
|
||||
// Handle as b, Kb, Mb, etc
|
||||
U_BYTES
|
||||
)
|
||||
|
||||
// Format integer
|
||||
func Format(i int64, units Units) string {
|
||||
switch units {
|
||||
case U_BYTES:
|
||||
return FormatBytes(i)
|
||||
default:
|
||||
// by default just convert to string
|
||||
return strconv.FormatInt(i, 10)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert bytes to human readable string. Like a 2 MB, 64.2 KB, 52 B
|
||||
func FormatBytes(i int64) (result string) {
|
||||
switch {
|
||||
case i > (1024 * 1024 * 1024 * 1024):
|
||||
result = fmt.Sprintf("%.02f TB", float64(i)/1024/1024/1024/1024)
|
||||
case i > (1024 * 1024 * 1024):
|
||||
result = fmt.Sprintf("%.02f GB", float64(i)/1024/1024/1024)
|
||||
case i > (1024 * 1024):
|
||||
result = fmt.Sprintf("%.02f MB", float64(i)/1024/1024)
|
||||
case i > 1024:
|
||||
result = fmt.Sprintf("%.02f KB", float64(i)/1024)
|
||||
default:
|
||||
result = fmt.Sprintf("%d B", i)
|
||||
}
|
||||
result = strings.Trim(result, " ")
|
||||
return
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package pb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_DefaultsToInteger(t *testing.T) {
|
||||
value := int64(1000)
|
||||
expected := strconv.Itoa(int(value))
|
||||
actual := Format(value, -1)
|
||||
|
||||
if actual != expected {
|
||||
t.Error(fmt.Sprintf("Expected {%s} was {%s}", expected, actual))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CanFormatAsInteger(t *testing.T) {
|
||||
value := int64(1000)
|
||||
expected := strconv.Itoa(int(value))
|
||||
actual := Format(value, U_NO)
|
||||
|
||||
if actual != expected {
|
||||
t.Error(fmt.Sprintf("Expected {%s} was {%s}", expected, actual))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CanFormatAsBytes(t *testing.T) {
|
||||
value := int64(1000)
|
||||
expected := "1000 B"
|
||||
actual := Format(value, U_BYTES)
|
||||
|
||||
if actual != expected {
|
||||
t.Error(fmt.Sprintf("Expected {%s} was {%s}", expected, actual))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,367 @@
|
|||
package pb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default refresh rate - 200ms
|
||||
DEFAULT_REFRESH_RATE = time.Millisecond * 200
|
||||
FORMAT = "[=>-]"
|
||||
)
|
||||
|
||||
// DEPRECATED
|
||||
// variables for backward compatibility, from now do not work
|
||||
// use pb.Format and pb.SetRefreshRate
|
||||
var (
|
||||
DefaultRefreshRate = DEFAULT_REFRESH_RATE
|
||||
BarStart, BarEnd, Empty, Current, CurrentN string
|
||||
)
|
||||
|
||||
// Create new progress bar object
|
||||
func New(total int) *ProgressBar {
|
||||
return New64(int64(total))
|
||||
}
|
||||
|
||||
// Create new progress bar object uding int64 as total
|
||||
func New64(total int64) *ProgressBar {
|
||||
pb := &ProgressBar{
|
||||
Total: total,
|
||||
RefreshRate: DEFAULT_REFRESH_RATE,
|
||||
ShowPercent: true,
|
||||
ShowCounters: true,
|
||||
ShowBar: true,
|
||||
ShowTimeLeft: true,
|
||||
ShowFinalTime: true,
|
||||
Units: U_NO,
|
||||
ManualUpdate: false,
|
||||
isFinish: make(chan struct{}),
|
||||
currentValue: -1,
|
||||
}
|
||||
return pb.Format(FORMAT)
|
||||
}
|
||||
|
||||
// Create new object and start
|
||||
func StartNew(total int) *ProgressBar {
|
||||
return New(total).Start()
|
||||
}
|
||||
|
||||
// Callback for custom output
|
||||
// For example:
|
||||
// bar.Callback = func(s string) {
|
||||
// mySuperPrint(s)
|
||||
// }
|
||||
//
|
||||
type Callback func(out string)
|
||||
|
||||
type ProgressBar struct {
|
||||
current int64 // current must be first member of struct (https://code.google.com/p/go/issues/detail?id=5278)
|
||||
|
||||
Total int64
|
||||
RefreshRate time.Duration
|
||||
ShowPercent, ShowCounters bool
|
||||
ShowSpeed, ShowTimeLeft, ShowBar bool
|
||||
ShowFinalTime bool
|
||||
Output io.Writer
|
||||
Callback Callback
|
||||
NotPrint bool
|
||||
Units Units
|
||||
Width int
|
||||
ForceWidth bool
|
||||
ManualUpdate bool
|
||||
|
||||
finishOnce sync.Once //Guards isFinish
|
||||
isFinish chan struct{}
|
||||
|
||||
startTime time.Time
|
||||
startValue int64
|
||||
currentValue int64
|
||||
|
||||
prefix, postfix string
|
||||
|
||||
BarStart string
|
||||
BarEnd string
|
||||
Empty string
|
||||
Current string
|
||||
CurrentN string
|
||||
}
|
||||
|
||||
// Start print
|
||||
func (pb *ProgressBar) Start() *ProgressBar {
|
||||
pb.startTime = time.Now()
|
||||
pb.startValue = pb.current
|
||||
if pb.Total == 0 {
|
||||
pb.ShowTimeLeft = false
|
||||
pb.ShowPercent = false
|
||||
}
|
||||
if !pb.ManualUpdate {
|
||||
go pb.writer()
|
||||
}
|
||||
return pb
|
||||
}
|
||||
|
||||
// Increment current value
|
||||
func (pb *ProgressBar) Increment() int {
|
||||
return pb.Add(1)
|
||||
}
|
||||
|
||||
// Set current value
|
||||
func (pb *ProgressBar) Set(current int) *ProgressBar {
|
||||
return pb.Set64(int64(current))
|
||||
}
|
||||
|
||||
// Set64 sets the current value as int64
|
||||
func (pb *ProgressBar) Set64(current int64) *ProgressBar {
|
||||
atomic.StoreInt64(&pb.current, current)
|
||||
return pb
|
||||
}
|
||||
|
||||
// Add to current value
|
||||
func (pb *ProgressBar) Add(add int) int {
|
||||
return int(pb.Add64(int64(add)))
|
||||
}
|
||||
|
||||
func (pb *ProgressBar) Add64(add int64) int64 {
|
||||
return atomic.AddInt64(&pb.current, add)
|
||||
}
|
||||
|
||||
// Set prefix string
|
||||
func (pb *ProgressBar) Prefix(prefix string) *ProgressBar {
|
||||
pb.prefix = prefix
|
||||
return pb
|
||||
}
|
||||
|
||||
// Set postfix string
|
||||
func (pb *ProgressBar) Postfix(postfix string) *ProgressBar {
|
||||
pb.postfix = postfix
|
||||
return pb
|
||||
}
|
||||
|
||||
// Set custom format for bar
|
||||
// Example: bar.Format("[=>_]")
|
||||
func (pb *ProgressBar) Format(format string) *ProgressBar {
|
||||
formatEntries := strings.Split(format, "")
|
||||
if len(formatEntries) == 5 {
|
||||
pb.BarStart = formatEntries[0]
|
||||
pb.BarEnd = formatEntries[4]
|
||||
pb.Empty = formatEntries[3]
|
||||
pb.Current = formatEntries[1]
|
||||
pb.CurrentN = formatEntries[2]
|
||||
}
|
||||
return pb
|
||||
}
|
||||
|
||||
// Set bar refresh rate
|
||||
func (pb *ProgressBar) SetRefreshRate(rate time.Duration) *ProgressBar {
|
||||
pb.RefreshRate = rate
|
||||
return pb
|
||||
}
|
||||
|
||||
// Set units
|
||||
// bar.SetUnits(U_NO) - by default
|
||||
// bar.SetUnits(U_BYTES) - for Mb, Kb, etc
|
||||
func (pb *ProgressBar) SetUnits(units Units) *ProgressBar {
|
||||
pb.Units = units
|
||||
return pb
|
||||
}
|
||||
|
||||
// Set max width, if width is bigger than terminal width, will be ignored
|
||||
func (pb *ProgressBar) SetMaxWidth(width int) *ProgressBar {
|
||||
pb.Width = width
|
||||
pb.ForceWidth = false
|
||||
return pb
|
||||
}
|
||||
|
||||
// Set bar width
|
||||
func (pb *ProgressBar) SetWidth(width int) *ProgressBar {
|
||||
pb.Width = width
|
||||
pb.ForceWidth = true
|
||||
return pb
|
||||
}
|
||||
|
||||
// End print
|
||||
func (pb *ProgressBar) Finish() {
|
||||
//Protect multiple calls
|
||||
pb.finishOnce.Do(func() {
|
||||
close(pb.isFinish)
|
||||
pb.write(atomic.LoadInt64(&pb.current))
|
||||
if !pb.NotPrint {
|
||||
fmt.Println()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// End print and write string 'str'
|
||||
func (pb *ProgressBar) FinishPrint(str string) {
|
||||
pb.Finish()
|
||||
fmt.Println(str)
|
||||
}
|
||||
|
||||
// implement io.Writer
|
||||
func (pb *ProgressBar) Write(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
pb.Add(n)
|
||||
return
|
||||
}
|
||||
|
||||
// implement io.Reader
|
||||
func (pb *ProgressBar) Read(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
pb.Add(n)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new proxy reader over bar
|
||||
func (pb *ProgressBar) NewProxyReader(r io.Reader) *Reader {
|
||||
return &Reader{r, pb}
|
||||
}
|
||||
|
||||
func (pb *ProgressBar) write(current int64) {
|
||||
width := pb.GetWidth()
|
||||
|
||||
var percentBox, countersBox, timeLeftBox, speedBox, barBox, end, out string
|
||||
|
||||
// percents
|
||||
if pb.ShowPercent {
|
||||
percent := float64(current) / (float64(pb.Total) / float64(100))
|
||||
percentBox = fmt.Sprintf(" %.02f %% ", percent)
|
||||
}
|
||||
|
||||
// counters
|
||||
if pb.ShowCounters {
|
||||
if pb.Total > 0 {
|
||||
countersBox = fmt.Sprintf("%s / %s ", Format(current, pb.Units), Format(pb.Total, pb.Units))
|
||||
} else {
|
||||
countersBox = Format(current, pb.Units) + " / ? "
|
||||
}
|
||||
}
|
||||
|
||||
// time left
|
||||
fromStart := time.Now().Sub(pb.startTime)
|
||||
currentFromStart := current - pb.startValue
|
||||
select {
|
||||
case <-pb.isFinish:
|
||||
if pb.ShowFinalTime {
|
||||
left := (fromStart / time.Second) * time.Second
|
||||
timeLeftBox = left.String()
|
||||
}
|
||||
default:
|
||||
if pb.ShowTimeLeft && currentFromStart > 0 {
|
||||
perEntry := fromStart / time.Duration(currentFromStart)
|
||||
left := time.Duration(pb.Total-currentFromStart) * perEntry
|
||||
left = (left / time.Second) * time.Second
|
||||
timeLeftBox = left.String()
|
||||
}
|
||||
}
|
||||
|
||||
// speed
|
||||
if pb.ShowSpeed && currentFromStart > 0 {
|
||||
fromStart := time.Now().Sub(pb.startTime)
|
||||
speed := float64(currentFromStart) / (float64(fromStart) / float64(time.Second))
|
||||
speedBox = Format(int64(speed), pb.Units) + "/s "
|
||||
}
|
||||
|
||||
barWidth := utf8.RuneCountInString(countersBox + pb.BarStart + pb.BarEnd + percentBox + timeLeftBox + speedBox + pb.prefix + pb.postfix)
|
||||
// bar
|
||||
if pb.ShowBar {
|
||||
size := width - barWidth
|
||||
if size > 0 {
|
||||
if pb.Total > 0 {
|
||||
curCount := int(math.Ceil((float64(current) / float64(pb.Total)) * float64(size)))
|
||||
emptCount := size - curCount
|
||||
barBox = pb.BarStart
|
||||
if emptCount < 0 {
|
||||
emptCount = 0
|
||||
}
|
||||
if curCount > size {
|
||||
curCount = size
|
||||
}
|
||||
if emptCount <= 0 {
|
||||
barBox += strings.Repeat(pb.Current, curCount)
|
||||
} else if curCount > 0 {
|
||||
barBox += strings.Repeat(pb.Current, curCount-1) + pb.CurrentN
|
||||
}
|
||||
|
||||
barBox += strings.Repeat(pb.Empty, emptCount) + pb.BarEnd
|
||||
} else {
|
||||
|
||||
barBox = pb.BarStart
|
||||
pos := size - int(current)%int(size)
|
||||
if pos-1 > 0 {
|
||||
barBox += strings.Repeat(pb.Empty, pos-1)
|
||||
}
|
||||
barBox += pb.Current
|
||||
if size-pos-1 > 0 {
|
||||
barBox += strings.Repeat(pb.Empty, size-pos-1)
|
||||
}
|
||||
barBox += pb.BarEnd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check len
|
||||
out = pb.prefix + countersBox + barBox + percentBox + speedBox + timeLeftBox + pb.postfix
|
||||
if utf8.RuneCountInString(out) < width {
|
||||
end = strings.Repeat(" ", width-utf8.RuneCountInString(out))
|
||||
}
|
||||
|
||||
// and print!
|
||||
switch {
|
||||
case pb.Output != nil:
|
||||
fmt.Fprint(pb.Output, "\r"+out+end)
|
||||
case pb.Callback != nil:
|
||||
pb.Callback(out + end)
|
||||
case !pb.NotPrint:
|
||||
fmt.Print("\r" + out + end)
|
||||
}
|
||||
}
|
||||
|
||||
func (pb *ProgressBar) GetWidth() int {
|
||||
if pb.ForceWidth {
|
||||
return pb.Width
|
||||
}
|
||||
|
||||
width := pb.Width
|
||||
termWidth, _ := terminalWidth()
|
||||
if width == 0 || termWidth <= width {
|
||||
width = termWidth
|
||||
}
|
||||
|
||||
return width
|
||||
}
|
||||
|
||||
// Write the current state of the progressbar
|
||||
func (pb *ProgressBar) Update() {
|
||||
c := atomic.LoadInt64(&pb.current)
|
||||
if c != pb.currentValue {
|
||||
pb.write(c)
|
||||
pb.currentValue = c
|
||||
}
|
||||
}
|
||||
|
||||
// Internal loop for writing progressbar
|
||||
func (pb *ProgressBar) writer() {
|
||||
pb.Update()
|
||||
for {
|
||||
select {
|
||||
case <-pb.isFinish:
|
||||
return
|
||||
case <-time.After(pb.RefreshRate):
|
||||
pb.Update()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type window struct {
|
||||
Row uint16
|
||||
Col uint16
|
||||
Xpixel uint16
|
||||
Ypixel uint16
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
// +build linux darwin freebsd netbsd openbsd
|
||||
|
||||
package pb
|
||||
|
||||
import "syscall"
|
||||
|
||||
const sys_ioctl = syscall.SYS_IOCTL
|
|
@ -0,0 +1,5 @@
|
|||
// +build solaris
|
||||
|
||||
package pb
|
||||
|
||||
const sys_ioctl = 54
|
|
@ -0,0 +1,37 @@
|
|||
package pb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_IncrementAddsOne(t *testing.T) {
|
||||
count := 5000
|
||||
bar := New(count)
|
||||
expected := 1
|
||||
actual := bar.Increment()
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("Expected {%d} was {%d}", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_Width(t *testing.T) {
|
||||
count := 5000
|
||||
bar := New(count)
|
||||
width := 100
|
||||
bar.SetWidth(100).Callback = func(out string) {
|
||||
if len(out) != width {
|
||||
t.Errorf("Bar width expected {%d} was {%d}", len(out), width)
|
||||
}
|
||||
}
|
||||
bar.Start()
|
||||
bar.Increment()
|
||||
bar.Finish()
|
||||
}
|
||||
|
||||
func Test_MultipleFinish(t *testing.T) {
|
||||
bar := New(5000)
|
||||
bar.Add(2000)
|
||||
bar.Finish()
|
||||
bar.Finish()
|
||||
}
|
|
@ -1,16 +1,16 @@
|
|||
// +build windows
|
||||
|
||||
package pb
|
||||
|
||||
import (
|
||||
"github.com/olekukonko/ts"
|
||||
)
|
||||
|
||||
func bold(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
func terminalWidth() (int, error) {
|
||||
size , err := ts.GetSize()
|
||||
return size.Col() , err
|
||||
}
|
||||
// +build windows
|
||||
|
||||
package pb
|
||||
|
||||
import (
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/olekukonko/ts"
|
||||
)
|
||||
|
||||
func bold(str string) string {
|
||||
return str
|
||||
}
|
||||
|
||||
func terminalWidth() (int, error) {
|
||||
size, err := ts.GetSize()
|
||||
return size.Col(), err
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
// +build linux darwin freebsd
|
||||
// +build linux darwin freebsd netbsd openbsd solaris
|
||||
|
||||
package pb
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
@ -13,6 +14,16 @@ const (
|
|||
TIOCGWINSZ_OSX = 1074295912
|
||||
)
|
||||
|
||||
var tty *os.File
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
tty, err = os.Open("/dev/tty")
|
||||
if err != nil {
|
||||
tty = os.Stdin
|
||||
}
|
||||
}
|
||||
|
||||
func bold(str string) string {
|
||||
return "\033[1m" + str + "\033[0m"
|
||||
}
|
||||
|
@ -23,8 +34,8 @@ func terminalWidth() (int, error) {
|
|||
if runtime.GOOS == "darwin" {
|
||||
tio = TIOCGWINSZ_OSX
|
||||
}
|
||||
res, _, err := syscall.Syscall(syscall.SYS_IOCTL,
|
||||
uintptr(syscall.Stdin),
|
||||
res, _, err := syscall.Syscall(sys_ioctl,
|
||||
tty.Fd(),
|
||||
uintptr(tio),
|
||||
uintptr(unsafe.Pointer(w)),
|
||||
)
|
|
@ -0,0 +1,17 @@
|
|||
package pb
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// It's proxy reader, implement io.Reader
|
||||
type Reader struct {
|
||||
io.Reader
|
||||
bar *ProgressBar
|
||||
}
|
||||
|
||||
func (r *Reader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.Reader.Read(p)
|
||||
r.bar.Add(n)
|
||||
return
|
||||
}
|
|
@ -1,5 +1,18 @@
|
|||
language: go
|
||||
go: 1.1
|
||||
sudo: false
|
||||
|
||||
go:
|
||||
- 1.0.3
|
||||
- 1.1.2
|
||||
- 1.2.2
|
||||
- 1.3.3
|
||||
- 1.4.2
|
||||
- 1.5.1
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
script:
|
||||
- go vet ./...
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
[![Coverage](http://gocover.io/_badge/github.com/codegangsta/cli?0)](http://gocover.io/github.com/codegangsta/cli)
|
||||
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli)
|
||||
[![GoDoc](https://godoc.org/github.com/codegangsta/cli?status.svg)](https://godoc.org/github.com/codegangsta/cli)
|
||||
|
||||
# cli.go
|
||||
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
|
||||
|
||||
You can view the API docs here:
|
||||
http://godoc.org/github.com/codegangsta/cli
|
||||
`cli.go` is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
|
||||
|
||||
## Overview
|
||||
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
|
||||
|
||||
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive!
|
||||
**This is where `cli.go` comes into play.** `cli.go` makes command line programming fun, organized, and expressive!
|
||||
|
||||
## Installation
|
||||
Make sure you have a working Go environment (go 1.1 is *required*). [See the install instructions](http://golang.org/doc/install.html).
|
||||
Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html).
|
||||
|
||||
To install `cli.go`, simply run:
|
||||
```
|
||||
|
@ -25,7 +24,7 @@ export PATH=$PATH:$GOPATH/bin
|
|||
```
|
||||
|
||||
## Getting Started
|
||||
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`.
|
||||
One of the philosophies behind `cli.go` is that an API should be playful and full of discovery. So a `cli.go` app can be as little as one line of code in `main()`.
|
||||
|
||||
``` go
|
||||
package main
|
||||
|
@ -68,8 +67,9 @@ Running this already gives you a ton of functionality, plus support for things l
|
|||
|
||||
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness!
|
||||
|
||||
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it:
|
||||
|
||||
``` go
|
||||
/* greet.go */
|
||||
package main
|
||||
|
||||
import (
|
||||
|
@ -84,7 +84,7 @@ func main() {
|
|||
app.Action = func(c *cli.Context) {
|
||||
println("Hello friend!")
|
||||
}
|
||||
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
||||
```
|
||||
|
@ -102,7 +102,8 @@ $ greet
|
|||
Hello friend!
|
||||
```
|
||||
|
||||
cli.go also generates some bitchass help text:
|
||||
`cli.go` also generates neat help text:
|
||||
|
||||
```
|
||||
$ greet help
|
||||
NAME:
|
||||
|
@ -157,6 +158,34 @@ app.Action = func(c *cli.Context) {
|
|||
...
|
||||
```
|
||||
|
||||
You can also set a destination variable for a flag, to which the content will be scanned.
|
||||
``` go
|
||||
...
|
||||
var language string
|
||||
app.Flags = []cli.Flag {
|
||||
cli.StringFlag{
|
||||
Name: "lang",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
Destination: &language,
|
||||
},
|
||||
}
|
||||
app.Action = func(c *cli.Context) {
|
||||
name := "someone"
|
||||
if len(c.Args()) > 0 {
|
||||
name = c.Args()[0]
|
||||
}
|
||||
if language == "spanish" {
|
||||
println("Hola", name)
|
||||
} else {
|
||||
println("Hello", name)
|
||||
}
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
See full list of flags at http://godoc.org/github.com/codegangsta/cli
|
||||
|
||||
#### Alternate Names
|
||||
|
||||
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g.
|
||||
|
@ -171,6 +200,8 @@ app.Flags = []cli.Flag {
|
|||
}
|
||||
```
|
||||
|
||||
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
|
||||
|
||||
#### Values from the Environment
|
||||
|
||||
You can also have the default value set from the environment via `EnvVar`. e.g.
|
||||
|
@ -186,7 +217,18 @@ app.Flags = []cli.Flag {
|
|||
}
|
||||
```
|
||||
|
||||
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
|
||||
The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default.
|
||||
|
||||
``` go
|
||||
app.Flags = []cli.Flag {
|
||||
cli.StringFlag{
|
||||
Name: "lang, l",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG",
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### Subcommands
|
||||
|
||||
|
@ -196,7 +238,7 @@ Subcommands can be defined for a more git-like command line app.
|
|||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "add",
|
||||
ShortName: "a",
|
||||
Aliases: []string{"a"},
|
||||
Usage: "add a task to the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("added task: ", c.Args().First())
|
||||
|
@ -204,7 +246,7 @@ app.Commands = []cli.Command{
|
|||
},
|
||||
{
|
||||
Name: "complete",
|
||||
ShortName: "c",
|
||||
Aliases: []string{"c"},
|
||||
Usage: "complete a task on the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("completed task: ", c.Args().First())
|
||||
|
@ -212,7 +254,7 @@ app.Commands = []cli.Command{
|
|||
},
|
||||
{
|
||||
Name: "template",
|
||||
ShortName: "r",
|
||||
Aliases: []string{"r"},
|
||||
Usage: "options for task templates",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
|
@ -230,7 +272,7 @@ app.Commands = []cli.Command{
|
|||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
...
|
||||
```
|
||||
|
@ -248,8 +290,8 @@ app := cli.NewApp()
|
|||
app.EnableBashCompletion = true
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "complete",
|
||||
ShortName: "c",
|
||||
Name: "complete",
|
||||
Aliases: []string{"c"},
|
||||
Usage: "complete a task on the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("completed task: ", c.Args().First())
|
||||
|
@ -275,13 +317,25 @@ setting the `PROG` variable to the name of your program:
|
|||
|
||||
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete`
|
||||
|
||||
#### To Distribute
|
||||
|
||||
Copy `autocomplete/bash_autocomplete` into `/etc/bash_completion.d/` and rename
|
||||
it to the name of the program you wish to add autocomplete support for (or
|
||||
automatically install it there if you are distributing a package). Don't forget
|
||||
to source the file to make it active in the current shell.
|
||||
|
||||
```
|
||||
sudo cp src/bash_autocomplete /etc/bash_completion.d/<myprogram>
|
||||
source /etc/bash_completion.d/<myprogram>
|
||||
```
|
||||
|
||||
Alternatively, you can just document that users should source the generic
|
||||
`autocomplete/bash_autocomplete` in their bash configuration with `$PROG` set
|
||||
to the name of their program (as above).
|
||||
|
||||
## Contribution Guidelines
|
||||
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
|
||||
|
||||
If you are have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
|
||||
If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
|
||||
|
||||
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out.
|
||||
|
||||
## About
|
||||
cli.go is written by none other than the [Code Gangsta](http://codegangsta.io)
|
||||
|
|
|
@ -2,18 +2,23 @@ package cli
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// App is the main structure of a cli application. It is recomended that
|
||||
// and app be created with the cli.NewApp() function
|
||||
// an app be created with the cli.NewApp() function
|
||||
type App struct {
|
||||
// The name of the program. Defaults to os.Args[0]
|
||||
Name string
|
||||
// Full name of command for help, defaults to Name
|
||||
HelpName string
|
||||
// Description of the program.
|
||||
Usage string
|
||||
// Description of the program argument format.
|
||||
ArgsUsage string
|
||||
// Version of the program
|
||||
Version string
|
||||
// List of commands to execute
|
||||
|
@ -24,21 +29,32 @@ type App struct {
|
|||
EnableBashCompletion bool
|
||||
// Boolean to hide built-in help command
|
||||
HideHelp bool
|
||||
// Boolean to hide built-in version flag
|
||||
HideVersion bool
|
||||
// An action to execute when the bash-completion flag is set
|
||||
BashComplete func(context *Context)
|
||||
// An action to execute before any subcommands are run, but after the context is ready
|
||||
// If a non-nil error is returned, no subcommands are run
|
||||
Before func(context *Context) error
|
||||
// An action to execute after any subcommands are run, but after the subcommand has finished
|
||||
// It is run even if Action() panics
|
||||
After func(context *Context) error
|
||||
// The action to execute when no subcommands are specified
|
||||
Action func(context *Context)
|
||||
// Execute this function if the proper command cannot be found
|
||||
CommandNotFound func(context *Context, command string)
|
||||
// Compilation date
|
||||
Compiled time.Time
|
||||
// Author
|
||||
// List of all authors who contributed
|
||||
Authors []Author
|
||||
// Copyright of the binary if any
|
||||
Copyright string
|
||||
// Name of Author (Note: Use App.Authors, this is deprecated)
|
||||
Author string
|
||||
// Author e-mail
|
||||
// Email of Author (Note: Use App.Authors, this is deprecated)
|
||||
Email string
|
||||
// Writer writer to write output to
|
||||
Writer io.Writer
|
||||
}
|
||||
|
||||
// Tries to find out when this binary was compiled.
|
||||
|
@ -55,61 +71,95 @@ func compileTime() time.Time {
|
|||
func NewApp() *App {
|
||||
return &App{
|
||||
Name: os.Args[0],
|
||||
HelpName: os.Args[0],
|
||||
Usage: "A new cli application",
|
||||
Version: "0.0.0",
|
||||
BashComplete: DefaultAppComplete,
|
||||
Action: helpCommand.Action,
|
||||
Compiled: compileTime(),
|
||||
Writer: os.Stdout,
|
||||
}
|
||||
}
|
||||
|
||||
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
|
||||
func (a *App) Run(arguments []string) error {
|
||||
func (a *App) Run(arguments []string) (err error) {
|
||||
if a.Author != "" || a.Email != "" {
|
||||
a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email})
|
||||
}
|
||||
|
||||
newCmds := []Command{}
|
||||
for _, c := range a.Commands {
|
||||
if c.HelpName == "" {
|
||||
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, c.Name)
|
||||
}
|
||||
newCmds = append(newCmds, c)
|
||||
}
|
||||
a.Commands = newCmds
|
||||
|
||||
// append help to commands
|
||||
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
|
||||
a.Commands = append(a.Commands, helpCommand)
|
||||
a.appendFlag(HelpFlag)
|
||||
if (HelpFlag != BoolFlag{}) {
|
||||
a.appendFlag(HelpFlag)
|
||||
}
|
||||
}
|
||||
|
||||
//append version/help flags
|
||||
if a.EnableBashCompletion {
|
||||
a.appendFlag(BashCompletionFlag)
|
||||
}
|
||||
a.appendFlag(VersionFlag)
|
||||
|
||||
if !a.HideVersion {
|
||||
a.appendFlag(VersionFlag)
|
||||
}
|
||||
|
||||
// parse flags
|
||||
set := flagSet(a.Name, a.Flags)
|
||||
set.SetOutput(ioutil.Discard)
|
||||
err := set.Parse(arguments[1:])
|
||||
err = set.Parse(arguments[1:])
|
||||
nerr := normalizeFlags(a.Flags, set)
|
||||
if nerr != nil {
|
||||
fmt.Println(nerr)
|
||||
context := NewContext(a, set, set)
|
||||
fmt.Fprintln(a.Writer, nerr)
|
||||
context := NewContext(a, set, nil)
|
||||
ShowAppHelp(context)
|
||||
fmt.Println("")
|
||||
return nerr
|
||||
}
|
||||
context := NewContext(a, set, set)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Incorrect Usage.\n\n")
|
||||
ShowAppHelp(context)
|
||||
fmt.Println("")
|
||||
return err
|
||||
}
|
||||
context := NewContext(a, set, nil)
|
||||
|
||||
if checkCompletions(context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if checkHelp(context) {
|
||||
if err != nil {
|
||||
fmt.Fprintln(a.Writer, "Incorrect Usage.")
|
||||
fmt.Fprintln(a.Writer)
|
||||
ShowAppHelp(context)
|
||||
return err
|
||||
}
|
||||
|
||||
if !a.HideHelp && checkHelp(context) {
|
||||
ShowAppHelp(context)
|
||||
return nil
|
||||
}
|
||||
|
||||
if checkVersion(context) {
|
||||
if !a.HideVersion && checkVersion(context) {
|
||||
ShowVersion(context)
|
||||
return nil
|
||||
}
|
||||
|
||||
if a.After != nil {
|
||||
defer func() {
|
||||
afterErr := a.After(context)
|
||||
if afterErr != nil {
|
||||
if err != nil {
|
||||
err = NewMultiError(err, afterErr)
|
||||
} else {
|
||||
err = afterErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if a.Before != nil {
|
||||
err := a.Before(context)
|
||||
if err != nil {
|
||||
|
@ -134,21 +184,32 @@ func (a *App) Run(arguments []string) error {
|
|||
// Another entry point to the cli app, takes care of passing arguments and error handling
|
||||
func (a *App) RunAndExitOnError() {
|
||||
if err := a.Run(os.Args); err != nil {
|
||||
os.Stderr.WriteString(fmt.Sprintln(err))
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
|
||||
func (a *App) RunAsSubcommand(ctx *Context) error {
|
||||
func (a *App) RunAsSubcommand(ctx *Context) (err error) {
|
||||
// append help to commands
|
||||
if len(a.Commands) > 0 {
|
||||
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
|
||||
a.Commands = append(a.Commands, helpCommand)
|
||||
a.appendFlag(HelpFlag)
|
||||
if (HelpFlag != BoolFlag{}) {
|
||||
a.appendFlag(HelpFlag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newCmds := []Command{}
|
||||
for _, c := range a.Commands {
|
||||
if c.HelpName == "" {
|
||||
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, c.Name)
|
||||
}
|
||||
newCmds = append(newCmds, c)
|
||||
}
|
||||
a.Commands = newCmds
|
||||
|
||||
// append flags
|
||||
if a.EnableBashCompletion {
|
||||
a.appendFlag(BashCompletionFlag)
|
||||
|
@ -157,31 +218,32 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
|
|||
// parse flags
|
||||
set := flagSet(a.Name, a.Flags)
|
||||
set.SetOutput(ioutil.Discard)
|
||||
err := set.Parse(ctx.Args().Tail())
|
||||
err = set.Parse(ctx.Args().Tail())
|
||||
nerr := normalizeFlags(a.Flags, set)
|
||||
context := NewContext(a, set, ctx.globalSet)
|
||||
context := NewContext(a, set, ctx)
|
||||
|
||||
if nerr != nil {
|
||||
fmt.Println(nerr)
|
||||
fmt.Fprintln(a.Writer, nerr)
|
||||
fmt.Fprintln(a.Writer)
|
||||
if len(a.Commands) > 0 {
|
||||
ShowSubcommandHelp(context)
|
||||
} else {
|
||||
ShowCommandHelp(ctx, context.Args().First())
|
||||
}
|
||||
fmt.Println("")
|
||||
return nerr
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Incorrect Usage.\n\n")
|
||||
ShowSubcommandHelp(context)
|
||||
return err
|
||||
}
|
||||
|
||||
if checkCompletions(context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintln(a.Writer, "Incorrect Usage.")
|
||||
fmt.Fprintln(a.Writer)
|
||||
ShowSubcommandHelp(context)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(a.Commands) > 0 {
|
||||
if checkSubcommandHelp(context) {
|
||||
return nil
|
||||
|
@ -192,6 +254,19 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
if a.After != nil {
|
||||
defer func() {
|
||||
afterErr := a.After(context)
|
||||
if afterErr != nil {
|
||||
if err != nil {
|
||||
err = NewMultiError(err, afterErr)
|
||||
} else {
|
||||
err = afterErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if a.Before != nil {
|
||||
err := a.Before(context)
|
||||
if err != nil {
|
||||
|
@ -209,11 +284,7 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
|
|||
}
|
||||
|
||||
// Run default Action
|
||||
if len(a.Commands) > 0 {
|
||||
a.Action(context)
|
||||
} else {
|
||||
a.Action(ctx)
|
||||
}
|
||||
a.Action(context)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -244,3 +315,19 @@ func (a *App) appendFlag(flag Flag) {
|
|||
a.Flags = append(a.Flags, flag)
|
||||
}
|
||||
}
|
||||
|
||||
// Author represents someone who has contributed to a cli project.
|
||||
type Author struct {
|
||||
Name string // The Authors name
|
||||
Email string // The Authors email
|
||||
}
|
||||
|
||||
// String makes Author comply to the Stringer interface, to allow an easy print in the templating process
|
||||
func (a Author) String() string {
|
||||
e := ""
|
||||
if a.Email != "" {
|
||||
e = "<" + a.Email + "> "
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v %v", a.Name, e)
|
||||
}
|
||||
|
|
|
@ -1,423 +0,0 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func ExampleApp() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "--name", "Jeremy"}
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Name = "greet"
|
||||
app.Flags = []cli.Flag{
|
||||
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
|
||||
}
|
||||
app.Action = func(c *cli.Context) {
|
||||
fmt.Printf("Hello %v\n", c.String("name"))
|
||||
}
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// Hello Jeremy
|
||||
}
|
||||
|
||||
func ExampleAppSubcommand() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"}
|
||||
app := cli.NewApp()
|
||||
app.Name = "say"
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "hello",
|
||||
ShortName: "hi",
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe hello the function",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "english",
|
||||
ShortName: "en",
|
||||
Usage: "sends a greeting in english",
|
||||
Description: "greets someone in english",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "name",
|
||||
Value: "Bob",
|
||||
Usage: "Name of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *cli.Context) {
|
||||
fmt.Println("Hello,", c.String("name"))
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// Hello, Jeremy
|
||||
}
|
||||
|
||||
func ExampleAppHelp() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "h", "describeit"}
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Name = "greet"
|
||||
app.Flags = []cli.Flag{
|
||||
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
|
||||
}
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "describeit",
|
||||
ShortName: "d",
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe describeit the function",
|
||||
Action: func(c *cli.Context) {
|
||||
fmt.Printf("i like to describe things")
|
||||
},
|
||||
},
|
||||
}
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// NAME:
|
||||
// describeit - use it to see a description
|
||||
//
|
||||
// USAGE:
|
||||
// command describeit [arguments...]
|
||||
//
|
||||
// DESCRIPTION:
|
||||
// This is how we describe describeit the function
|
||||
}
|
||||
|
||||
func ExampleAppBashComplete() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "--generate-bash-completion"}
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Name = "greet"
|
||||
app.EnableBashCompletion = true
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "describeit",
|
||||
ShortName: "d",
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe describeit the function",
|
||||
Action: func(c *cli.Context) {
|
||||
fmt.Printf("i like to describe things")
|
||||
},
|
||||
}, {
|
||||
Name: "next",
|
||||
Usage: "next example",
|
||||
Description: "more stuff to see when generating bash completion",
|
||||
Action: func(c *cli.Context) {
|
||||
fmt.Printf("the next example")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// describeit
|
||||
// d
|
||||
// next
|
||||
// help
|
||||
// h
|
||||
}
|
||||
|
||||
func TestApp_Run(t *testing.T) {
|
||||
s := ""
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Action = func(c *cli.Context) {
|
||||
s = s + c.Args().First()
|
||||
}
|
||||
|
||||
err := app.Run([]string{"command", "foo"})
|
||||
expect(t, err, nil)
|
||||
err = app.Run([]string{"command", "bar"})
|
||||
expect(t, err, nil)
|
||||
expect(t, s, "foobar")
|
||||
}
|
||||
|
||||
var commandAppTests = []struct {
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{"foobar", true},
|
||||
{"batbaz", true},
|
||||
{"b", true},
|
||||
{"f", true},
|
||||
{"bat", false},
|
||||
{"nothing", false},
|
||||
}
|
||||
|
||||
func TestApp_Command(t *testing.T) {
|
||||
app := cli.NewApp()
|
||||
fooCommand := cli.Command{Name: "foobar", ShortName: "f"}
|
||||
batCommand := cli.Command{Name: "batbaz", ShortName: "b"}
|
||||
app.Commands = []cli.Command{
|
||||
fooCommand,
|
||||
batCommand,
|
||||
}
|
||||
|
||||
for _, test := range commandAppTests {
|
||||
expect(t, app.Command(test.name) != nil, test.expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_CommandWithArgBeforeFlags(t *testing.T) {
|
||||
var parsedOption, firstArg string
|
||||
|
||||
app := cli.NewApp()
|
||||
command := cli.Command{
|
||||
Name: "cmd",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "option", Value: "", Usage: "some option"},
|
||||
},
|
||||
Action: func(c *cli.Context) {
|
||||
parsedOption = c.String("option")
|
||||
firstArg = c.Args().First()
|
||||
},
|
||||
}
|
||||
app.Commands = []cli.Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"})
|
||||
|
||||
expect(t, parsedOption, "my-option")
|
||||
expect(t, firstArg, "my-arg")
|
||||
}
|
||||
|
||||
func TestApp_Float64Flag(t *testing.T) {
|
||||
var meters float64
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Flags = []cli.Flag{
|
||||
cli.Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"},
|
||||
}
|
||||
app.Action = func(c *cli.Context) {
|
||||
meters = c.Float64("height")
|
||||
}
|
||||
|
||||
app.Run([]string{"", "--height", "1.93"})
|
||||
expect(t, meters, 1.93)
|
||||
}
|
||||
|
||||
func TestApp_ParseSliceFlags(t *testing.T) {
|
||||
var parsedOption, firstArg string
|
||||
var parsedIntSlice []int
|
||||
var parsedStringSlice []string
|
||||
|
||||
app := cli.NewApp()
|
||||
command := cli.Command{
|
||||
Name: "cmd",
|
||||
Flags: []cli.Flag{
|
||||
cli.IntSliceFlag{Name: "p", Value: &cli.IntSlice{}, Usage: "set one or more ip addr"},
|
||||
cli.StringSliceFlag{Name: "ip", Value: &cli.StringSlice{}, Usage: "set one or more ports to open"},
|
||||
},
|
||||
Action: func(c *cli.Context) {
|
||||
parsedIntSlice = c.IntSlice("p")
|
||||
parsedStringSlice = c.StringSlice("ip")
|
||||
parsedOption = c.String("option")
|
||||
firstArg = c.Args().First()
|
||||
},
|
||||
}
|
||||
app.Commands = []cli.Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"})
|
||||
|
||||
IntsEquals := func(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
StrsEquals := func(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
var expectedIntSlice = []int{22, 80}
|
||||
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"}
|
||||
|
||||
if !IntsEquals(parsedIntSlice, expectedIntSlice) {
|
||||
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice)
|
||||
}
|
||||
|
||||
if !StrsEquals(parsedStringSlice, expectedStringSlice) {
|
||||
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_BeforeFunc(t *testing.T) {
|
||||
beforeRun, subcommandRun := false, false
|
||||
beforeError := fmt.Errorf("fail")
|
||||
var err error
|
||||
|
||||
app := cli.NewApp()
|
||||
|
||||
app.Before = func(c *cli.Context) error {
|
||||
beforeRun = true
|
||||
s := c.String("opt")
|
||||
if s == "fail" {
|
||||
return beforeError
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Commands = []cli.Command{
|
||||
cli.Command{
|
||||
Name: "sub",
|
||||
Action: func(c *cli.Context) {
|
||||
subcommandRun = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Flags = []cli.Flag{
|
||||
cli.StringFlag{Name: "opt"},
|
||||
}
|
||||
|
||||
// run with the Before() func succeeding
|
||||
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Run error: %s", err)
|
||||
}
|
||||
|
||||
if beforeRun == false {
|
||||
t.Errorf("Before() not executed when expected")
|
||||
}
|
||||
|
||||
if subcommandRun == false {
|
||||
t.Errorf("Subcommand not executed when expected")
|
||||
}
|
||||
|
||||
// reset
|
||||
beforeRun, subcommandRun = false, false
|
||||
|
||||
// run with the Before() func failing
|
||||
err = app.Run([]string{"command", "--opt", "fail", "sub"})
|
||||
|
||||
// should be the same error produced by the Before func
|
||||
if err != beforeError {
|
||||
t.Errorf("Run error expected, but not received")
|
||||
}
|
||||
|
||||
if beforeRun == false {
|
||||
t.Errorf("Before() not executed when expected")
|
||||
}
|
||||
|
||||
if subcommandRun == true {
|
||||
t.Errorf("Subcommand executed when NOT expected")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAppHelpPrinter(t *testing.T) {
|
||||
oldPrinter := cli.HelpPrinter
|
||||
defer func() {
|
||||
cli.HelpPrinter = oldPrinter
|
||||
}()
|
||||
|
||||
var wasCalled = false
|
||||
cli.HelpPrinter = func(template string, data interface{}) {
|
||||
wasCalled = true
|
||||
}
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Run([]string{"-h"})
|
||||
|
||||
if wasCalled == false {
|
||||
t.Errorf("Help printer expected to be called, but was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppVersionPrinter(t *testing.T) {
|
||||
oldPrinter := cli.VersionPrinter
|
||||
defer func() {
|
||||
cli.VersionPrinter = oldPrinter
|
||||
}()
|
||||
|
||||
var wasCalled = false
|
||||
cli.VersionPrinter = func(c *cli.Context) {
|
||||
wasCalled = true
|
||||
}
|
||||
|
||||
app := cli.NewApp()
|
||||
ctx := cli.NewContext(app, nil, nil)
|
||||
cli.ShowVersion(ctx)
|
||||
|
||||
if wasCalled == false {
|
||||
t.Errorf("Version printer expected to be called, but was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppCommandNotFound(t *testing.T) {
|
||||
beforeRun, subcommandRun := false, false
|
||||
app := cli.NewApp()
|
||||
|
||||
app.CommandNotFound = func(c *cli.Context, command string) {
|
||||
beforeRun = true
|
||||
}
|
||||
|
||||
app.Commands = []cli.Command{
|
||||
cli.Command{
|
||||
Name: "bar",
|
||||
Action: func(c *cli.Context) {
|
||||
subcommandRun = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run([]string{"command", "foo"})
|
||||
|
||||
expect(t, beforeRun, true)
|
||||
expect(t, subcommandRun, false)
|
||||
}
|
||||
|
||||
func TestGlobalFlagsInSubcommands(t *testing.T) {
|
||||
subcommandRun := false
|
||||
app := cli.NewApp()
|
||||
|
||||
app.Flags = []cli.Flag{
|
||||
cli.BoolFlag{Name: "debug, d", Usage: "Enable debugging"},
|
||||
}
|
||||
|
||||
app.Commands = []cli.Command{
|
||||
cli.Command{
|
||||
Name: "foo",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "bar",
|
||||
Action: func(c *cli.Context) {
|
||||
if c.GlobalBool("debug") {
|
||||
subcommandRun = true
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run([]string{"command", "-d", "foo", "bar"})
|
||||
|
||||
expect(t, subcommandRun, true)
|
||||
}
|
7
Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete
generated
vendored
7
Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete
generated
vendored
|
@ -1,13 +1,14 @@
|
|||
#! /bin/bash
|
||||
|
||||
: ${PROG:=$(basename ${BASH_SOURCE})}
|
||||
|
||||
_cli_bash_autocomplete() {
|
||||
local cur prev opts base
|
||||
local cur opts base
|
||||
COMPREPLY=()
|
||||
cur="${COMP_WORDS[COMP_CWORD]}"
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion )
|
||||
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
|
||||
return 0
|
||||
}
|
||||
|
||||
complete -F _cli_bash_autocomplete $PROG
|
||||
complete -F _cli_bash_autocomplete $PROG
|
||||
|
|
|
@ -17,3 +17,24 @@
|
|||
// app.Run(os.Args)
|
||||
// }
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MultiError struct {
|
||||
Errors []error
|
||||
}
|
||||
|
||||
func NewMultiError(err ...error) MultiError {
|
||||
return MultiError{Errors: err}
|
||||
}
|
||||
|
||||
func (m MultiError) Error() string {
|
||||
errs := make([]string, len(m.Errors))
|
||||
for i, err := range m.Errors {
|
||||
errs[i] = err.Error()
|
||||
}
|
||||
|
||||
return strings.Join(errs, "\n")
|
||||
}
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
app := cli.NewApp()
|
||||
app.Name = "todo"
|
||||
app.Usage = "task list on the command line"
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "add",
|
||||
ShortName: "a",
|
||||
Usage: "add a task to the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("added task: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "complete",
|
||||
ShortName: "c",
|
||||
Usage: "complete a task on the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("completed task: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
||||
|
||||
func ExampleSubcommand() {
|
||||
app := cli.NewApp()
|
||||
app.Name = "say"
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "hello",
|
||||
ShortName: "hi",
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe hello the function",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "english",
|
||||
ShortName: "en",
|
||||
Usage: "sends a greeting in english",
|
||||
Description: "greets someone in english",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "name",
|
||||
Value: "Bob",
|
||||
Usage: "Name of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *cli.Context) {
|
||||
println("Hello, ", c.String("name"))
|
||||
},
|
||||
}, {
|
||||
Name: "spanish",
|
||||
ShortName: "sp",
|
||||
Usage: "sends a greeting in spanish",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "surname",
|
||||
Value: "Jones",
|
||||
Usage: "Surname of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *cli.Context) {
|
||||
println("Hola, ", c.String("surname"))
|
||||
},
|
||||
}, {
|
||||
Name: "french",
|
||||
ShortName: "fr",
|
||||
Usage: "sends a greeting in french",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{
|
||||
Name: "nickname",
|
||||
Value: "Stevie",
|
||||
Usage: "Nickname of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *cli.Context) {
|
||||
println("Bonjour, ", c.String("nickname"))
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
Name: "bye",
|
||||
Usage: "says goodbye",
|
||||
Action: func(c *cli.Context) {
|
||||
println("bye")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
|
@ -10,17 +10,24 @@ import (
|
|||
type Command struct {
|
||||
// The name of the command
|
||||
Name string
|
||||
// short name of the command. Typically one character
|
||||
// short name of the command. Typically one character (deprecated, use `Aliases`)
|
||||
ShortName string
|
||||
// A list of aliases for the command
|
||||
Aliases []string
|
||||
// A short description of the usage of this command
|
||||
Usage string
|
||||
// A longer explanation of how the command works
|
||||
Description string
|
||||
// A short description of the arguments of this command
|
||||
ArgsUsage string
|
||||
// The function to call when checking for bash command completions
|
||||
BashComplete func(context *Context)
|
||||
// An action to execute before any sub-subcommands are run, but after the context is ready
|
||||
// If a non-nil error is returned, no sub-subcommands are run
|
||||
Before func(context *Context) error
|
||||
// An action to execute after any subcommands are run, but after the subcommand has finished
|
||||
// It is run even if Action() panics
|
||||
After func(context *Context) error
|
||||
// The function to call when this command is invoked
|
||||
Action func(context *Context)
|
||||
// List of child commands
|
||||
|
@ -31,16 +38,28 @@ type Command struct {
|
|||
SkipFlagParsing bool
|
||||
// Boolean to hide built-in help command
|
||||
HideHelp bool
|
||||
|
||||
// Full name of command for help, defaults to full command name, including parent commands.
|
||||
HelpName string
|
||||
commandNamePath []string
|
||||
}
|
||||
|
||||
// Returns the full name of the command.
|
||||
// For subcommands this ensures that parent commands are part of the command path
|
||||
func (c Command) FullName() string {
|
||||
if c.commandNamePath == nil {
|
||||
return c.Name
|
||||
}
|
||||
return strings.Join(c.commandNamePath, " ")
|
||||
}
|
||||
|
||||
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
|
||||
func (c Command) Run(ctx *Context) error {
|
||||
|
||||
if len(c.Subcommands) > 0 || c.Before != nil {
|
||||
if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil {
|
||||
return c.startApp(ctx)
|
||||
}
|
||||
|
||||
if !c.HideHelp {
|
||||
if !c.HideHelp && (HelpFlag != BoolFlag{}) {
|
||||
// append help to flags
|
||||
c.Flags = append(
|
||||
c.Flags,
|
||||
|
@ -55,40 +74,57 @@ func (c Command) Run(ctx *Context) error {
|
|||
set := flagSet(c.Name, c.Flags)
|
||||
set.SetOutput(ioutil.Discard)
|
||||
|
||||
firstFlagIndex := -1
|
||||
for index, arg := range ctx.Args() {
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
firstFlagIndex = index
|
||||
break
|
||||
var err error
|
||||
if !c.SkipFlagParsing {
|
||||
firstFlagIndex := -1
|
||||
terminatorIndex := -1
|
||||
for index, arg := range ctx.Args() {
|
||||
if arg == "--" {
|
||||
terminatorIndex = index
|
||||
break
|
||||
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
|
||||
firstFlagIndex = index
|
||||
}
|
||||
}
|
||||
|
||||
if firstFlagIndex > -1 {
|
||||
args := ctx.Args()
|
||||
regularArgs := make([]string, len(args[1:firstFlagIndex]))
|
||||
copy(regularArgs, args[1:firstFlagIndex])
|
||||
|
||||
var flagArgs []string
|
||||
if terminatorIndex > -1 {
|
||||
flagArgs = args[firstFlagIndex:terminatorIndex]
|
||||
regularArgs = append(regularArgs, args[terminatorIndex:]...)
|
||||
} else {
|
||||
flagArgs = args[firstFlagIndex:]
|
||||
}
|
||||
|
||||
err = set.Parse(append(flagArgs, regularArgs...))
|
||||
} else {
|
||||
err = set.Parse(ctx.Args().Tail())
|
||||
}
|
||||
} else {
|
||||
if c.SkipFlagParsing {
|
||||
err = set.Parse(append([]string{"--"}, ctx.Args().Tail()...))
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if firstFlagIndex > -1 && !c.SkipFlagParsing {
|
||||
args := ctx.Args()
|
||||
regularArgs := args[1:firstFlagIndex]
|
||||
flagArgs := args[firstFlagIndex:]
|
||||
err = set.Parse(append(flagArgs, regularArgs...))
|
||||
} else {
|
||||
err = set.Parse(ctx.Args().Tail())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Incorrect Usage.\n\n")
|
||||
fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.")
|
||||
fmt.Fprintln(ctx.App.Writer)
|
||||
ShowCommandHelp(ctx, c.Name)
|
||||
fmt.Println("")
|
||||
return err
|
||||
}
|
||||
|
||||
nerr := normalizeFlags(c.Flags, set)
|
||||
if nerr != nil {
|
||||
fmt.Println(nerr)
|
||||
fmt.Println("")
|
||||
fmt.Fprintln(ctx.App.Writer, nerr)
|
||||
fmt.Fprintln(ctx.App.Writer)
|
||||
ShowCommandHelp(ctx, c.Name)
|
||||
fmt.Println("")
|
||||
return nerr
|
||||
}
|
||||
context := NewContext(ctx.App, set, ctx.globalSet)
|
||||
context := NewContext(ctx.App, set, ctx)
|
||||
|
||||
if checkCommandCompletions(context, c.Name) {
|
||||
return nil
|
||||
|
@ -102,9 +138,24 @@ func (c Command) Run(ctx *Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c Command) Names() []string {
|
||||
names := []string{c.Name}
|
||||
|
||||
if c.ShortName != "" {
|
||||
names = append(names, c.ShortName)
|
||||
}
|
||||
|
||||
return append(names, c.Aliases...)
|
||||
}
|
||||
|
||||
// Returns true if Command.Name or Command.ShortName matches given name
|
||||
func (c Command) HasName(name string) bool {
|
||||
return c.Name == name || c.ShortName == name
|
||||
for _, n := range c.Names() {
|
||||
if n == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Command) startApp(ctx *Context) error {
|
||||
|
@ -112,6 +163,12 @@ func (c Command) startApp(ctx *Context) error {
|
|||
|
||||
// set the name and usage
|
||||
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
|
||||
if c.HelpName == "" {
|
||||
app.HelpName = c.HelpName
|
||||
} else {
|
||||
app.HelpName = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
|
||||
}
|
||||
|
||||
if c.Description != "" {
|
||||
app.Usage = c.Description
|
||||
} else {
|
||||
|
@ -126,6 +183,13 @@ func (c Command) startApp(ctx *Context) error {
|
|||
app.Flags = c.Flags
|
||||
app.HideHelp = c.HideHelp
|
||||
|
||||
app.Version = ctx.App.Version
|
||||
app.HideVersion = ctx.App.HideVersion
|
||||
app.Compiled = ctx.App.Compiled
|
||||
app.Author = ctx.App.Author
|
||||
app.Email = ctx.App.Email
|
||||
app.Writer = ctx.App.Writer
|
||||
|
||||
// bash completion
|
||||
app.EnableBashCompletion = ctx.App.EnableBashCompletion
|
||||
if c.BashComplete != nil {
|
||||
|
@ -134,11 +198,19 @@ func (c Command) startApp(ctx *Context) error {
|
|||
|
||||
// set the actions
|
||||
app.Before = c.Before
|
||||
app.After = c.After
|
||||
if c.Action != nil {
|
||||
app.Action = c.Action
|
||||
} else {
|
||||
app.Action = helpSubcommand.Action
|
||||
}
|
||||
|
||||
var newCmds []Command
|
||||
for _, cc := range app.Commands {
|
||||
cc.commandNamePath = []string{c.Name, cc.Name}
|
||||
newCmds = append(newCmds, cc)
|
||||
}
|
||||
app.Commands = newCmds
|
||||
|
||||
return app.RunAsSubcommand(ctx)
|
||||
}
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func TestCommandDoNotIgnoreFlags(t *testing.T) {
|
||||
app := cli.NewApp()
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
test := []string{"blah", "blah", "-break"}
|
||||
set.Parse(test)
|
||||
|
||||
c := cli.NewContext(app, set, set)
|
||||
|
||||
command := cli.Command{
|
||||
Name: "test-cmd",
|
||||
ShortName: "tc",
|
||||
Usage: "this is for testing",
|
||||
Description: "testing",
|
||||
Action: func(_ *cli.Context) {},
|
||||
}
|
||||
err := command.Run(c)
|
||||
|
||||
expect(t, err.Error(), "flag provided but not defined: -break")
|
||||
}
|
||||
|
||||
func TestCommandIgnoreFlags(t *testing.T) {
|
||||
app := cli.NewApp()
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
test := []string{"blah", "blah"}
|
||||
set.Parse(test)
|
||||
|
||||
c := cli.NewContext(app, set, set)
|
||||
|
||||
command := cli.Command{
|
||||
Name: "test-cmd",
|
||||
ShortName: "tc",
|
||||
Usage: "this is for testing",
|
||||
Description: "testing",
|
||||
Action: func(_ *cli.Context) {},
|
||||
SkipFlagParsing: true,
|
||||
}
|
||||
err := command.Run(c)
|
||||
|
||||
expect(t, err, nil)
|
||||
}
|
|
@ -13,16 +13,17 @@ import (
|
|||
// can be used to retrieve context-specific Args and
|
||||
// parsed command-line options.
|
||||
type Context struct {
|
||||
App *App
|
||||
Command Command
|
||||
flagSet *flag.FlagSet
|
||||
globalSet *flag.FlagSet
|
||||
setFlags map[string]bool
|
||||
App *App
|
||||
Command Command
|
||||
flagSet *flag.FlagSet
|
||||
setFlags map[string]bool
|
||||
globalSetFlags map[string]bool
|
||||
parentContext *Context
|
||||
}
|
||||
|
||||
// Creates a new context. For use in when invoking an App or Command action.
|
||||
func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context {
|
||||
return &Context{App: app, flagSet: set, globalSet: globalSet}
|
||||
func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context {
|
||||
return &Context{App: app, flagSet: set, parentContext: parentCtx}
|
||||
}
|
||||
|
||||
// Looks up the value of a local int flag, returns 0 if no int flag exists
|
||||
|
@ -72,40 +73,66 @@ func (c *Context) Generic(name string) interface{} {
|
|||
|
||||
// Looks up the value of a global int flag, returns 0 if no int flag exists
|
||||
func (c *Context) GlobalInt(name string) int {
|
||||
return lookupInt(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupInt(name, fs)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
|
||||
func (c *Context) GlobalDuration(name string) time.Duration {
|
||||
return lookupDuration(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupDuration(name, fs)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Looks up the value of a global bool flag, returns false if no bool flag exists
|
||||
func (c *Context) GlobalBool(name string) bool {
|
||||
return lookupBool(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupBool(name, fs)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Looks up the value of a global string flag, returns "" if no string flag exists
|
||||
func (c *Context) GlobalString(name string) string {
|
||||
return lookupString(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupString(name, fs)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
|
||||
func (c *Context) GlobalStringSlice(name string) []string {
|
||||
return lookupStringSlice(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupStringSlice(name, fs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
|
||||
func (c *Context) GlobalIntSlice(name string) []int {
|
||||
return lookupIntSlice(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupIntSlice(name, fs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Looks up the value of a global generic flag, returns nil if no generic flag exists
|
||||
func (c *Context) GlobalGeneric(name string) interface{} {
|
||||
return lookupGeneric(name, c.globalSet)
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupGeneric(name, fs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determines if the flag was actually set exists
|
||||
// Returns the number of flags set
|
||||
func (c *Context) NumFlags() int {
|
||||
return c.flagSet.NFlag()
|
||||
}
|
||||
|
||||
// Determines if the flag was actually set
|
||||
func (c *Context) IsSet(name string) bool {
|
||||
if c.setFlags == nil {
|
||||
c.setFlags = make(map[string]bool)
|
||||
|
@ -116,6 +143,23 @@ func (c *Context) IsSet(name string) bool {
|
|||
return c.setFlags[name] == true
|
||||
}
|
||||
|
||||
// Determines if the global flag was actually set
|
||||
func (c *Context) GlobalIsSet(name string) bool {
|
||||
if c.globalSetFlags == nil {
|
||||
c.globalSetFlags = make(map[string]bool)
|
||||
ctx := c
|
||||
if ctx.parentContext != nil {
|
||||
ctx = ctx.parentContext
|
||||
}
|
||||
for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext {
|
||||
ctx.flagSet.Visit(func(f *flag.Flag) {
|
||||
c.globalSetFlags[f.Name] = true
|
||||
})
|
||||
}
|
||||
}
|
||||
return c.globalSetFlags[name]
|
||||
}
|
||||
|
||||
// Returns a slice of flag names used in this context.
|
||||
func (c *Context) FlagNames() (names []string) {
|
||||
for _, flag := range c.Command.Flags {
|
||||
|
@ -128,6 +172,23 @@ func (c *Context) FlagNames() (names []string) {
|
|||
return
|
||||
}
|
||||
|
||||
// Returns a slice of global flag names used by the app.
|
||||
func (c *Context) GlobalFlagNames() (names []string) {
|
||||
for _, flag := range c.App.Flags {
|
||||
name := strings.Split(flag.getName(), ",")[0]
|
||||
if name == "help" || name == "version" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Returns the parent context, if any
|
||||
func (c *Context) Parent() *Context {
|
||||
return c.parentContext
|
||||
}
|
||||
|
||||
type Args []string
|
||||
|
||||
// Returns the command line arguments associated with the context.
|
||||
|
@ -172,6 +233,18 @@ func (a Args) Swap(from, to int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet {
|
||||
if ctx.parentContext != nil {
|
||||
ctx = ctx.parentContext
|
||||
}
|
||||
for ; ctx != nil; ctx = ctx.parentContext {
|
||||
if f := ctx.flagSet.Lookup(name); f != nil {
|
||||
return ctx.flagSet
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupInt(name string, set *flag.FlagSet) int {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Int("myflag", 12, "doc")
|
||||
globalSet := flag.NewFlagSet("test", 0)
|
||||
globalSet.Int("myflag", 42, "doc")
|
||||
command := cli.Command{Name: "mycommand"}
|
||||
c := cli.NewContext(nil, set, globalSet)
|
||||
c.Command = command
|
||||
expect(t, c.Int("myflag"), 12)
|
||||
expect(t, c.GlobalInt("myflag"), 42)
|
||||
expect(t, c.Command.Name, "mycommand")
|
||||
}
|
||||
|
||||
func TestContext_Int(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Int("myflag", 12, "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
expect(t, c.Int("myflag"), 12)
|
||||
}
|
||||
|
||||
func TestContext_Duration(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Duration("myflag", time.Duration(12*time.Second), "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
|
||||
}
|
||||
|
||||
func TestContext_String(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.String("myflag", "hello world", "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
expect(t, c.String("myflag"), "hello world")
|
||||
}
|
||||
|
||||
func TestContext_Bool(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
expect(t, c.Bool("myflag"), false)
|
||||
}
|
||||
|
||||
func TestContext_BoolT(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", true, "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
expect(t, c.BoolT("myflag"), true)
|
||||
}
|
||||
|
||||
func TestContext_Args(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
set.Parse([]string{"--myflag", "bat", "baz"})
|
||||
expect(t, len(c.Args()), 2)
|
||||
expect(t, c.Bool("myflag"), true)
|
||||
}
|
||||
|
||||
func TestContext_IsSet(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
set.String("otherflag", "hello world", "doc")
|
||||
c := cli.NewContext(nil, set, set)
|
||||
set.Parse([]string{"--myflag", "bat", "baz"})
|
||||
expect(t, c.IsSet("myflag"), true)
|
||||
expect(t, c.IsSet("otherflag"), false)
|
||||
expect(t, c.IsSet("bogusflag"), false)
|
||||
}
|
|
@ -21,6 +21,8 @@ var VersionFlag = BoolFlag{
|
|||
}
|
||||
|
||||
// This flag prints the help for all commands and subcommands
|
||||
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
|
||||
// unless HideHelp is set to true)
|
||||
var HelpFlag = BoolFlag{
|
||||
Name: "help, h",
|
||||
Usage: "show help",
|
||||
|
@ -67,15 +69,24 @@ type GenericFlag struct {
|
|||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the string representation of the generic flag to display the
|
||||
// help text to the user (uses the String() method of the generic flag to show
|
||||
// the value)
|
||||
func (f GenericFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage))
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply takes the flagset and calls Set on the generic flag with the value
|
||||
// provided by the user for parsing by the flag
|
||||
func (f GenericFlag) Apply(set *flag.FlagSet) {
|
||||
val := f.Value
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
val.Set(envVal)
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
val.Set(envVal)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,21 +99,27 @@ func (f GenericFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// StringSlice is an opaque type for []string to satisfy flag.Value
|
||||
type StringSlice []string
|
||||
|
||||
// Set appends the string value to the list of values
|
||||
func (f *StringSlice) Set(value string) error {
|
||||
*f = append(*f, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f *StringSlice) String() string {
|
||||
return fmt.Sprintf("%s", *f)
|
||||
}
|
||||
|
||||
// Value returns the slice of strings set by this flag
|
||||
func (f *StringSlice) Value() []string {
|
||||
return *f
|
||||
}
|
||||
|
||||
// StringSlice is a string flag that can be specified multiple times on the
|
||||
// command-line
|
||||
type StringSliceFlag struct {
|
||||
Name string
|
||||
Value *StringSlice
|
||||
|
@ -110,24 +127,34 @@ type StringSliceFlag struct {
|
|||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f StringSliceFlag) String() string {
|
||||
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
|
||||
pref := prefixFor(firstName)
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f StringSliceFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
newVal := &StringSlice{}
|
||||
for _, s := range strings.Split(envVal, ",") {
|
||||
newVal.Set(s)
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
newVal := &StringSlice{}
|
||||
for _, s := range strings.Split(envVal, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
newVal.Set(s)
|
||||
}
|
||||
f.Value = newVal
|
||||
break
|
||||
}
|
||||
f.Value = newVal
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Value == nil {
|
||||
f.Value = &StringSlice{}
|
||||
}
|
||||
set.Var(f.Value, name, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -136,10 +163,11 @@ func (f StringSliceFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// StringSlice is an opaque type for []int to satisfy flag.Value
|
||||
type IntSlice []int
|
||||
|
||||
// Set parses the value into an integer and appends it to the list of values
|
||||
func (f *IntSlice) Set(value string) error {
|
||||
|
||||
tmp, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -149,14 +177,18 @@ func (f *IntSlice) Set(value string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f *IntSlice) String() string {
|
||||
return fmt.Sprintf("%d", *f)
|
||||
}
|
||||
|
||||
// Value returns the slice of ints set by this flag
|
||||
func (f *IntSlice) Value() []int {
|
||||
return *f
|
||||
}
|
||||
|
||||
// IntSliceFlag is an int flag that can be specified multiple times on the
|
||||
// command-line
|
||||
type IntSliceFlag struct {
|
||||
Name string
|
||||
Value *IntSlice
|
||||
|
@ -164,27 +196,37 @@ type IntSliceFlag struct {
|
|||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f IntSliceFlag) String() string {
|
||||
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
|
||||
pref := prefixFor(firstName)
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f IntSliceFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
newVal := &IntSlice{}
|
||||
for _, s := range strings.Split(envVal, ",") {
|
||||
err := newVal.Set(s)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
newVal := &IntSlice{}
|
||||
for _, s := range strings.Split(envVal, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
err := newVal.Set(s)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
}
|
||||
}
|
||||
f.Value = newVal
|
||||
break
|
||||
}
|
||||
f.Value = newVal
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Value == nil {
|
||||
f.Value = &IntSlice{}
|
||||
}
|
||||
set.Var(f.Value, name, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -193,28 +235,40 @@ func (f IntSliceFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// BoolFlag is a switch that defaults to false
|
||||
type BoolFlag struct {
|
||||
Name string
|
||||
Usage string
|
||||
EnvVar string
|
||||
Name string
|
||||
Usage string
|
||||
EnvVar string
|
||||
Destination *bool
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f BoolFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f BoolFlag) Apply(set *flag.FlagSet) {
|
||||
val := false
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
envValBool, err := strconv.ParseBool(envVal)
|
||||
if err == nil {
|
||||
val = envValBool
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValBool, err := strconv.ParseBool(envVal)
|
||||
if err == nil {
|
||||
val = envValBool
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Destination != nil {
|
||||
set.BoolVar(f.Destination, name, val, f.Usage)
|
||||
return
|
||||
}
|
||||
set.Bool(name, val, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -223,28 +277,41 @@ func (f BoolFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// BoolTFlag this represents a boolean flag that is true by default, but can
|
||||
// still be set to false by --some-flag=false
|
||||
type BoolTFlag struct {
|
||||
Name string
|
||||
Usage string
|
||||
EnvVar string
|
||||
Name string
|
||||
Usage string
|
||||
EnvVar string
|
||||
Destination *bool
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f BoolTFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f BoolTFlag) Apply(set *flag.FlagSet) {
|
||||
val := true
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
envValBool, err := strconv.ParseBool(envVal)
|
||||
if err == nil {
|
||||
val = envValBool
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValBool, err := strconv.ParseBool(envVal)
|
||||
if err == nil {
|
||||
val = envValBool
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Destination != nil {
|
||||
set.BoolVar(f.Destination, name, val, f.Usage)
|
||||
return
|
||||
}
|
||||
set.Bool(name, val, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -253,19 +320,22 @@ func (f BoolTFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// StringFlag represents a flag that takes as string value
|
||||
type StringFlag struct {
|
||||
Name string
|
||||
Value string
|
||||
Usage string
|
||||
EnvVar string
|
||||
Name string
|
||||
Value string
|
||||
Usage string
|
||||
EnvVar string
|
||||
Destination *string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f StringFlag) String() string {
|
||||
var fmtString string
|
||||
fmtString = "%s %v\t%v"
|
||||
|
||||
if len(f.Value) > 0 {
|
||||
fmtString = "%s '%v'\t%v"
|
||||
fmtString = "%s \"%v\"\t%v"
|
||||
} else {
|
||||
fmtString = "%s %v\t%v"
|
||||
}
|
||||
|
@ -273,14 +343,23 @@ func (f StringFlag) String() string {
|
|||
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f StringFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
f.Value = envVal
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
f.Value = envVal
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Destination != nil {
|
||||
set.StringVar(f.Destination, name, f.Value, f.Usage)
|
||||
return
|
||||
}
|
||||
set.String(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -289,28 +368,41 @@ func (f StringFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// IntFlag is a flag that takes an integer
|
||||
// Errors if the value provided cannot be parsed
|
||||
type IntFlag struct {
|
||||
Name string
|
||||
Value int
|
||||
Usage string
|
||||
EnvVar string
|
||||
Name string
|
||||
Value int
|
||||
Usage string
|
||||
EnvVar string
|
||||
Destination *int
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f IntFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f IntFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
envValInt, err := strconv.ParseUint(envVal, 10, 64)
|
||||
if err == nil {
|
||||
f.Value = int(envValInt)
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValInt, err := strconv.ParseInt(envVal, 0, 64)
|
||||
if err == nil {
|
||||
f.Value = int(envValInt)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Destination != nil {
|
||||
set.IntVar(f.Destination, name, f.Value, f.Usage)
|
||||
return
|
||||
}
|
||||
set.Int(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -319,28 +411,41 @@ func (f IntFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// DurationFlag is a flag that takes a duration specified in Go's duration
|
||||
// format: https://golang.org/pkg/time/#ParseDuration
|
||||
type DurationFlag struct {
|
||||
Name string
|
||||
Value time.Duration
|
||||
Usage string
|
||||
EnvVar string
|
||||
Name string
|
||||
Value time.Duration
|
||||
Usage string
|
||||
EnvVar string
|
||||
Destination *time.Duration
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f DurationFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f DurationFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
envValDuration, err := time.ParseDuration(envVal)
|
||||
if err == nil {
|
||||
f.Value = envValDuration
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValDuration, err := time.ParseDuration(envVal)
|
||||
if err == nil {
|
||||
f.Value = envValDuration
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Destination != nil {
|
||||
set.DurationVar(f.Destination, name, f.Value, f.Usage)
|
||||
return
|
||||
}
|
||||
set.Duration(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -349,28 +454,40 @@ func (f DurationFlag) getName() string {
|
|||
return f.Name
|
||||
}
|
||||
|
||||
// Float64Flag is a flag that takes an float value
|
||||
// Errors if the value provided cannot be parsed
|
||||
type Float64Flag struct {
|
||||
Name string
|
||||
Value float64
|
||||
Usage string
|
||||
EnvVar string
|
||||
Name string
|
||||
Value float64
|
||||
Usage string
|
||||
EnvVar string
|
||||
Destination *float64
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f Float64Flag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f Float64Flag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
if envVal := os.Getenv(f.EnvVar); envVal != "" {
|
||||
envValFloat, err := strconv.ParseFloat(envVal, 10)
|
||||
if err == nil {
|
||||
f.Value = float64(envValFloat)
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValFloat, err := strconv.ParseFloat(envVal, 10)
|
||||
if err == nil {
|
||||
f.Value = float64(envValFloat)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Destination != nil {
|
||||
set.Float64Var(f.Destination, name, f.Value, f.Usage)
|
||||
return
|
||||
}
|
||||
set.Float64(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
@ -404,7 +521,7 @@ func prefixedNames(fullName string) (prefixed string) {
|
|||
func withEnvHint(envVar, str string) string {
|
||||
envText := ""
|
||||
if envVar != "" {
|
||||
envText = fmt.Sprintf(" [$%s]", envVar)
|
||||
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $"))
|
||||
}
|
||||
return str + envText
|
||||
}
|
||||
|
|
|
@ -1,587 +0,0 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
var boolFlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help\t"},
|
||||
{"h", "-h\t"},
|
||||
}
|
||||
|
||||
func TestBoolFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range boolFlagTests {
|
||||
flag := cli.BoolFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stringFlagTests = []struct {
|
||||
name string
|
||||
value string
|
||||
expected string
|
||||
}{
|
||||
{"help", "", "--help \t"},
|
||||
{"h", "", "-h \t"},
|
||||
{"h", "", "-h \t"},
|
||||
{"test", "Something", "--test 'Something'\t"},
|
||||
}
|
||||
|
||||
func TestStringFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range stringFlagTests {
|
||||
flag := cli.StringFlag{Name: test.name, Value: test.value}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_FOO", "derp")
|
||||
for _, test := range stringFlagTests {
|
||||
flag := cli.StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_FOO]") {
|
||||
t.Errorf("%s does not end with [$APP_FOO]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stringSliceFlagTests = []struct {
|
||||
name string
|
||||
value *cli.StringSlice
|
||||
expected string
|
||||
}{
|
||||
{"help", func() *cli.StringSlice {
|
||||
s := &cli.StringSlice{}
|
||||
s.Set("")
|
||||
return s
|
||||
}(), "--help '--help option --help option'\t"},
|
||||
{"h", func() *cli.StringSlice {
|
||||
s := &cli.StringSlice{}
|
||||
s.Set("")
|
||||
return s
|
||||
}(), "-h '-h option -h option'\t"},
|
||||
{"h", func() *cli.StringSlice {
|
||||
s := &cli.StringSlice{}
|
||||
s.Set("")
|
||||
return s
|
||||
}(), "-h '-h option -h option'\t"},
|
||||
{"test", func() *cli.StringSlice {
|
||||
s := &cli.StringSlice{}
|
||||
s.Set("Something")
|
||||
return s
|
||||
}(), "--test '--test option --test option'\t"},
|
||||
}
|
||||
|
||||
func TestStringSliceFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range stringSliceFlagTests {
|
||||
flag := cli.StringSliceFlag{Name: test.name, Value: test.value}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%q does not match %q", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_QWWX", "11,4")
|
||||
for _, test := range stringSliceFlagTests {
|
||||
flag := cli.StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_QWWX]") {
|
||||
t.Errorf("%q does not end with [$APP_QWWX]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var intFlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help '0'\t"},
|
||||
{"h", "-h '0'\t"},
|
||||
}
|
||||
|
||||
func TestIntFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range intFlagTests {
|
||||
flag := cli.IntFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_BAR", "2")
|
||||
for _, test := range intFlagTests {
|
||||
flag := cli.IntFlag{Name: test.name, EnvVar: "APP_BAR"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAR]") {
|
||||
t.Errorf("%s does not end with [$APP_BAR]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var durationFlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help '0'\t"},
|
||||
{"h", "-h '0'\t"},
|
||||
}
|
||||
|
||||
func TestDurationFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range durationFlagTests {
|
||||
flag := cli.DurationFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_BAR", "2h3m6s")
|
||||
for _, test := range durationFlagTests {
|
||||
flag := cli.DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAR]") {
|
||||
t.Errorf("%s does not end with [$APP_BAR]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var intSliceFlagTests = []struct {
|
||||
name string
|
||||
value *cli.IntSlice
|
||||
expected string
|
||||
}{
|
||||
{"help", &cli.IntSlice{}, "--help '--help option --help option'\t"},
|
||||
{"h", &cli.IntSlice{}, "-h '-h option -h option'\t"},
|
||||
{"h", &cli.IntSlice{}, "-h '-h option -h option'\t"},
|
||||
{"test", func() *cli.IntSlice {
|
||||
i := &cli.IntSlice{}
|
||||
i.Set("9")
|
||||
return i
|
||||
}(), "--test '--test option --test option'\t"},
|
||||
}
|
||||
|
||||
func TestIntSliceFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range intSliceFlagTests {
|
||||
flag := cli.IntSliceFlag{Name: test.name, Value: test.value}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%q does not match %q", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_SMURF", "42,3")
|
||||
for _, test := range intSliceFlagTests {
|
||||
flag := cli.IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_SMURF]") {
|
||||
t.Errorf("%q does not end with [$APP_SMURF]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var float64FlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help '0'\t"},
|
||||
{"h", "-h '0'\t"},
|
||||
}
|
||||
|
||||
func TestFloat64FlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range float64FlagTests {
|
||||
flag := cli.Float64Flag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_BAZ", "99.4")
|
||||
for _, test := range float64FlagTests {
|
||||
flag := cli.Float64Flag{Name: test.name, EnvVar: "APP_BAZ"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAZ]") {
|
||||
t.Errorf("%s does not end with [$APP_BAZ]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var genericFlagTests = []struct {
|
||||
name string
|
||||
value cli.Generic
|
||||
expected string
|
||||
}{
|
||||
{"help", &Parser{}, "--help <nil>\t`-help option -help option` "},
|
||||
{"h", &Parser{}, "-h <nil>\t`-h option -h option` "},
|
||||
{"test", &Parser{}, "--test <nil>\t`-test option -test option` "},
|
||||
}
|
||||
|
||||
func TestGenericFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range genericFlagTests {
|
||||
flag := cli.GenericFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%q does not match %q", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
|
||||
os.Setenv("APP_ZAP", "3")
|
||||
for _, test := range genericFlagTests {
|
||||
flag := cli.GenericFlag{Name: test.name, EnvVar: "APP_ZAP"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_ZAP]") {
|
||||
t.Errorf("%s does not end with [$APP_ZAP]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMultiString(t *testing.T) {
|
||||
(&cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.String("serve") != "10" {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.String("s") != "10" {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run", "-s", "10"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_COUNT", "20")
|
||||
(&cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "count, c", EnvVar: "APP_COUNT"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.String("count") != "20" {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.String("c") != "20" {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringSlice(t *testing.T) {
|
||||
(&cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.StringSliceFlag{Name: "serve, s", Value: &cli.StringSlice{}},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run", "-s", "10", "-s", "20"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringSliceFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_INTERVALS", "20,30,40")
|
||||
|
||||
(&cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "APP_INTERVALS"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiInt(t *testing.T) {
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.IntFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.Int("serve") != 10 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Int("s") != 10 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "-s", "10"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_TIMEOUT_SECONDS", "10")
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.Int("timeout") != 10 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Int("t") != 10 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntSlice(t *testing.T) {
|
||||
(&cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.IntSliceFlag{Name: "serve, s", Value: &cli.IntSlice{}},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run", "-s", "10", "-s", "20"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntSliceFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_INTERVALS", "20,30,40")
|
||||
|
||||
(&cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "APP_INTERVALS"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiFloat64(t *testing.T) {
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.Float64Flag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.Float64("serve") != 10.2 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Float64("s") != 10.2 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "-s", "10.2"})
|
||||
}
|
||||
|
||||
func TestParseMultiFloat64FromEnv(t *testing.T) {
|
||||
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.Float64("timeout") != 15.5 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Float64("t") != 15.5 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiBool(t *testing.T) {
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.BoolFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.Bool("serve") != true {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Bool("s") != true {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "--serve"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_DEBUG", "1")
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.Bool("debug") != true {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if ctx.Bool("d") != true {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolT(t *testing.T) {
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.BoolTFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.BoolT("serve") != true {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.BoolT("s") != true {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "--serve"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolTFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_DEBUG", "0")
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if ctx.BoolT("debug") != false {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if ctx.BoolT("d") != false {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
type Parser [2]string
|
||||
|
||||
func (p *Parser) Set(value string) error {
|
||||
parts := strings.Split(value, ",")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid format")
|
||||
}
|
||||
|
||||
(*p)[0] = parts[0]
|
||||
(*p)[1] = parts[1]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) String() string {
|
||||
return fmt.Sprintf("%s,%s", p[0], p[1])
|
||||
}
|
||||
|
||||
func TestParseGeneric(t *testing.T) {
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.GenericFlag{Name: "serve, s", Value: &Parser{}},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "-s", "10,20"})
|
||||
}
|
||||
|
||||
func TestParseGenericFromEnv(t *testing.T) {
|
||||
os.Setenv("APP_SERVE", "20,30")
|
||||
a := cli.App{
|
||||
Flags: []cli.Flag{
|
||||
cli.GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"},
|
||||
},
|
||||
Action: func(ctx *cli.Context) {
|
||||
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
|
@ -2,7 +2,8 @@ package cli
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"io"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"text/template"
|
||||
)
|
||||
|
@ -14,31 +15,33 @@ var AppHelpTemplate = `NAME:
|
|||
{{.Name}} - {{.Usage}}
|
||||
|
||||
USAGE:
|
||||
{{.Name}} {{if .Flags}}[global options] {{end}}command{{if .Flags}} [command options]{{end}} [arguments...]
|
||||
|
||||
{{.HelpName}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}
|
||||
{{if .Version}}
|
||||
VERSION:
|
||||
{{.Version}}{{if or .Author .Email}}
|
||||
|
||||
AUTHOR:{{if .Author}}
|
||||
{{.Author}}{{if .Email}} - <{{.Email}}>{{end}}{{else}}
|
||||
{{.Email}}{{end}}{{end}}
|
||||
|
||||
{{.Version}}
|
||||
{{end}}{{if len .Authors}}
|
||||
AUTHOR(S):
|
||||
{{range .Authors}}{{ . }}{{end}}
|
||||
{{end}}{{if .Commands}}
|
||||
COMMANDS:
|
||||
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
|
||||
{{end}}{{if .Flags}}
|
||||
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
|
||||
{{end}}{{end}}{{if .Flags}}
|
||||
GLOBAL OPTIONS:
|
||||
{{range .Flags}}{{.}}
|
||||
{{end}}{{end}}
|
||||
{{end}}{{end}}{{if .Copyright }}
|
||||
COPYRIGHT:
|
||||
{{.Copyright}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
// The text template for the command help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var CommandHelpTemplate = `NAME:
|
||||
{{.Name}} - {{.Usage}}
|
||||
{{.HelpName}} - {{.Usage}}
|
||||
|
||||
USAGE:
|
||||
command {{.Name}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}}
|
||||
{{.HelpName}}{{if .Flags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{if .Description}}
|
||||
|
||||
DESCRIPTION:
|
||||
{{.Description}}{{end}}{{if .Flags}}
|
||||
|
@ -52,13 +55,13 @@ OPTIONS:
|
|||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var SubcommandHelpTemplate = `NAME:
|
||||
{{.Name}} - {{.Usage}}
|
||||
{{.HelpName}} - {{.Usage}}
|
||||
|
||||
USAGE:
|
||||
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...]
|
||||
{{.HelpName}} command{{if .Flags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}
|
||||
|
||||
COMMANDS:
|
||||
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
|
||||
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
|
||||
{{end}}{{if .Flags}}
|
||||
OPTIONS:
|
||||
{{range .Flags}}{{.}}
|
||||
|
@ -67,8 +70,9 @@ OPTIONS:
|
|||
|
||||
var helpCommand = Command{
|
||||
Name: "help",
|
||||
ShortName: "h",
|
||||
Aliases: []string{"h"},
|
||||
Usage: "Shows a list of commands or help for one command",
|
||||
ArgsUsage: "[command]",
|
||||
Action: func(c *Context) {
|
||||
args := c.Args()
|
||||
if args.Present() {
|
||||
|
@ -81,8 +85,9 @@ var helpCommand = Command{
|
|||
|
||||
var helpSubcommand = Command{
|
||||
Name: "help",
|
||||
ShortName: "h",
|
||||
Aliases: []string{"h"},
|
||||
Usage: "Shows a list of commands or help for one command",
|
||||
ArgsUsage: "[command]",
|
||||
Action: func(c *Context) {
|
||||
args := c.Args()
|
||||
if args.Present() {
|
||||
|
@ -93,45 +98,52 @@ var helpSubcommand = Command{
|
|||
},
|
||||
}
|
||||
|
||||
// Prints help for the App
|
||||
var HelpPrinter = printHelp
|
||||
// Prints help for the App or Command
|
||||
type helpPrinter func(w io.Writer, templ string, data interface{})
|
||||
|
||||
var HelpPrinter helpPrinter = printHelp
|
||||
|
||||
// Prints version for the App
|
||||
var VersionPrinter = printVersion
|
||||
|
||||
func ShowAppHelp(c *Context) {
|
||||
HelpPrinter(AppHelpTemplate, c.App)
|
||||
HelpPrinter(c.App.Writer, AppHelpTemplate, c.App)
|
||||
}
|
||||
|
||||
// Prints the list of subcommands as the default app completion method
|
||||
func DefaultAppComplete(c *Context) {
|
||||
for _, command := range c.App.Commands {
|
||||
fmt.Println(command.Name)
|
||||
if command.ShortName != "" {
|
||||
fmt.Println(command.ShortName)
|
||||
for _, name := range command.Names() {
|
||||
fmt.Fprintln(c.App.Writer, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prints help for the given command
|
||||
func ShowCommandHelp(c *Context, command string) {
|
||||
for _, c := range c.App.Commands {
|
||||
func ShowCommandHelp(ctx *Context, command string) {
|
||||
// show the subcommand help for a command with subcommands
|
||||
if command == "" {
|
||||
HelpPrinter(ctx.App.Writer, SubcommandHelpTemplate, ctx.App)
|
||||
return
|
||||
}
|
||||
|
||||
for _, c := range ctx.App.Commands {
|
||||
if c.HasName(command) {
|
||||
HelpPrinter(CommandHelpTemplate, c)
|
||||
HelpPrinter(ctx.App.Writer, CommandHelpTemplate, c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if c.App.CommandNotFound != nil {
|
||||
c.App.CommandNotFound(c, command)
|
||||
if ctx.App.CommandNotFound != nil {
|
||||
ctx.App.CommandNotFound(ctx, command)
|
||||
} else {
|
||||
fmt.Printf("No help topic for '%v'\n", command)
|
||||
fmt.Fprintf(ctx.App.Writer, "No help topic for '%v'\n", command)
|
||||
}
|
||||
}
|
||||
|
||||
// Prints help for the given subcommand
|
||||
func ShowSubcommandHelp(c *Context) {
|
||||
HelpPrinter(SubcommandHelpTemplate, c.App)
|
||||
ShowCommandHelp(c, c.Command.Name)
|
||||
}
|
||||
|
||||
// Prints the version number of the App
|
||||
|
@ -140,7 +152,7 @@ func ShowVersion(c *Context) {
|
|||
}
|
||||
|
||||
func printVersion(c *Context) {
|
||||
fmt.Printf("%v version %v\n", c.App.Name, c.App.Version)
|
||||
fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version)
|
||||
}
|
||||
|
||||
// Prints the lists of commands within a given context
|
||||
|
@ -159,9 +171,13 @@ func ShowCommandCompletions(ctx *Context, command string) {
|
|||
}
|
||||
}
|
||||
|
||||
func printHelp(templ string, data interface{}) {
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0)
|
||||
t := template.Must(template.New("help").Parse(templ))
|
||||
func printHelp(out io.Writer, templ string, data interface{}) {
|
||||
funcMap := template.FuncMap{
|
||||
"join": strings.Join,
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(out, 0, 8, 1, '\t', 0)
|
||||
t := template.Must(template.New("help").Funcs(funcMap).Parse(templ))
|
||||
err := t.Execute(w, data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -170,21 +186,27 @@ func printHelp(templ string, data interface{}) {
|
|||
}
|
||||
|
||||
func checkVersion(c *Context) bool {
|
||||
if c.GlobalBool("version") {
|
||||
ShowVersion(c)
|
||||
return true
|
||||
found := false
|
||||
if VersionFlag.Name != "" {
|
||||
eachName(VersionFlag.Name, func(name string) {
|
||||
if c.GlobalBool(name) || c.Bool(name) {
|
||||
found = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return false
|
||||
return found
|
||||
}
|
||||
|
||||
func checkHelp(c *Context) bool {
|
||||
if c.GlobalBool("h") || c.GlobalBool("help") {
|
||||
ShowAppHelp(c)
|
||||
return true
|
||||
found := false
|
||||
if HelpFlag.Name != "" {
|
||||
eachName(HelpFlag.Name, func(name string) {
|
||||
if c.GlobalBool(name) || c.Bool(name) {
|
||||
found = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return false
|
||||
return found
|
||||
}
|
||||
|
||||
func checkCommandHelp(c *Context, name string) bool {
|
||||
|
@ -206,7 +228,7 @@ func checkSubcommandHelp(c *Context) bool {
|
|||
}
|
||||
|
||||
func checkCompletions(c *Context) bool {
|
||||
if c.GlobalBool(BashCompletionFlag.Name) && c.App.EnableBashCompletion {
|
||||
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion {
|
||||
ShowCompletions(c)
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
package cli_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
/* Test Helpers */
|
||||
func expect(t *testing.T, a interface{}, b interface{}) {
|
||||
if a != b {
|
||||
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
|
||||
}
|
||||
}
|
||||
|
||||
func refute(t *testing.T, a interface{}, b interface{}) {
|
||||
if a == b {
|
||||
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
Copyright (C) 2014 Thomas Rooney
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,64 @@
|
|||
# Gexpect
|
||||
|
||||
Gexpect is a pure golang expect-like module.
|
||||
|
||||
It makes it simpler to create and control other terminal applications.
|
||||
|
||||
child, err := gexpect.Spawn("python")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
child.Expect(">>>")
|
||||
child.SendLine("print 'Hello World'")
|
||||
child.Interact()
|
||||
child.Close()
|
||||
|
||||
## Examples
|
||||
|
||||
`Spawn` handles the argument parsing from a string
|
||||
|
||||
child.Spawn("/bin/sh -c 'echo \"my complicated command\" | tee log | cat > log2'")
|
||||
child.ReadLine() // ReadLine() (string, error)
|
||||
child.ReadUntil(' ') // ReadUntil(delim byte) ([]byte, error)
|
||||
|
||||
`ReadLine`, `ReadUntil` and `SendLine` send strings from/to `stdout/stdin` respectively
|
||||
|
||||
child := gexpect.Spawn("cat")
|
||||
child.SendLine("echoing process_stdin") // SendLine(command string) (error)
|
||||
msg, _ := child.ReadLine() // msg = echoing process_stdin
|
||||
|
||||
`Wait` and `Close` allow for graceful and ungraceful termination.
|
||||
|
||||
child.Wait() // Waits until the child terminates naturally.
|
||||
child.Close() // Sends a kill command
|
||||
|
||||
`AsyncInteractChannels` spawns two go routines to pipe into and from `stdout`/`stdin`, allowing for some usecases to be a little simpler.
|
||||
|
||||
child := gexpect.spawn("sh")
|
||||
sender, reciever := child.AsyncInteractChannels()
|
||||
sender <- "echo Hello World\n" // Send to stdin
|
||||
line, open := <- reciever // Recieve a line from stdout/stderr
|
||||
// When the subprocess stops (e.g. with child.Close()) , receiver is closed
|
||||
if open {
|
||||
fmt.Printf("Received %s", line)]
|
||||
}
|
||||
|
||||
`ExpectRegex` uses golang's internal regex engine to wait until a match from the process with the given regular expression (or an error on process termination with no match).
|
||||
|
||||
child := gexpect.Spawn("echo accb")
|
||||
match, _ := child.ExpectRegex("a..b")
|
||||
// (match=true)
|
||||
|
||||
`ExpectRegexFind` allows for groups to be extracted from process stdout. The first element is an array of containing the total matched text, followed by each subexpression group match.
|
||||
|
||||
child := gexpect.Spawn("echo 123 456 789")
|
||||
result, _ := child.ExpectRegexFind("\d+ (\d+) (\d+)")
|
||||
// result = []string{"123 456 789", "456", "789"}
|
||||
|
||||
See `gexpect_test.go` and the `examples` folder for full syntax
|
||||
|
||||
## Credits
|
||||
|
||||
github.com/kballard/go-shellquote
|
||||
github.com/kr/pty
|
||||
KMP Algorithm: "http://blog.databigbang.com/searching-for-substrings-in-streams-a-slight-modification-of-the-knuth-morris-pratt-algorithm-in-haxe/"
|
|
@ -0,0 +1,27 @@
|
|||
package main
|
||||
|
||||
import gexpect "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
|
||||
import "log"
|
||||
|
||||
func main() {
|
||||
log.Printf("Testing Ftp... ")
|
||||
|
||||
child, err := gexpect.Spawn("ftp ftp.openbsd.org")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
child.Expect("Name")
|
||||
child.SendLine("anonymous")
|
||||
child.Expect("Password")
|
||||
child.SendLine("pexpect@sourceforge.net")
|
||||
child.Expect("ftp> ")
|
||||
child.SendLine("cd /pub/OpenBSD/3.7/packages/i386")
|
||||
child.Expect("ftp> ")
|
||||
child.SendLine("bin")
|
||||
child.Expect("ftp> ")
|
||||
child.SendLine("prompt")
|
||||
child.Expect("ftp> ")
|
||||
child.SendLine("pwd")
|
||||
child.Expect("ftp> ")
|
||||
log.Printf("Success\n")
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package main
|
||||
|
||||
import gexpect "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
|
||||
import "log"
|
||||
|
||||
func main() {
|
||||
log.Printf("Testing Ping interact... \n")
|
||||
|
||||
child, err := gexpect.Spawn("ping -c8 127.0.0.1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
child.Interact()
|
||||
log.Printf("Success\n")
|
||||
}
|
22
Godeps/_workspace/src/github.com/coreos/gexpect/examples/python.go
generated
vendored
Normal file
22
Godeps/_workspace/src/github.com/coreos/gexpect/examples/python.go
generated
vendored
Normal file
|
@ -0,0 +1,22 @@
|
|||
package main
|
||||
|
||||
import "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
|
||||
import "fmt"
|
||||
|
||||
func main() {
|
||||
fmt.Printf("Starting python.. \n")
|
||||
child, err := gexpect.Spawn("python")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("Expecting >>>.. \n")
|
||||
child.Expect(">>>")
|
||||
fmt.Printf("print 'Hello World'..\n")
|
||||
child.SendLine("print 'Hello World'")
|
||||
child.Expect(">>>")
|
||||
|
||||
fmt.Printf("Interacting.. \n")
|
||||
child.Interact()
|
||||
fmt.Printf("Done \n")
|
||||
child.Close()
|
||||
}
|
53
Godeps/_workspace/src/github.com/coreos/gexpect/examples/screen.go
generated
vendored
Normal file
53
Godeps/_workspace/src/github.com/coreos/gexpect/examples/screen.go
generated
vendored
Normal file
|
@ -0,0 +1,53 @@
|
|||
package main
|
||||
|
||||
import "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
|
||||
import "fmt"
|
||||
import "strings"
|
||||
|
||||
func main() {
|
||||
waitChan := make(chan string)
|
||||
|
||||
fmt.Printf("Starting screen.. \n")
|
||||
|
||||
child, err := gexpect.Spawn("screen")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sender, reciever := child.AsyncInteractChannels()
|
||||
go func() {
|
||||
waitString := ""
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case waitString = <-waitChan:
|
||||
count++
|
||||
case msg, open := <-reciever:
|
||||
if !open {
|
||||
return
|
||||
}
|
||||
fmt.Printf("Recieved: %s\n", msg)
|
||||
|
||||
if strings.Contains(msg, waitString) {
|
||||
if count >= 1 {
|
||||
waitChan <- msg
|
||||
count -= 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
wait := func(str string) {
|
||||
waitChan <- str
|
||||
<-waitChan
|
||||
}
|
||||
fmt.Printf("Waiting until started.. \n")
|
||||
wait(" ")
|
||||
fmt.Printf("Sending Enter.. \n")
|
||||
sender <- "\n"
|
||||
wait("$")
|
||||
fmt.Printf("Sending echo.. \n")
|
||||
sender <- "echo Hello World\n"
|
||||
wait("Hello World")
|
||||
fmt.Printf("Received echo. \n")
|
||||
}
|
|
@ -0,0 +1,430 @@
|
|||
package gexpect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
shell "github.com/coreos/etcd/Godeps/_workspace/src/github.com/kballard/go-shellquote"
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/kr/pty"
|
||||
)
|
||||
|
||||
type ExpectSubprocess struct {
|
||||
Cmd *exec.Cmd
|
||||
buf *buffer
|
||||
outputBuffer []byte
|
||||
}
|
||||
|
||||
type buffer struct {
|
||||
f *os.File
|
||||
b bytes.Buffer
|
||||
collect bool
|
||||
|
||||
collection bytes.Buffer
|
||||
}
|
||||
|
||||
func (buf *buffer) StartCollecting() {
|
||||
buf.collect = true
|
||||
}
|
||||
|
||||
func (buf *buffer) StopCollecting() (result string) {
|
||||
result = string(buf.collection.Bytes())
|
||||
buf.collect = false
|
||||
buf.collection.Reset()
|
||||
return result
|
||||
}
|
||||
|
||||
func (buf *buffer) Read(chunk []byte) (int, error) {
|
||||
nread := 0
|
||||
if buf.b.Len() > 0 {
|
||||
n, err := buf.b.Read(chunk)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n == len(chunk) {
|
||||
return n, nil
|
||||
}
|
||||
nread = n
|
||||
}
|
||||
fn, err := buf.f.Read(chunk[nread:])
|
||||
return fn + nread, err
|
||||
}
|
||||
|
||||
func (buf *buffer) ReadRune() (r rune, size int, err error) {
|
||||
l := buf.b.Len()
|
||||
|
||||
chunk := make([]byte, utf8.UTFMax)
|
||||
if l > 0 {
|
||||
n, err := buf.b.Read(chunk)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if utf8.FullRune(chunk) {
|
||||
r, rL := utf8.DecodeRune(chunk)
|
||||
if n > rL {
|
||||
buf.PutBack(chunk[rL:n])
|
||||
}
|
||||
if buf.collect {
|
||||
buf.collection.WriteRune(r)
|
||||
}
|
||||
return r, rL, nil
|
||||
}
|
||||
}
|
||||
// else add bytes from the file, then try that
|
||||
for l < utf8.UTFMax {
|
||||
fn, err := buf.f.Read(chunk[l : l+1])
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
l = l + fn
|
||||
|
||||
if utf8.FullRune(chunk) {
|
||||
r, rL := utf8.DecodeRune(chunk)
|
||||
if buf.collect {
|
||||
buf.collection.WriteRune(r)
|
||||
}
|
||||
return r, rL, nil
|
||||
}
|
||||
}
|
||||
return 0, 0, errors.New("File is not a valid UTF=8 encoding")
|
||||
}
|
||||
|
||||
func (buf *buffer) PutBack(chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
if buf.b.Len() == 0 {
|
||||
buf.b.Write(chunk)
|
||||
return
|
||||
}
|
||||
d := make([]byte, 0, len(chunk)+buf.b.Len())
|
||||
d = append(d, chunk...)
|
||||
d = append(d, buf.b.Bytes()...)
|
||||
buf.b.Reset()
|
||||
buf.b.Write(d)
|
||||
}
|
||||
|
||||
func SpawnAtDirectory(command string, directory string) (*ExpectSubprocess, error) {
|
||||
expect, err := _spawn(command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expect.Cmd.Dir = directory
|
||||
return _start(expect)
|
||||
}
|
||||
|
||||
func Command(command string) (*ExpectSubprocess, error) {
|
||||
expect, err := _spawn(command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return expect, nil
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Start() error {
|
||||
_, err := _start(expect)
|
||||
return err
|
||||
}
|
||||
|
||||
func Spawn(command string) (*ExpectSubprocess, error) {
|
||||
expect, err := _spawn(command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return _start(expect)
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Close() error {
|
||||
return expect.Cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) AsyncInteractChannels() (send chan string, receive chan string) {
|
||||
receive = make(chan string)
|
||||
send = make(chan string)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
str, err := expect.ReadLine()
|
||||
if err != nil {
|
||||
close(receive)
|
||||
return
|
||||
}
|
||||
receive <- str
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case sendCommand, exists := <-send:
|
||||
{
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
err := expect.Send(sendCommand)
|
||||
if err != nil {
|
||||
receive <- "gexpect Error: " + err.Error()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ExpectRegex(regex string) (bool, error) {
|
||||
return regexp.MatchReader(regex, expect.buf)
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) expectRegexFind(regex string, output bool) ([]string, string, error) {
|
||||
re, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
expect.buf.StartCollecting()
|
||||
pairs := re.FindReaderSubmatchIndex(expect.buf)
|
||||
stringIndexedInto := expect.buf.StopCollecting()
|
||||
l := len(pairs)
|
||||
numPairs := l / 2
|
||||
result := make([]string, numPairs)
|
||||
for i := 0; i < numPairs; i += 1 {
|
||||
result[i] = stringIndexedInto[pairs[i*2]:pairs[i*2+1]]
|
||||
}
|
||||
// convert indexes to strings
|
||||
|
||||
if len(result) == 0 {
|
||||
err = fmt.Errorf("ExpectRegex didn't find regex '%v'.", regex)
|
||||
}
|
||||
return result, stringIndexedInto, err
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) expectTimeoutRegexFind(regex string, timeout time.Duration) (result []string, out string, err error) {
|
||||
t := make(chan bool)
|
||||
go func() {
|
||||
result, out, err = expect.ExpectRegexFindWithOutput(regex)
|
||||
t <- false
|
||||
}()
|
||||
go func() {
|
||||
time.Sleep(timeout)
|
||||
err = fmt.Errorf("ExpectRegex timed out after %v finding '%v'.\nOutput:\n%s", timeout, regex, expect.Collect())
|
||||
t <- true
|
||||
}()
|
||||
<-t
|
||||
return result, out, err
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ExpectRegexFind(regex string) ([]string, error) {
|
||||
result, _, err := expect.expectRegexFind(regex, false)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ExpectTimeoutRegexFind(regex string, timeout time.Duration) ([]string, error) {
|
||||
result, _, err := expect.expectTimeoutRegexFind(regex, timeout)
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ExpectRegexFindWithOutput(regex string) ([]string, string, error) {
|
||||
return expect.expectRegexFind(regex, true)
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ExpectTimeoutRegexFindWithOutput(regex string, timeout time.Duration) ([]string, string, error) {
|
||||
return expect.expectTimeoutRegexFind(regex, timeout)
|
||||
}
|
||||
|
||||
func buildKMPTable(searchString string) []int {
|
||||
pos := 2
|
||||
cnd := 0
|
||||
length := len(searchString)
|
||||
|
||||
var table []int
|
||||
if length < 2 {
|
||||
length = 2
|
||||
}
|
||||
|
||||
table = make([]int, length)
|
||||
table[0] = -1
|
||||
table[1] = 0
|
||||
|
||||
for pos < len(searchString) {
|
||||
if searchString[pos-1] == searchString[cnd] {
|
||||
cnd += 1
|
||||
table[pos] = cnd
|
||||
pos += 1
|
||||
} else if cnd > 0 {
|
||||
cnd = table[cnd]
|
||||
} else {
|
||||
table[pos] = 0
|
||||
pos += 1
|
||||
}
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ExpectTimeout(searchString string, timeout time.Duration) (e error) {
|
||||
result := make(chan error)
|
||||
go func() {
|
||||
result <- expect.Expect(searchString)
|
||||
}()
|
||||
select {
|
||||
case e = <-result:
|
||||
case <-time.After(timeout):
|
||||
e = fmt.Errorf("Expect timed out after %v waiting for '%v'.\nOutput:\n%s", timeout, searchString, expect.Collect())
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Expect(searchString string) (e error) {
|
||||
chunk := make([]byte, len(searchString)*2)
|
||||
target := len(searchString)
|
||||
if expect.outputBuffer != nil {
|
||||
expect.outputBuffer = expect.outputBuffer[0:]
|
||||
}
|
||||
m := 0
|
||||
i := 0
|
||||
// Build KMP Table
|
||||
table := buildKMPTable(searchString)
|
||||
|
||||
for {
|
||||
n, err := expect.buf.Read(chunk)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if expect.outputBuffer != nil {
|
||||
expect.outputBuffer = append(expect.outputBuffer, chunk[:n]...)
|
||||
}
|
||||
offset := m + i
|
||||
for m+i-offset < n {
|
||||
if searchString[i] == chunk[m+i-offset] {
|
||||
i += 1
|
||||
if i == target {
|
||||
unreadIndex := m + i - offset
|
||||
if len(chunk) > unreadIndex {
|
||||
expect.buf.PutBack(chunk[unreadIndex:])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
m += i - table[i]
|
||||
if table[i] > -1 {
|
||||
i = table[i]
|
||||
} else {
|
||||
i = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Send(command string) error {
|
||||
_, err := io.WriteString(expect.buf.f, command)
|
||||
return err
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Capture() {
|
||||
if expect.outputBuffer == nil {
|
||||
expect.outputBuffer = make([]byte, 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Collect() []byte {
|
||||
collectOutput := make([]byte, len(expect.outputBuffer))
|
||||
copy(collectOutput, expect.outputBuffer)
|
||||
expect.outputBuffer = nil
|
||||
return collectOutput
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) SendLine(command string) error {
|
||||
_, err := io.WriteString(expect.buf.f, command+"\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Interact() {
|
||||
defer expect.Cmd.Wait()
|
||||
io.Copy(os.Stdout, &expect.buf.b)
|
||||
go io.Copy(os.Stdout, expect.buf.f)
|
||||
go io.Copy(expect.buf.f, os.Stdin)
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ReadUntil(delim byte) ([]byte, error) {
|
||||
join := make([]byte, 1, 512)
|
||||
chunk := make([]byte, 255)
|
||||
|
||||
for {
|
||||
|
||||
n, err := expect.buf.Read(chunk)
|
||||
|
||||
if err != nil {
|
||||
return join, err
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
if chunk[i] == delim {
|
||||
if len(chunk) > i+1 {
|
||||
expect.buf.PutBack(chunk[i+1:])
|
||||
}
|
||||
return join, nil
|
||||
} else {
|
||||
join = append(join, chunk[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) Wait() error {
|
||||
return expect.Cmd.Wait()
|
||||
}
|
||||
|
||||
func (expect *ExpectSubprocess) ReadLine() (string, error) {
|
||||
str, err := expect.ReadUntil('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func _start(expect *ExpectSubprocess) (*ExpectSubprocess, error) {
|
||||
f, err := pty.Start(expect.Cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expect.buf.f = f
|
||||
|
||||
return expect, nil
|
||||
}
|
||||
|
||||
func _spawn(command string) (*ExpectSubprocess, error) {
|
||||
wrapper := new(ExpectSubprocess)
|
||||
|
||||
wrapper.outputBuffer = nil
|
||||
|
||||
splitArgs, err := shell.Split(command)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numArguments := len(splitArgs) - 1
|
||||
if numArguments < 0 {
|
||||
return nil, errors.New("gexpect: No command given to spawn")
|
||||
}
|
||||
path, err := exec.LookPath(splitArgs[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if numArguments >= 1 {
|
||||
wrapper.Cmd = exec.Command(path, splitArgs[1:]...)
|
||||
} else {
|
||||
wrapper.Cmd = exec.Command(path)
|
||||
}
|
||||
wrapper.buf = new(buffer)
|
||||
|
||||
return wrapper, nil
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Brian Goff
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
19
Godeps/_workspace/src/github.com/cpuguy83/go-md2man/md2man/md2man.go
generated
vendored
Normal file
19
Godeps/_workspace/src/github.com/cpuguy83/go-md2man/md2man/md2man.go
generated
vendored
Normal file
|
@ -0,0 +1,19 @@
|
|||
package md2man
|
||||
|
||||
import (
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/russross/blackfriday"
|
||||
)
|
||||
|
||||
func Render(doc []byte) []byte {
|
||||
renderer := RoffRenderer(0)
|
||||
extensions := 0
|
||||
extensions |= blackfriday.EXTENSION_NO_INTRA_EMPHASIS
|
||||
extensions |= blackfriday.EXTENSION_TABLES
|
||||
extensions |= blackfriday.EXTENSION_FENCED_CODE
|
||||
extensions |= blackfriday.EXTENSION_AUTOLINK
|
||||
extensions |= blackfriday.EXTENSION_SPACE_HEADERS
|
||||
extensions |= blackfriday.EXTENSION_FOOTNOTES
|
||||
extensions |= blackfriday.EXTENSION_TITLEBLOCK
|
||||
|
||||
return blackfriday.Markdown(doc, renderer, extensions)
|
||||
}
|
269
Godeps/_workspace/src/github.com/cpuguy83/go-md2man/md2man/roff.go
generated
vendored
Normal file
269
Godeps/_workspace/src/github.com/cpuguy83/go-md2man/md2man/roff.go
generated
vendored
Normal file
|
@ -0,0 +1,269 @@
|
|||
package md2man
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/russross/blackfriday"
|
||||
)
|
||||
|
||||
type roffRenderer struct{}
|
||||
|
||||
func RoffRenderer(flags int) blackfriday.Renderer {
|
||||
return &roffRenderer{}
|
||||
}
|
||||
|
||||
func (r *roffRenderer) GetFlags() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *roffRenderer) TitleBlock(out *bytes.Buffer, text []byte) {
|
||||
out.WriteString(".TH ")
|
||||
|
||||
splitText := bytes.Split(text, []byte("\n"))
|
||||
for i, line := range splitText {
|
||||
line = bytes.TrimPrefix(line, []byte("% "))
|
||||
if i == 0 {
|
||||
line = bytes.Replace(line, []byte("("), []byte("\" \""), 1)
|
||||
line = bytes.Replace(line, []byte(")"), []byte("\" \""), 1)
|
||||
}
|
||||
line = append([]byte("\""), line...)
|
||||
line = append(line, []byte("\" ")...)
|
||||
out.Write(line)
|
||||
}
|
||||
|
||||
out.WriteString(" \"\"\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) BlockCode(out *bytes.Buffer, text []byte, lang string) {
|
||||
out.WriteString("\n.PP\n.RS\n\n.nf\n")
|
||||
escapeSpecialChars(out, text)
|
||||
out.WriteString("\n.fi\n.RE\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) BlockQuote(out *bytes.Buffer, text []byte) {
|
||||
out.WriteString("\n.PP\n.RS\n")
|
||||
out.Write(text)
|
||||
out.WriteString("\n.RE\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) BlockHtml(out *bytes.Buffer, text []byte) {
|
||||
out.Write(text)
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Header(out *bytes.Buffer, text func() bool, level int, id string) {
|
||||
marker := out.Len()
|
||||
|
||||
switch {
|
||||
case marker == 0:
|
||||
// This is the doc header
|
||||
out.WriteString(".TH ")
|
||||
case level == 1:
|
||||
out.WriteString("\n\n.SH ")
|
||||
case level == 2:
|
||||
out.WriteString("\n.SH ")
|
||||
default:
|
||||
out.WriteString("\n.SS ")
|
||||
}
|
||||
|
||||
if !text() {
|
||||
out.Truncate(marker)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *roffRenderer) HRule(out *bytes.Buffer) {
|
||||
out.WriteString("\n.ti 0\n\\l'\\n(.lu'\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) List(out *bytes.Buffer, text func() bool, flags int) {
|
||||
marker := out.Len()
|
||||
out.WriteString(".IP ")
|
||||
if flags&blackfriday.LIST_TYPE_ORDERED != 0 {
|
||||
out.WriteString("\\(bu 2")
|
||||
} else {
|
||||
out.WriteString("\\n+[step" + string(flags) + "]")
|
||||
}
|
||||
out.WriteString("\n")
|
||||
if !text() {
|
||||
out.Truncate(marker)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (r *roffRenderer) ListItem(out *bytes.Buffer, text []byte, flags int) {
|
||||
out.WriteString("\n\\item ")
|
||||
out.Write(text)
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Paragraph(out *bytes.Buffer, text func() bool) {
|
||||
marker := out.Len()
|
||||
out.WriteString("\n.PP\n")
|
||||
if !text() {
|
||||
out.Truncate(marker)
|
||||
return
|
||||
}
|
||||
if marker != 0 {
|
||||
out.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This might now work
|
||||
func (r *roffRenderer) Table(out *bytes.Buffer, header []byte, body []byte, columnData []int) {
|
||||
out.WriteString(".TS\nallbox;\n")
|
||||
|
||||
out.Write(header)
|
||||
out.Write(body)
|
||||
out.WriteString("\n.TE\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) TableRow(out *bytes.Buffer, text []byte) {
|
||||
if out.Len() > 0 {
|
||||
out.WriteString("\n")
|
||||
}
|
||||
out.Write(text)
|
||||
out.WriteString("\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) TableHeaderCell(out *bytes.Buffer, text []byte, align int) {
|
||||
if out.Len() > 0 {
|
||||
out.WriteString(" ")
|
||||
}
|
||||
out.Write(text)
|
||||
out.WriteString(" ")
|
||||
}
|
||||
|
||||
// TODO: This is probably broken
|
||||
func (r *roffRenderer) TableCell(out *bytes.Buffer, text []byte, align int) {
|
||||
if out.Len() > 0 {
|
||||
out.WriteString("\t")
|
||||
}
|
||||
out.Write(text)
|
||||
out.WriteString("\t")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Footnotes(out *bytes.Buffer, text func() bool) {
|
||||
|
||||
}
|
||||
|
||||
func (r *roffRenderer) FootnoteItem(out *bytes.Buffer, name, text []byte, flags int) {
|
||||
|
||||
}
|
||||
|
||||
func (r *roffRenderer) AutoLink(out *bytes.Buffer, link []byte, kind int) {
|
||||
out.WriteString("\n\\[la]")
|
||||
out.Write(link)
|
||||
out.WriteString("\\[ra]")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) CodeSpan(out *bytes.Buffer, text []byte) {
|
||||
out.WriteString("\\fB\\fC")
|
||||
escapeSpecialChars(out, text)
|
||||
out.WriteString("\\fR")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) DoubleEmphasis(out *bytes.Buffer, text []byte) {
|
||||
out.WriteString("\\fB")
|
||||
out.Write(text)
|
||||
out.WriteString("\\fP")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Emphasis(out *bytes.Buffer, text []byte) {
|
||||
out.WriteString("\\fI")
|
||||
out.Write(text)
|
||||
out.WriteString("\\fP")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Image(out *bytes.Buffer, link []byte, title []byte, alt []byte) {
|
||||
}
|
||||
|
||||
func (r *roffRenderer) LineBreak(out *bytes.Buffer) {
|
||||
out.WriteString("\n.br\n")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Link(out *bytes.Buffer, link []byte, title []byte, content []byte) {
|
||||
r.AutoLink(out, link, 0)
|
||||
}
|
||||
|
||||
func (r *roffRenderer) RawHtmlTag(out *bytes.Buffer, tag []byte) {
|
||||
out.Write(tag)
|
||||
}
|
||||
|
||||
func (r *roffRenderer) TripleEmphasis(out *bytes.Buffer, text []byte) {
|
||||
out.WriteString("\\s+2")
|
||||
out.Write(text)
|
||||
out.WriteString("\\s-2")
|
||||
}
|
||||
|
||||
func (r *roffRenderer) StrikeThrough(out *bytes.Buffer, text []byte) {
|
||||
}
|
||||
|
||||
func (r *roffRenderer) FootnoteRef(out *bytes.Buffer, ref []byte, id int) {
|
||||
|
||||
}
|
||||
|
||||
func (r *roffRenderer) Entity(out *bytes.Buffer, entity []byte) {
|
||||
out.WriteString(html.UnescapeString(string(entity)))
|
||||
}
|
||||
|
||||
func processFooterText(text []byte) []byte {
|
||||
text = bytes.TrimPrefix(text, []byte("% "))
|
||||
newText := []byte{}
|
||||
textArr := strings.Split(string(text), ") ")
|
||||
|
||||
for i, w := range textArr {
|
||||
if i == 0 {
|
||||
w = strings.Replace(w, "(", "\" \"", 1)
|
||||
w = fmt.Sprintf("\"%s\"", w)
|
||||
} else {
|
||||
w = fmt.Sprintf(" \"%s\"", w)
|
||||
}
|
||||
newText = append(newText, []byte(w)...)
|
||||
}
|
||||
newText = append(newText, []byte(" \"\"")...)
|
||||
|
||||
return newText
|
||||
}
|
||||
|
||||
func (r *roffRenderer) NormalText(out *bytes.Buffer, text []byte) {
|
||||
escapeSpecialChars(out, text)
|
||||
}
|
||||
|
||||
func (r *roffRenderer) DocumentHeader(out *bytes.Buffer) {
|
||||
}
|
||||
|
||||
func (r *roffRenderer) DocumentFooter(out *bytes.Buffer) {
|
||||
}
|
||||
|
||||
func needsBackslash(c byte) bool {
|
||||
for _, r := range []byte("-_&\\~") {
|
||||
if c == r {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func escapeSpecialChars(out *bytes.Buffer, text []byte) {
|
||||
for i := 0; i < len(text); i++ {
|
||||
// directly copy normal characters
|
||||
org := i
|
||||
|
||||
for i < len(text) && !needsBackslash(text[i]) {
|
||||
i++
|
||||
}
|
||||
if i > org {
|
||||
out.Write(text[org:i])
|
||||
}
|
||||
|
||||
// escape a character
|
||||
if i >= len(text) {
|
||||
break
|
||||
}
|
||||
out.WriteByte('\\')
|
||||
out.WriteByte(text[i])
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
Extensions for Protocol Buffers to create more go like structures.
|
||||
|
||||
Copyright (c) 2013, Vastech SA (PTY) LTD. All rights reserved.
|
||||
http://github.com/gogo/protobuf/gogoproto
|
||||
|
||||
Go support for Protocol Buffers - Google's data interchange format
|
||||
|
||||
Copyright 2010 The Go Authors. All rights reserved.
|
||||
https://github.com/golang/protobuf
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -39,5 +39,5 @@ test: install generate-test-pbs
|
|||
generate-test-pbs:
|
||||
make install
|
||||
make -C testdata
|
||||
protoc-min-version --version="3.0.0" --proto_path=.:../../../../ proto3_proto/proto3.proto
|
||||
protoc-min-version --version="3.0.0" --proto_path=.:../../../../ --gogo_out=. proto3_proto/proto3.proto
|
||||
make
|
||||
|
|
|
@ -401,17 +401,18 @@ type fakeMarshaler struct {
|
|||
err error
|
||||
}
|
||||
|
||||
func (f fakeMarshaler) Marshal() ([]byte, error) {
|
||||
return f.b, f.err
|
||||
func (f *fakeMarshaler) Marshal() ([]byte, error) { return f.b, f.err }
|
||||
func (f *fakeMarshaler) String() string { return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err) }
|
||||
func (f *fakeMarshaler) ProtoMessage() {}
|
||||
func (f *fakeMarshaler) Reset() {}
|
||||
|
||||
type msgWithFakeMarshaler struct {
|
||||
M *fakeMarshaler `protobuf:"bytes,1,opt,name=fake"`
|
||||
}
|
||||
|
||||
func (f fakeMarshaler) String() string {
|
||||
return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err)
|
||||
}
|
||||
|
||||
func (f fakeMarshaler) ProtoMessage() {}
|
||||
|
||||
func (f fakeMarshaler) Reset() {}
|
||||
func (m *msgWithFakeMarshaler) String() string { return CompactTextString(m) }
|
||||
func (m *msgWithFakeMarshaler) ProtoMessage() {}
|
||||
func (m *msgWithFakeMarshaler) Reset() {}
|
||||
|
||||
// Simple tests for proto messages that implement the Marshaler interface.
|
||||
func TestMarshalerEncoding(t *testing.T) {
|
||||
|
@ -423,7 +424,7 @@ func TestMarshalerEncoding(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
name: "Marshaler that fails",
|
||||
m: fakeMarshaler{
|
||||
m: &fakeMarshaler{
|
||||
err: errors.New("some marshal err"),
|
||||
b: []byte{5, 6, 7},
|
||||
},
|
||||
|
@ -431,9 +432,25 @@ func TestMarshalerEncoding(t *testing.T) {
|
|||
want: nil,
|
||||
wantErr: errors.New("some marshal err"),
|
||||
},
|
||||
{
|
||||
name: "Marshaler that fails with RequiredNotSetError",
|
||||
m: &msgWithFakeMarshaler{
|
||||
M: &fakeMarshaler{
|
||||
err: &RequiredNotSetError{},
|
||||
b: []byte{5, 6, 7},
|
||||
},
|
||||
},
|
||||
// Since there's an error that can be continued after,
|
||||
// the buffer should be written.
|
||||
want: []byte{
|
||||
10, 3, // for &msgWithFakeMarshaler
|
||||
5, 6, 7, // for &fakeMarshaler
|
||||
},
|
||||
wantErr: &RequiredNotSetError{},
|
||||
},
|
||||
{
|
||||
name: "Marshaler that succeeds",
|
||||
m: fakeMarshaler{
|
||||
m: &fakeMarshaler{
|
||||
b: []byte{0, 1, 2, 3, 4, 127, 255},
|
||||
},
|
||||
want: []byte{0, 1, 2, 3, 4, 127, 255},
|
||||
|
@ -443,6 +460,10 @@ func TestMarshalerEncoding(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
b := NewBuffer(nil)
|
||||
err := b.Marshal(test.m)
|
||||
if _, ok := err.(*RequiredNotSetError); ok {
|
||||
// We're not in package proto, so we can only assert the type in this case.
|
||||
err = &RequiredNotSetError{}
|
||||
}
|
||||
if !reflect.DeepEqual(test.wantErr, err) {
|
||||
t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr)
|
||||
}
|
||||
|
@ -1281,7 +1302,7 @@ func TestEnum(t *testing.T) {
|
|||
// We don't care what the value actually is, just as long as it doesn't crash.
|
||||
func TestPrintingNilEnumFields(t *testing.T) {
|
||||
pb := new(GoEnum)
|
||||
fmt.Sprintf("%+v", pb)
|
||||
_ = fmt.Sprintf("%+v", pb)
|
||||
}
|
||||
|
||||
// Verify that absent required fields cause Marshal/Unmarshal to return errors.
|
||||
|
@ -1925,6 +1946,83 @@ func TestMapFieldRoundTrips(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMapFieldWithNil(t *testing.T) {
|
||||
m := &MessageWithMap{
|
||||
MsgMapping: map[int64]*FloatingPoint{
|
||||
1: nil,
|
||||
},
|
||||
}
|
||||
b, err := Marshal(m)
|
||||
if err == nil {
|
||||
t.Fatalf("Marshal of bad map should have failed, got these bytes: %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOneof(t *testing.T) {
|
||||
m := &Communique{}
|
||||
b, err := Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal of empty message with oneof: %v", err)
|
||||
}
|
||||
if len(b) != 0 {
|
||||
t.Errorf("Marshal of empty message yielded too many bytes: %v", b)
|
||||
}
|
||||
|
||||
m = &Communique{
|
||||
Union: &Communique_Name{"Barry"},
|
||||
}
|
||||
|
||||
// Round-trip.
|
||||
b, err = Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal of message with oneof: %v", err)
|
||||
}
|
||||
if len(b) != 7 { // name tag/wire (1) + name len (1) + name (5)
|
||||
t.Errorf("Incorrect marshal of message with oneof: %v", b)
|
||||
}
|
||||
m.Reset()
|
||||
if err := Unmarshal(b, m); err != nil {
|
||||
t.Fatalf("Unmarshal of message with oneof: %v", err)
|
||||
}
|
||||
if x, ok := m.Union.(*Communique_Name); !ok || x.Name != "Barry" {
|
||||
t.Errorf("After round trip, Union = %+v", m.Union)
|
||||
}
|
||||
if name := m.GetName(); name != "Barry" {
|
||||
t.Errorf("After round trip, GetName = %q, want %q", name, "Barry")
|
||||
}
|
||||
|
||||
// Let's try with a message in the oneof.
|
||||
m.Union = &Communique_Msg{&Strings{StringField: String("deep deep string")}}
|
||||
b, err = Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal of message with oneof set to message: %v", err)
|
||||
}
|
||||
if len(b) != 20 { // msg tag/wire (1) + msg len (1) + msg (1 + 1 + 16)
|
||||
t.Errorf("Incorrect marshal of message with oneof set to message: %v", b)
|
||||
}
|
||||
m.Reset()
|
||||
if err := Unmarshal(b, m); err != nil {
|
||||
t.Fatalf("Unmarshal of message with oneof set to message: %v", err)
|
||||
}
|
||||
ss, ok := m.Union.(*Communique_Msg)
|
||||
if !ok || ss.Msg.GetStringField() != "deep deep string" {
|
||||
t.Errorf("After round trip with oneof set to message, Union = %+v", m.Union)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInefficientPackedBool(t *testing.T) {
|
||||
// https://github.com/golang/protobuf/issues/76
|
||||
inp := []byte{
|
||||
0x12, 0x02, // 0x12 = 2<<3|2; 2 bytes
|
||||
// Usually a bool should take a single byte,
|
||||
// but it is permitted to be any varint.
|
||||
0xb9, 0x30,
|
||||
}
|
||||
if err := Unmarshal(inp, new(MoreRepeated)); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmarks
|
||||
|
||||
func testMsg() *GoTest {
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Protocol buffer deep copy and merge.
|
||||
// TODO: MessageSet and RawMessage.
|
||||
// TODO: RawMessage.
|
||||
|
||||
package proto
|
||||
|
||||
|
@ -75,12 +75,13 @@ func Merge(dst, src Message) {
|
|||
}
|
||||
|
||||
func mergeStruct(out, in reflect.Value) {
|
||||
sprop := GetProperties(in.Type())
|
||||
for i := 0; i < in.NumField(); i++ {
|
||||
f := in.Type().Field(i)
|
||||
if strings.HasPrefix(f.Name, "XXX_") {
|
||||
continue
|
||||
}
|
||||
mergeAny(out.Field(i), in.Field(i))
|
||||
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
|
||||
}
|
||||
|
||||
if emIn, ok := in.Addr().Interface().(extensionsMap); ok {
|
||||
|
@ -103,7 +104,10 @@ func mergeStruct(out, in reflect.Value) {
|
|||
}
|
||||
}
|
||||
|
||||
func mergeAny(out, in reflect.Value) {
|
||||
// mergeAny performs a merge between two values of the same type.
|
||||
// viaPtr indicates whether the values were indirected through a pointer (implying proto2).
|
||||
// prop is set if this is a struct field (it may be nil).
|
||||
func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
|
||||
if in.Type() == protoMessageType {
|
||||
if !in.IsNil() {
|
||||
if out.IsNil() {
|
||||
|
@ -117,7 +121,21 @@ func mergeAny(out, in reflect.Value) {
|
|||
switch in.Kind() {
|
||||
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Uint32, reflect.Uint64:
|
||||
if !viaPtr && isProto3Zero(in) {
|
||||
return
|
||||
}
|
||||
out.Set(in)
|
||||
case reflect.Interface:
|
||||
// Probably a oneof field; copy non-nil values.
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
// Allocate destination if it is not set, or set to a different type.
|
||||
// Otherwise we will merge as normal.
|
||||
if out.IsNil() || out.Elem().Type() != in.Elem().Type() {
|
||||
out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
|
||||
}
|
||||
mergeAny(out.Elem(), in.Elem(), false, nil)
|
||||
case reflect.Map:
|
||||
if in.Len() == 0 {
|
||||
return
|
||||
|
@ -132,7 +150,7 @@ func mergeAny(out, in reflect.Value) {
|
|||
switch elemKind {
|
||||
case reflect.Ptr:
|
||||
val = reflect.New(in.Type().Elem().Elem())
|
||||
mergeAny(val, in.MapIndex(key))
|
||||
mergeAny(val, in.MapIndex(key), false, nil)
|
||||
case reflect.Slice:
|
||||
val = in.MapIndex(key)
|
||||
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
|
||||
|
@ -148,13 +166,21 @@ func mergeAny(out, in reflect.Value) {
|
|||
if out.IsNil() {
|
||||
out.Set(reflect.New(in.Elem().Type()))
|
||||
}
|
||||
mergeAny(out.Elem(), in.Elem())
|
||||
mergeAny(out.Elem(), in.Elem(), true, nil)
|
||||
case reflect.Slice:
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
if in.Type().Elem().Kind() == reflect.Uint8 {
|
||||
// []byte is a scalar bytes field, not a repeated field.
|
||||
|
||||
// Edge case: if this is in a proto3 message, a zero length
|
||||
// bytes field is considered the zero value, and should not
|
||||
// be merged.
|
||||
if prop != nil && prop.proto3 && in.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Make a deep copy.
|
||||
// Append to []byte{} instead of []byte(nil) so that we never end up
|
||||
// with a nil result.
|
||||
|
@ -172,7 +198,7 @@ func mergeAny(out, in reflect.Value) {
|
|||
default:
|
||||
for i := 0; i < n; i++ {
|
||||
x := reflect.Indirect(reflect.New(in.Type().Elem()))
|
||||
mergeAny(x, in.Index(i))
|
||||
mergeAny(x, in.Index(i), false, nil)
|
||||
out.Set(reflect.Append(out, x))
|
||||
}
|
||||
}
|
||||
|
@ -189,7 +215,7 @@ func mergeExtension(out, in map[int32]Extension) {
|
|||
eOut := Extension{desc: eIn.desc}
|
||||
if eIn.value != nil {
|
||||
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
|
||||
mergeAny(v, reflect.ValueOf(eIn.value))
|
||||
mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
|
||||
eOut.value = v.Interface()
|
||||
}
|
||||
if eIn.enc != nil {
|
||||
|
|
|
@ -35,6 +35,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
|
||||
|
||||
proto3pb "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto/proto3_proto"
|
||||
pb "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto/testdata"
|
||||
)
|
||||
|
||||
|
@ -213,6 +215,45 @@ var mergeTests = []struct {
|
|||
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
|
||||
},
|
||||
},
|
||||
// proto3 shouldn't merge zero values,
|
||||
// in the same way that proto2 shouldn't merge nils.
|
||||
{
|
||||
src: &proto3pb.Message{
|
||||
Name: "Aaron",
|
||||
Data: []byte(""), // zero value, but not nil
|
||||
},
|
||||
dst: &proto3pb.Message{
|
||||
HeightInCm: 176,
|
||||
Data: []byte("texas!"),
|
||||
},
|
||||
want: &proto3pb.Message{
|
||||
Name: "Aaron",
|
||||
HeightInCm: 176,
|
||||
Data: []byte("texas!"),
|
||||
},
|
||||
},
|
||||
// Oneof fields should merge by assignment.
|
||||
{
|
||||
src: &pb.Communique{
|
||||
Union: &pb.Communique_Number{Number: 41},
|
||||
},
|
||||
dst: &pb.Communique{
|
||||
Union: &pb.Communique_Name{Name: "Bobby Tables"},
|
||||
},
|
||||
want: &pb.Communique{
|
||||
Union: &pb.Communique_Number{Number: 41},
|
||||
},
|
||||
},
|
||||
// Oneof nil is the same as not set.
|
||||
{
|
||||
src: &pb.Communique{},
|
||||
dst: &pb.Communique{
|
||||
Union: &pb.Communique_Name{Name: "Bobby Tables"},
|
||||
},
|
||||
want: &pb.Communique{
|
||||
Union: &pb.Communique_Name{Name: "Bobby Tables"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
|
|
|
@ -46,6 +46,10 @@ import (
|
|||
// errOverflow is returned when an integer is too large to be represented.
|
||||
var errOverflow = errors.New("proto: integer overflow")
|
||||
|
||||
// ErrInternalBadWireType is returned by generated code when an incorrect
|
||||
// wire type is encountered. It does not get returned to user code.
|
||||
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
|
||||
|
||||
// The fundamental decoders that interpret bytes on the wire.
|
||||
// Those that take integer types all return uint64 and are
|
||||
// therefore of type valueDecoder.
|
||||
|
@ -314,6 +318,24 @@ func UnmarshalMerge(buf []byte, pb Message) error {
|
|||
return NewBuffer(buf).Unmarshal(pb)
|
||||
}
|
||||
|
||||
// DecodeMessage reads a count-delimited message from the Buffer.
|
||||
func (p *Buffer) DecodeMessage(pb Message) error {
|
||||
enc, err := p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return NewBuffer(enc).Unmarshal(pb)
|
||||
}
|
||||
|
||||
// DecodeGroup reads a tag-delimited group from the Buffer.
|
||||
func (p *Buffer) DecodeGroup(pb Message) error {
|
||||
typ, base, err := getbase(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
|
||||
}
|
||||
|
||||
// Unmarshal parses the protocol buffer representation in the
|
||||
// Buffer and places the decoded result in pb. If the struct
|
||||
// underlying pb does not match the data in the buffer, the results can be
|
||||
|
@ -370,11 +392,11 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
|
|||
if prop.extendable {
|
||||
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
|
||||
if err = o.skip(st, tag, wire); err == nil {
|
||||
if ee, ok := e.(extensionsMap); ok {
|
||||
if ee, eok := e.(extensionsMap); eok {
|
||||
ext := ee.ExtensionMap()[int32(tag)] // may be missing
|
||||
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
|
||||
ee.ExtensionMap()[int32(tag)] = ext
|
||||
} else if ee, ok := e.(extensionsBytes); ok {
|
||||
} else if ee, eok := e.(extensionsBytes); eok {
|
||||
ext := ee.GetExtensions()
|
||||
*ext = append(*ext, o.buf[oi:o.index]...)
|
||||
}
|
||||
|
@ -382,6 +404,20 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
|
|||
continue
|
||||
}
|
||||
}
|
||||
// Maybe it's a oneof?
|
||||
if prop.oneofUnmarshaler != nil {
|
||||
m := structPointer_Interface(base, st).(Message)
|
||||
// First return value indicates whether tag is a oneof field.
|
||||
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
|
||||
if err == ErrInternalBadWireType {
|
||||
// Map the error to something more descriptive.
|
||||
// Do the formatting here to save generated code space.
|
||||
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
|
||||
}
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
|
||||
continue
|
||||
}
|
||||
|
@ -566,9 +602,13 @@ func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error
|
|||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded bools
|
||||
fin := o.index + nb
|
||||
if fin < o.index {
|
||||
return errOverflow
|
||||
}
|
||||
|
||||
y := *v
|
||||
for i := 0; i < nb; i++ {
|
||||
for o.index < fin {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -680,7 +720,7 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
|
|||
oi := o.index // index at the end of this map entry
|
||||
o.index -= len(raw) // move buffer back to start of map entry
|
||||
|
||||
mptr := structPointer_Map(base, p.field, p.mtype) // *map[K]V
|
||||
mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V
|
||||
if mptr.Elem().IsNil() {
|
||||
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
|
||||
}
|
||||
|
@ -732,8 +772,14 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
|
|||
return fmt.Errorf("proto: bad map data tag %d", raw[0])
|
||||
}
|
||||
}
|
||||
keyelem, valelem := keyptr.Elem(), valptr.Elem()
|
||||
if !keyelem.IsValid() || !valelem.IsValid() {
|
||||
// We did not decode the key or the value in the map entry.
|
||||
// Either way, it's an invalid map entry.
|
||||
return fmt.Errorf("proto: bad map data: missing key/val")
|
||||
}
|
||||
|
||||
v.SetMapIndex(keyptr.Elem(), valptr.Elem())
|
||||
v.SetMapIndex(keyelem, valelem)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -228,6 +228,20 @@ func Marshal(pb Message) ([]byte, error) {
|
|||
return p.buf, err
|
||||
}
|
||||
|
||||
// EncodeMessage writes the protocol buffer to the Buffer,
|
||||
// prefixed by a varint-encoded length.
|
||||
func (p *Buffer) EncodeMessage(pb Message) error {
|
||||
t, base, err := getbase(pb)
|
||||
if structPointer_IsNil(base) {
|
||||
return ErrNil
|
||||
}
|
||||
if err == nil {
|
||||
var state errorState
|
||||
err = p.enc_len_struct(GetProperties(t.Elem()), base, &state)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Marshal takes the protocol buffer
|
||||
// and encodes it into the wire format, writing the result to the
|
||||
// Buffer.
|
||||
|
@ -318,7 +332,7 @@ func size_bool(p *Properties, base structPointer) int {
|
|||
|
||||
func size_proto3_bool(p *Properties, base structPointer) int {
|
||||
v := *structPointer_BoolVal(base, p.field)
|
||||
if !v {
|
||||
if !v && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
return len(p.tagcode) + 1 // each bool takes exactly one byte
|
||||
|
@ -361,7 +375,7 @@ func size_int32(p *Properties, base structPointer) (n int) {
|
|||
func size_proto3_int32(p *Properties, base structPointer) (n int) {
|
||||
v := structPointer_Word32Val(base, p.field)
|
||||
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
|
||||
if x == 0 {
|
||||
if x == 0 && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
n += len(p.tagcode)
|
||||
|
@ -407,7 +421,7 @@ func size_uint32(p *Properties, base structPointer) (n int) {
|
|||
func size_proto3_uint32(p *Properties, base structPointer) (n int) {
|
||||
v := structPointer_Word32Val(base, p.field)
|
||||
x := word32Val_Get(v)
|
||||
if x == 0 {
|
||||
if x == 0 && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
n += len(p.tagcode)
|
||||
|
@ -452,7 +466,7 @@ func size_int64(p *Properties, base structPointer) (n int) {
|
|||
func size_proto3_int64(p *Properties, base structPointer) (n int) {
|
||||
v := structPointer_Word64Val(base, p.field)
|
||||
x := word64Val_Get(v)
|
||||
if x == 0 {
|
||||
if x == 0 && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
n += len(p.tagcode)
|
||||
|
@ -495,7 +509,7 @@ func size_string(p *Properties, base structPointer) (n int) {
|
|||
|
||||
func size_proto3_string(p *Properties, base structPointer) (n int) {
|
||||
v := *structPointer_StringVal(base, p.field)
|
||||
if v == "" {
|
||||
if v == "" && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
n += len(p.tagcode)
|
||||
|
@ -529,7 +543,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
|
|||
}
|
||||
o.buf = append(o.buf, p.tagcode...)
|
||||
o.EncodeRawBytes(data)
|
||||
return nil
|
||||
return state.err
|
||||
}
|
||||
|
||||
o.buf = append(o.buf, p.tagcode...)
|
||||
|
@ -667,7 +681,7 @@ func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error
|
|||
|
||||
func size_slice_byte(p *Properties, base structPointer) (n int) {
|
||||
s := *structPointer_Bytes(base, p.field)
|
||||
if s == nil {
|
||||
if s == nil && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
n += len(p.tagcode)
|
||||
|
@ -677,7 +691,7 @@ func size_slice_byte(p *Properties, base structPointer) (n int) {
|
|||
|
||||
func size_proto3_slice_byte(p *Properties, base structPointer) (n int) {
|
||||
s := *structPointer_Bytes(base, p.field)
|
||||
if len(s) == 0 {
|
||||
if len(s) == 0 && !p.oneof {
|
||||
return 0
|
||||
}
|
||||
n += len(p.tagcode)
|
||||
|
@ -1084,7 +1098,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
|
|||
repeated MapFieldEntry map_field = N;
|
||||
*/
|
||||
|
||||
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
|
||||
v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V
|
||||
if v.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -1101,11 +1115,15 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
keys := v.MapKeys()
|
||||
sort.Sort(mapKeys(keys))
|
||||
for _, key := range keys {
|
||||
// Don't sort map keys. It is not required by the spec, and C++ doesn't do it.
|
||||
for _, key := range v.MapKeys() {
|
||||
val := v.MapIndex(key)
|
||||
|
||||
// The only illegal map entry values are nil message pointers.
|
||||
if val.Kind() == reflect.Ptr && val.IsNil() {
|
||||
return errors.New("proto: map has nil element")
|
||||
}
|
||||
|
||||
keycopy.Set(key)
|
||||
valcopy.Set(val)
|
||||
|
||||
|
@ -1118,7 +1136,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
|
|||
}
|
||||
|
||||
func size_new_map(p *Properties, base structPointer) int {
|
||||
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
|
||||
v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V
|
||||
|
||||
keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
|
||||
|
||||
|
@ -1196,6 +1214,14 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Do oneof fields.
|
||||
if prop.oneofMarshaler != nil {
|
||||
m := structPointer_Interface(base, prop.stype).(Message)
|
||||
if err := prop.oneofMarshaler(m, o); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Add unrecognized fields at the end.
|
||||
if prop.unrecField.IsValid() {
|
||||
v := *structPointer_Bytes(base, prop.unrecField)
|
||||
|
@ -1221,6 +1247,27 @@ func size_struct(prop *StructProperties, base structPointer) (n int) {
|
|||
n += len(v)
|
||||
}
|
||||
|
||||
// Factor in any oneof fields.
|
||||
// TODO: This could be faster and use less reflection.
|
||||
if prop.oneofMarshaler != nil {
|
||||
sv := reflect.ValueOf(structPointer_Interface(base, prop.stype)).Elem()
|
||||
for i := 0; i < prop.stype.NumField(); i++ {
|
||||
fv := sv.Field(i)
|
||||
if fv.Kind() != reflect.Interface || fv.IsNil() {
|
||||
continue
|
||||
}
|
||||
if prop.stype.Field(i).Tag.Get("protobuf_oneof") == "" {
|
||||
continue
|
||||
}
|
||||
spv := fv.Elem() // interface -> *T
|
||||
sv := spv.Elem() // *T -> T
|
||||
sf := sv.Type().Field(0) // StructField inside T
|
||||
var prop Properties
|
||||
prop.Init(sf.Type, "whatever", sf.Tag.Get("protobuf"), &sf)
|
||||
n += prop.size(&prop, toStructPointer(spv))
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Protocol buffer comparison.
|
||||
// TODO: MessageSet.
|
||||
|
||||
package proto
|
||||
|
||||
|
@ -154,6 +153,17 @@ func equalAny(v1, v2 reflect.Value) bool {
|
|||
return v1.Float() == v2.Float()
|
||||
case reflect.Int32, reflect.Int64:
|
||||
return v1.Int() == v2.Int()
|
||||
case reflect.Interface:
|
||||
// Probably a oneof field; compare the inner values.
|
||||
n1, n2 := v1.IsNil(), v2.IsNil()
|
||||
if n1 || n2 {
|
||||
return n1 == n2
|
||||
}
|
||||
e1, e2 := v1.Elem(), v2.Elem()
|
||||
if e1.Type() != e2.Type() {
|
||||
return false
|
||||
}
|
||||
return equalAny(e1, e2)
|
||||
case reflect.Map:
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
|
|
|
@ -180,6 +180,24 @@ var EqualTests = []struct {
|
|||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"oneof same",
|
||||
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
|
||||
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"oneof one nil",
|
||||
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
|
||||
&pb.Communique{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"oneof different",
|
||||
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
|
||||
&pb.Communique{Union: &pb.Communique_Name{Name: "Bobby Tables"}},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
func TestEqual(t *testing.T) {
|
||||
|
|
|
@ -311,7 +311,9 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
|
|||
emap := epb.ExtensionMap()
|
||||
e, ok := emap[extension.Field]
|
||||
if !ok {
|
||||
return nil, ErrMissingExtension
|
||||
// defaultExtensionValue returns the default value or
|
||||
// ErrMissingExtension if there is no default.
|
||||
return defaultExtensionValue(extension)
|
||||
}
|
||||
if e.value != nil {
|
||||
// Already decoded. Check the descriptor, though.
|
||||
|
@ -356,10 +358,46 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
|
|||
}
|
||||
o += n + l
|
||||
}
|
||||
return defaultExtensionValue(extension)
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// defaultExtensionValue returns the default value for extension.
|
||||
// If no default for an extension is defined ErrMissingExtension is returned.
|
||||
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
|
||||
t := reflect.TypeOf(extension.ExtensionType)
|
||||
props := extensionProperties(extension)
|
||||
|
||||
sf, _, err := fieldDefault(t, props)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sf == nil || sf.value == nil {
|
||||
// There is no default value.
|
||||
return nil, ErrMissingExtension
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Ptr {
|
||||
// We do not need to return a Ptr, we can directly return sf.value.
|
||||
return sf.value, nil
|
||||
}
|
||||
|
||||
// We need to return an interface{} that is a pointer to sf.value.
|
||||
value := reflect.New(t).Elem()
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
if sf.kind == reflect.Int32 {
|
||||
// We may have an int32 or an enum, but the underlying data is int32.
|
||||
// Since we can't set an int32 into a non int32 reflect.value directly
|
||||
// set it as a int32.
|
||||
value.Elem().SetInt(int64(sf.value.(int32)))
|
||||
} else {
|
||||
value.Elem().Set(reflect.ValueOf(sf.value))
|
||||
}
|
||||
return value.Interface(), nil
|
||||
}
|
||||
|
||||
// decodeExtension decodes an extension encoded in b.
|
||||
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
|
||||
o := NewBuffer(b)
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
package proto_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
|
||||
|
@ -93,6 +95,143 @@ func TestGetExtensionStability(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetExtensionDefaults(t *testing.T) {
|
||||
var setFloat64 float64 = 1
|
||||
var setFloat32 float32 = 2
|
||||
var setInt32 int32 = 3
|
||||
var setInt64 int64 = 4
|
||||
var setUint32 uint32 = 5
|
||||
var setUint64 uint64 = 6
|
||||
var setBool = true
|
||||
var setBool2 = false
|
||||
var setString = "Goodnight string"
|
||||
var setBytes = []byte("Goodnight bytes")
|
||||
var setEnum = pb.DefaultsMessage_TWO
|
||||
|
||||
type testcase struct {
|
||||
ext *proto.ExtensionDesc // Extension we are testing.
|
||||
want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
|
||||
def interface{} // Expected value of extension after ClearExtension().
|
||||
}
|
||||
tests := []testcase{
|
||||
{pb.E_NoDefaultDouble, setFloat64, nil},
|
||||
{pb.E_NoDefaultFloat, setFloat32, nil},
|
||||
{pb.E_NoDefaultInt32, setInt32, nil},
|
||||
{pb.E_NoDefaultInt64, setInt64, nil},
|
||||
{pb.E_NoDefaultUint32, setUint32, nil},
|
||||
{pb.E_NoDefaultUint64, setUint64, nil},
|
||||
{pb.E_NoDefaultSint32, setInt32, nil},
|
||||
{pb.E_NoDefaultSint64, setInt64, nil},
|
||||
{pb.E_NoDefaultFixed32, setUint32, nil},
|
||||
{pb.E_NoDefaultFixed64, setUint64, nil},
|
||||
{pb.E_NoDefaultSfixed32, setInt32, nil},
|
||||
{pb.E_NoDefaultSfixed64, setInt64, nil},
|
||||
{pb.E_NoDefaultBool, setBool, nil},
|
||||
{pb.E_NoDefaultBool, setBool2, nil},
|
||||
{pb.E_NoDefaultString, setString, nil},
|
||||
{pb.E_NoDefaultBytes, setBytes, nil},
|
||||
{pb.E_NoDefaultEnum, setEnum, nil},
|
||||
{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
|
||||
{pb.E_DefaultFloat, setFloat32, float32(3.14)},
|
||||
{pb.E_DefaultInt32, setInt32, int32(42)},
|
||||
{pb.E_DefaultInt64, setInt64, int64(43)},
|
||||
{pb.E_DefaultUint32, setUint32, uint32(44)},
|
||||
{pb.E_DefaultUint64, setUint64, uint64(45)},
|
||||
{pb.E_DefaultSint32, setInt32, int32(46)},
|
||||
{pb.E_DefaultSint64, setInt64, int64(47)},
|
||||
{pb.E_DefaultFixed32, setUint32, uint32(48)},
|
||||
{pb.E_DefaultFixed64, setUint64, uint64(49)},
|
||||
{pb.E_DefaultSfixed32, setInt32, int32(50)},
|
||||
{pb.E_DefaultSfixed64, setInt64, int64(51)},
|
||||
{pb.E_DefaultBool, setBool, true},
|
||||
{pb.E_DefaultBool, setBool2, true},
|
||||
{pb.E_DefaultString, setString, "Hello, string"},
|
||||
{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
|
||||
{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
|
||||
}
|
||||
|
||||
checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
|
||||
val, err := proto.GetExtension(msg, test.ext)
|
||||
if err != nil {
|
||||
if valWant != nil {
|
||||
return fmt.Errorf("GetExtension(): %s", err)
|
||||
}
|
||||
if want := proto.ErrMissingExtension; err != want {
|
||||
return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// All proto2 extension values are either a pointer to a value or a slice of values.
|
||||
ty := reflect.TypeOf(val)
|
||||
tyWant := reflect.TypeOf(test.ext.ExtensionType)
|
||||
if got, want := ty, tyWant; got != want {
|
||||
return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
|
||||
}
|
||||
tye := ty.Elem()
|
||||
tyeWant := tyWant.Elem()
|
||||
if got, want := tye, tyeWant; got != want {
|
||||
return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
|
||||
}
|
||||
|
||||
// Check the name of the type of the value.
|
||||
// If it is an enum it will be type int32 with the name of the enum.
|
||||
if got, want := tye.Name(), tye.Name(); got != want {
|
||||
return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
|
||||
}
|
||||
|
||||
// Check that value is what we expect.
|
||||
// If we have a pointer in val, get the value it points to.
|
||||
valExp := val
|
||||
if ty.Kind() == reflect.Ptr {
|
||||
valExp = reflect.ValueOf(val).Elem().Interface()
|
||||
}
|
||||
if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
|
||||
return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
setTo := func(test testcase) interface{} {
|
||||
setTo := reflect.ValueOf(test.want)
|
||||
if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
|
||||
setTo = reflect.New(typ).Elem()
|
||||
setTo.Set(reflect.New(setTo.Type().Elem()))
|
||||
setTo.Elem().Set(reflect.ValueOf(test.want))
|
||||
}
|
||||
return setTo.Interface()
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
msg := &pb.DefaultsMessage{}
|
||||
name := test.ext.Name
|
||||
|
||||
// Check the initial value.
|
||||
if err := checkVal(test, msg, test.def); err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
}
|
||||
|
||||
// Set the per-type value and check value.
|
||||
name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
|
||||
if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
|
||||
t.Errorf("%s: SetExtension(): %v", name, err)
|
||||
continue
|
||||
}
|
||||
if err := checkVal(test, msg, test.want); err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set and check the value.
|
||||
name += " (cleared)"
|
||||
proto.ClearExtension(msg, test.ext)
|
||||
if err := checkVal(test, msg, test.def); err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionsRoundTrip(t *testing.T) {
|
||||
msg := &pb.MyMessage{}
|
||||
ext1 := &pb.Ext{
|
||||
|
|
|
@ -30,179 +30,230 @@
|
|||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
/*
|
||||
Package proto converts data structures to and from the wire format of
|
||||
protocol buffers. It works in concert with the Go source code generated
|
||||
for .proto files by the protocol compiler.
|
||||
Package proto converts data structures to and from the wire format of
|
||||
protocol buffers. It works in concert with the Go source code generated
|
||||
for .proto files by the protocol compiler.
|
||||
|
||||
A summary of the properties of the protocol buffer interface
|
||||
for a protocol buffer variable v:
|
||||
A summary of the properties of the protocol buffer interface
|
||||
for a protocol buffer variable v:
|
||||
|
||||
- Names are turned from camel_case to CamelCase for export.
|
||||
- There are no methods on v to set fields; just treat
|
||||
them as structure fields.
|
||||
- There are getters that return a field's value if set,
|
||||
and return the field's default value if unset.
|
||||
The getters work even if the receiver is a nil message.
|
||||
- The zero value for a struct is its correct initialization state.
|
||||
All desired fields must be set before marshaling.
|
||||
- A Reset() method will restore a protobuf struct to its zero state.
|
||||
- Non-repeated fields are pointers to the values; nil means unset.
|
||||
That is, optional or required field int32 f becomes F *int32.
|
||||
- Repeated fields are slices.
|
||||
- Helper functions are available to aid the setting of fields.
|
||||
msg.Foo = proto.String("hello") // set field
|
||||
- Constants are defined to hold the default values of all fields that
|
||||
have them. They have the form Default_StructName_FieldName.
|
||||
Because the getter methods handle defaulted values,
|
||||
direct use of these constants should be rare.
|
||||
- Enums are given type names and maps from names to values.
|
||||
Enum values are prefixed by the enclosing message's name, or by the
|
||||
enum's type name if it is a top-level enum. Enum types have a String
|
||||
method, and a Enum method to assist in message construction.
|
||||
- Nested messages, groups and enums have type names prefixed with the name of
|
||||
the surrounding message type.
|
||||
- Extensions are given descriptor names that start with E_,
|
||||
followed by an underscore-delimited list of the nested messages
|
||||
that contain it (if any) followed by the CamelCased name of the
|
||||
extension field itself. HasExtension, ClearExtension, GetExtension
|
||||
and SetExtension are functions for manipulating extensions.
|
||||
- Marshal and Unmarshal are functions to encode and decode the wire format.
|
||||
- Names are turned from camel_case to CamelCase for export.
|
||||
- There are no methods on v to set fields; just treat
|
||||
them as structure fields.
|
||||
- There are getters that return a field's value if set,
|
||||
and return the field's default value if unset.
|
||||
The getters work even if the receiver is a nil message.
|
||||
- The zero value for a struct is its correct initialization state.
|
||||
All desired fields must be set before marshaling.
|
||||
- A Reset() method will restore a protobuf struct to its zero state.
|
||||
- Non-repeated fields are pointers to the values; nil means unset.
|
||||
That is, optional or required field int32 f becomes F *int32.
|
||||
- Repeated fields are slices.
|
||||
- Helper functions are available to aid the setting of fields.
|
||||
msg.Foo = proto.String("hello") // set field
|
||||
- Constants are defined to hold the default values of all fields that
|
||||
have them. They have the form Default_StructName_FieldName.
|
||||
Because the getter methods handle defaulted values,
|
||||
direct use of these constants should be rare.
|
||||
- Enums are given type names and maps from names to values.
|
||||
Enum values are prefixed by the enclosing message's name, or by the
|
||||
enum's type name if it is a top-level enum. Enum types have a String
|
||||
method, and a Enum method to assist in message construction.
|
||||
- Nested messages, groups and enums have type names prefixed with the name of
|
||||
the surrounding message type.
|
||||
- Extensions are given descriptor names that start with E_,
|
||||
followed by an underscore-delimited list of the nested messages
|
||||
that contain it (if any) followed by the CamelCased name of the
|
||||
extension field itself. HasExtension, ClearExtension, GetExtension
|
||||
and SetExtension are functions for manipulating extensions.
|
||||
- Oneof field sets are given a single field in their message,
|
||||
with distinguished wrapper types for each possible field value.
|
||||
- Marshal and Unmarshal are functions to encode and decode the wire format.
|
||||
|
||||
The simplest way to describe this is to see an example.
|
||||
Given file test.proto, containing
|
||||
The simplest way to describe this is to see an example.
|
||||
Given file test.proto, containing
|
||||
|
||||
package example;
|
||||
package example;
|
||||
|
||||
enum FOO { X = 17; }
|
||||
enum FOO { X = 17; }
|
||||
|
||||
message Test {
|
||||
required string label = 1;
|
||||
optional int32 type = 2 [default=77];
|
||||
repeated int64 reps = 3;
|
||||
optional group OptionalGroup = 4 {
|
||||
required string RequiredField = 5;
|
||||
}
|
||||
message Test {
|
||||
required string label = 1;
|
||||
optional int32 type = 2 [default=77];
|
||||
repeated int64 reps = 3;
|
||||
optional group OptionalGroup = 4 {
|
||||
required string RequiredField = 5;
|
||||
}
|
||||
oneof union {
|
||||
int32 number = 6;
|
||||
string name = 7;
|
||||
}
|
||||
}
|
||||
|
||||
The resulting file, test.pb.go, is:
|
||||
|
||||
package example
|
||||
|
||||
import proto "github.com/gogo/protobuf/proto"
|
||||
import math "math"
|
||||
|
||||
type FOO int32
|
||||
const (
|
||||
FOO_X FOO = 17
|
||||
)
|
||||
var FOO_name = map[int32]string{
|
||||
17: "X",
|
||||
}
|
||||
var FOO_value = map[string]int32{
|
||||
"X": 17,
|
||||
}
|
||||
|
||||
func (x FOO) Enum() *FOO {
|
||||
p := new(FOO)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
func (x FOO) String() string {
|
||||
return proto.EnumName(FOO_name, int32(x))
|
||||
}
|
||||
func (x *FOO) UnmarshalJSON(data []byte) error {
|
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = FOO(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
The resulting file, test.pb.go, is:
|
||||
type Test struct {
|
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
|
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
|
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
|
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
|
||||
// Types that are valid to be assigned to Union:
|
||||
// *Test_Number
|
||||
// *Test_Name
|
||||
Union isTest_Union `protobuf_oneof:"union"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
func (m *Test) Reset() { *m = Test{} }
|
||||
func (m *Test) String() string { return proto.CompactTextString(m) }
|
||||
func (*Test) ProtoMessage() {}
|
||||
|
||||
package example
|
||||
type isTest_Union interface {
|
||||
isTest_Union()
|
||||
}
|
||||
|
||||
import proto "github.com/gogo/protobuf/proto"
|
||||
import math "math"
|
||||
type Test_Number struct {
|
||||
Number int32 `protobuf:"varint,6,opt,name=number"`
|
||||
}
|
||||
type Test_Name struct {
|
||||
Name string `protobuf:"bytes,7,opt,name=name"`
|
||||
}
|
||||
|
||||
type FOO int32
|
||||
const (
|
||||
FOO_X FOO = 17
|
||||
)
|
||||
var FOO_name = map[int32]string{
|
||||
17: "X",
|
||||
func (*Test_Number) isTest_Union() {}
|
||||
func (*Test_Name) isTest_Union() {}
|
||||
|
||||
func (m *Test) GetUnion() isTest_Union {
|
||||
if m != nil {
|
||||
return m.Union
|
||||
}
|
||||
var FOO_value = map[string]int32{
|
||||
"X": 17,
|
||||
return nil
|
||||
}
|
||||
const Default_Test_Type int32 = 77
|
||||
|
||||
func (m *Test) GetLabel() string {
|
||||
if m != nil && m.Label != nil {
|
||||
return *m.Label
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x FOO) Enum() *FOO {
|
||||
p := new(FOO)
|
||||
*p = x
|
||||
return p
|
||||
func (m *Test) GetType() int32 {
|
||||
if m != nil && m.Type != nil {
|
||||
return *m.Type
|
||||
}
|
||||
func (x FOO) String() string {
|
||||
return proto.EnumName(FOO_name, int32(x))
|
||||
return Default_Test_Type
|
||||
}
|
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
|
||||
if m != nil {
|
||||
return m.Optionalgroup
|
||||
}
|
||||
func (x *FOO) UnmarshalJSON(data []byte) error {
|
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = FOO(value)
|
||||
return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test_OptionalGroup struct {
|
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
|
||||
}
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
|
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
|
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string {
|
||||
if m != nil && m.RequiredField != nil {
|
||||
return *m.RequiredField
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Test struct {
|
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
|
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
|
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
|
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
func (m *Test) GetNumber() int32 {
|
||||
if x, ok := m.GetUnion().(*Test_Number); ok {
|
||||
return x.Number
|
||||
}
|
||||
func (m *Test) Reset() { *m = Test{} }
|
||||
func (m *Test) String() string { return proto.CompactTextString(m) }
|
||||
func (*Test) ProtoMessage() {}
|
||||
const Default_Test_Type int32 = 77
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Test) GetLabel() string {
|
||||
if m != nil && m.Label != nil {
|
||||
return *m.Label
|
||||
}
|
||||
return ""
|
||||
func (m *Test) GetName() string {
|
||||
if x, ok := m.GetUnion().(*Test_Name); ok {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Test) GetType() int32 {
|
||||
if m != nil && m.Type != nil {
|
||||
return *m.Type
|
||||
}
|
||||
return Default_Test_Type
|
||||
func init() {
|
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
|
||||
}
|
||||
|
||||
To create and play with a Test object:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
pb "./example.pb"
|
||||
)
|
||||
|
||||
func main() {
|
||||
test := &pb.Test{
|
||||
Label: proto.String("hello"),
|
||||
Type: proto.Int32(17),
|
||||
Optionalgroup: &pb.Test_OptionalGroup{
|
||||
RequiredField: proto.String("good bye"),
|
||||
},
|
||||
Union: &pb.Test_Name{"fred"},
|
||||
}
|
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
|
||||
if m != nil {
|
||||
return m.Optionalgroup
|
||||
}
|
||||
return nil
|
||||
data, err := proto.Marshal(test)
|
||||
if err != nil {
|
||||
log.Fatal("marshaling error: ", err)
|
||||
}
|
||||
|
||||
type Test_OptionalGroup struct {
|
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
|
||||
newTest := &pb.Test{}
|
||||
err = proto.Unmarshal(data, newTest)
|
||||
if err != nil {
|
||||
log.Fatal("unmarshaling error: ", err)
|
||||
}
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
|
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
|
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string {
|
||||
if m != nil && m.RequiredField != nil {
|
||||
return *m.RequiredField
|
||||
}
|
||||
return ""
|
||||
// Now test and newTest contain the same data.
|
||||
if test.GetLabel() != newTest.GetLabel() {
|
||||
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
|
||||
}
|
||||
|
||||
To create and play with a Test object:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
pb "./example.pb"
|
||||
)
|
||||
|
||||
func main() {
|
||||
test := &pb.Test{
|
||||
Label: proto.String("hello"),
|
||||
Type: proto.Int32(17),
|
||||
Optionalgroup: &pb.Test_OptionalGroup{
|
||||
RequiredField: proto.String("good bye"),
|
||||
},
|
||||
}
|
||||
data, err := proto.Marshal(test)
|
||||
if err != nil {
|
||||
log.Fatal("marshaling error: ", err)
|
||||
}
|
||||
newTest := &pb.Test{}
|
||||
err = proto.Unmarshal(data, newTest)
|
||||
if err != nil {
|
||||
log.Fatal("unmarshaling error: ", err)
|
||||
}
|
||||
// Now test and newTest contain the same data.
|
||||
if test.GetLabel() != newTest.GetLabel() {
|
||||
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
|
||||
}
|
||||
// etc.
|
||||
// Use a type switch to determine which oneof was set.
|
||||
switch u := test.Union.(type) {
|
||||
case *pb.Test_Number: // u.Number contains the number.
|
||||
case *pb.Test_Name: // u.Name contains the string.
|
||||
}
|
||||
// etc.
|
||||
}
|
||||
*/
|
||||
package proto
|
||||
|
||||
|
@ -211,6 +262,7 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
@ -385,13 +437,13 @@ func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32,
|
|||
|
||||
// DebugPrint dumps the encoded data in b in a debugging format with a header
|
||||
// including the string s. Used in testing but made available for general debugging.
|
||||
func (o *Buffer) DebugPrint(s string, b []byte) {
|
||||
func (p *Buffer) DebugPrint(s string, b []byte) {
|
||||
var u uint64
|
||||
|
||||
obuf := o.buf
|
||||
index := o.index
|
||||
o.buf = b
|
||||
o.index = 0
|
||||
obuf := p.buf
|
||||
index := p.index
|
||||
p.buf = b
|
||||
p.index = 0
|
||||
depth := 0
|
||||
|
||||
fmt.Printf("\n--- %s ---\n", s)
|
||||
|
@ -402,12 +454,12 @@ out:
|
|||
fmt.Print(" ")
|
||||
}
|
||||
|
||||
index := o.index
|
||||
if index == len(o.buf) {
|
||||
index := p.index
|
||||
if index == len(p.buf) {
|
||||
break
|
||||
}
|
||||
|
||||
op, err := o.DecodeVarint()
|
||||
op, err := p.DecodeVarint()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: fetching op err %v\n", index, err)
|
||||
break out
|
||||
|
@ -424,7 +476,7 @@ out:
|
|||
case WireBytes:
|
||||
var r []byte
|
||||
|
||||
r, err = o.DecodeRawBytes(false)
|
||||
r, err = p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
break out
|
||||
}
|
||||
|
@ -445,7 +497,7 @@ out:
|
|||
fmt.Printf("\n")
|
||||
|
||||
case WireFixed32:
|
||||
u, err = o.DecodeFixed32()
|
||||
u, err = p.DecodeFixed32()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err)
|
||||
break out
|
||||
|
@ -453,16 +505,15 @@ out:
|
|||
fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u)
|
||||
|
||||
case WireFixed64:
|
||||
u, err = o.DecodeFixed64()
|
||||
u, err = p.DecodeFixed64()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
|
||||
break
|
||||
|
||||
case WireVarint:
|
||||
u, err = o.DecodeVarint()
|
||||
u, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err)
|
||||
break out
|
||||
|
@ -470,30 +521,22 @@ out:
|
|||
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
|
||||
|
||||
case WireStartGroup:
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d start err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d start\n", index, tag)
|
||||
depth++
|
||||
|
||||
case WireEndGroup:
|
||||
depth--
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d end err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d end\n", index, tag)
|
||||
}
|
||||
}
|
||||
|
||||
if depth != 0 {
|
||||
fmt.Printf("%3d: start-end not balanced %d\n", o.index, depth)
|
||||
fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
o.buf = obuf
|
||||
o.index = index
|
||||
p.buf = obuf
|
||||
p.index = index
|
||||
}
|
||||
|
||||
// SetDefaults sets unset protocol buffer fields to their default values.
|
||||
|
@ -668,123 +711,173 @@ func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
|
|||
}
|
||||
ft := t.Field(fi).Type
|
||||
|
||||
var canHaveDefault, nestedMessage bool
|
||||
switch ft.Kind() {
|
||||
case reflect.Ptr:
|
||||
if ft.Elem().Kind() == reflect.Struct {
|
||||
nestedMessage = true
|
||||
} else {
|
||||
canHaveDefault = true // proto2 scalar field
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Ptr:
|
||||
nestedMessage = true // repeated message
|
||||
case reflect.Uint8:
|
||||
canHaveDefault = true // bytes field
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
if ft.Elem().Kind() == reflect.Ptr {
|
||||
nestedMessage = true // map with message values
|
||||
}
|
||||
sf, nested, err := fieldDefault(ft, prop)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Print(err)
|
||||
case nested:
|
||||
dm.nested = append(dm.nested, fi)
|
||||
case sf != nil:
|
||||
sf.index = fi
|
||||
dm.scalars = append(dm.scalars, *sf)
|
||||
}
|
||||
|
||||
if !canHaveDefault {
|
||||
if nestedMessage {
|
||||
dm.nested = append(dm.nested, fi)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
sf := scalarField{
|
||||
index: fi,
|
||||
kind: ft.Elem().Kind(),
|
||||
}
|
||||
|
||||
// scalar fields without defaults
|
||||
if !prop.HasDefault {
|
||||
dm.scalars = append(dm.scalars, sf)
|
||||
continue
|
||||
}
|
||||
|
||||
// a scalar field: either *T or []byte
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
x, err := strconv.ParseBool(prop.Default)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default bool %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Float32:
|
||||
x, err := strconv.ParseFloat(prop.Default, 32)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default float32 %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = float32(x)
|
||||
case reflect.Float64:
|
||||
x, err := strconv.ParseFloat(prop.Default, 64)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default float64 %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Int32:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default int32 %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = int32(x)
|
||||
case reflect.Int64:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default int64 %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.String:
|
||||
sf.value = prop.Default
|
||||
case reflect.Uint8:
|
||||
// []byte (not *uint8)
|
||||
sf.value = []byte(prop.Default)
|
||||
case reflect.Uint32:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default uint32 %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = uint32(x)
|
||||
case reflect.Uint64:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
log.Printf("proto: bad default uint64 %q: %v", prop.Default, err)
|
||||
continue
|
||||
}
|
||||
sf.value = x
|
||||
default:
|
||||
log.Printf("proto: unhandled def kind %v", ft.Elem().Kind())
|
||||
continue
|
||||
}
|
||||
|
||||
dm.scalars = append(dm.scalars, sf)
|
||||
}
|
||||
|
||||
return dm
|
||||
}
|
||||
|
||||
// fieldDefault returns the scalarField for field type ft.
|
||||
// sf will be nil if the field can not have a default.
|
||||
// nestedMessage will be true if this is a nested message.
|
||||
// Note that sf.index is not set on return.
|
||||
func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) {
|
||||
var canHaveDefault bool
|
||||
switch ft.Kind() {
|
||||
case reflect.Ptr:
|
||||
if ft.Elem().Kind() == reflect.Struct {
|
||||
nestedMessage = true
|
||||
} else {
|
||||
canHaveDefault = true // proto2 scalar field
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Ptr:
|
||||
nestedMessage = true // repeated message
|
||||
case reflect.Uint8:
|
||||
canHaveDefault = true // bytes field
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
if ft.Elem().Kind() == reflect.Ptr {
|
||||
nestedMessage = true // map with message values
|
||||
}
|
||||
}
|
||||
|
||||
if !canHaveDefault {
|
||||
if nestedMessage {
|
||||
return nil, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// We now know that ft is a pointer or slice.
|
||||
sf = &scalarField{kind: ft.Elem().Kind()}
|
||||
|
||||
// scalar fields without defaults
|
||||
if !prop.HasDefault {
|
||||
return sf, false, nil
|
||||
}
|
||||
|
||||
// a scalar field: either *T or []byte
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
x, err := strconv.ParseBool(prop.Default)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Float32:
|
||||
x, err := strconv.ParseFloat(prop.Default, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = float32(x)
|
||||
case reflect.Float64:
|
||||
x, err := strconv.ParseFloat(prop.Default, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Int32:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = int32(x)
|
||||
case reflect.Int64:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.String:
|
||||
sf.value = prop.Default
|
||||
case reflect.Uint8:
|
||||
// []byte (not *uint8)
|
||||
sf.value = []byte(prop.Default)
|
||||
case reflect.Uint32:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = uint32(x)
|
||||
case reflect.Uint64:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
default:
|
||||
return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind())
|
||||
}
|
||||
|
||||
return sf, false, nil
|
||||
}
|
||||
|
||||
// Map fields may have key types of non-float scalars, strings and enums.
|
||||
// The easiest way to sort them in some deterministic order is to use fmt.
|
||||
// If this turns out to be inefficient we can always consider other options,
|
||||
// such as doing a Schwartzian transform.
|
||||
|
||||
type mapKeys []reflect.Value
|
||||
func mapKeys(vs []reflect.Value) sort.Interface {
|
||||
s := mapKeySorter{
|
||||
vs: vs,
|
||||
// default Less function: textual comparison
|
||||
less: func(a, b reflect.Value) bool {
|
||||
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
|
||||
},
|
||||
}
|
||||
|
||||
func (s mapKeys) Len() int { return len(s) }
|
||||
func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s mapKeys) Less(i, j int) bool {
|
||||
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
|
||||
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
|
||||
// numeric keys are sorted numerically.
|
||||
if len(vs) == 0 {
|
||||
return s
|
||||
}
|
||||
switch vs[0].Kind() {
|
||||
case reflect.Int32, reflect.Int64:
|
||||
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
type mapKeySorter struct {
|
||||
vs []reflect.Value
|
||||
less func(a, b reflect.Value) bool
|
||||
}
|
||||
|
||||
func (s mapKeySorter) Len() int { return len(s.vs) }
|
||||
func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
|
||||
func (s mapKeySorter) Less(i, j int) bool {
|
||||
return s.less(s.vs[i], s.vs[j])
|
||||
}
|
||||
|
||||
// isProto3Zero reports whether v is a zero proto3 value.
|
||||
func isProto3Zero(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.String:
|
||||
return v.String() == ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -44,11 +44,11 @@ import (
|
|||
"sort"
|
||||
)
|
||||
|
||||
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
|
||||
// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
|
||||
// A message type ID is required for storing a protocol buffer in a message set.
|
||||
var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
|
||||
var errNoMessageTypeID = errors.New("proto does not have a message type ID")
|
||||
|
||||
// The first two types (_MessageSet_Item and MessageSet)
|
||||
// The first two types (_MessageSet_Item and messageSet)
|
||||
// model what the protocol compiler produces for the following protocol message:
|
||||
// message MessageSet {
|
||||
// repeated group Item = 1 {
|
||||
|
@ -58,27 +58,20 @@ var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
|
|||
// }
|
||||
// That is the MessageSet wire format. We can't use a proto to generate these
|
||||
// because that would introduce a circular dependency between it and this package.
|
||||
//
|
||||
// When a proto1 proto has a field that looks like:
|
||||
// optional message<MessageSet> info = 3;
|
||||
// the protocol compiler produces a field in the generated struct that looks like:
|
||||
// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"`
|
||||
// The package is automatically inserted so there is no need for that proto file to
|
||||
// import this package.
|
||||
|
||||
type _MessageSet_Item struct {
|
||||
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
|
||||
Message []byte `protobuf:"bytes,3,req,name=message"`
|
||||
}
|
||||
|
||||
type MessageSet struct {
|
||||
type messageSet struct {
|
||||
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
|
||||
XXX_unrecognized []byte
|
||||
// TODO: caching?
|
||||
}
|
||||
|
||||
// Make sure MessageSet is a Message.
|
||||
var _ Message = (*MessageSet)(nil)
|
||||
// Make sure messageSet is a Message.
|
||||
var _ Message = (*messageSet)(nil)
|
||||
|
||||
// messageTypeIder is an interface satisfied by a protocol buffer type
|
||||
// that may be stored in a MessageSet.
|
||||
|
@ -86,7 +79,7 @@ type messageTypeIder interface {
|
|||
MessageTypeId() int32
|
||||
}
|
||||
|
||||
func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
|
||||
func (ms *messageSet) find(pb Message) *_MessageSet_Item {
|
||||
mti, ok := pb.(messageTypeIder)
|
||||
if !ok {
|
||||
return nil
|
||||
|
@ -100,24 +93,24 @@ func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Has(pb Message) bool {
|
||||
func (ms *messageSet) Has(pb Message) bool {
|
||||
if ms.find(pb) != nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Unmarshal(pb Message) error {
|
||||
func (ms *messageSet) Unmarshal(pb Message) error {
|
||||
if item := ms.find(pb); item != nil {
|
||||
return Unmarshal(item.Message, pb)
|
||||
}
|
||||
if _, ok := pb.(messageTypeIder); !ok {
|
||||
return ErrNoMessageTypeId
|
||||
return errNoMessageTypeID
|
||||
}
|
||||
return nil // TODO: return error instead?
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Marshal(pb Message) error {
|
||||
func (ms *messageSet) Marshal(pb Message) error {
|
||||
msg, err := Marshal(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -130,7 +123,7 @@ func (ms *MessageSet) Marshal(pb Message) error {
|
|||
|
||||
mti, ok := pb.(messageTypeIder)
|
||||
if !ok {
|
||||
return ErrNoMessageTypeId
|
||||
return errNoMessageTypeID
|
||||
}
|
||||
|
||||
mtid := mti.MessageTypeId()
|
||||
|
@ -141,9 +134,9 @@ func (ms *MessageSet) Marshal(pb Message) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Reset() { *ms = MessageSet{} }
|
||||
func (ms *MessageSet) String() string { return CompactTextString(ms) }
|
||||
func (*MessageSet) ProtoMessage() {}
|
||||
func (ms *messageSet) Reset() { *ms = messageSet{} }
|
||||
func (ms *messageSet) String() string { return CompactTextString(ms) }
|
||||
func (*messageSet) ProtoMessage() {}
|
||||
|
||||
// Support for the message_set_wire_format message option.
|
||||
|
||||
|
@ -169,7 +162,7 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
|
|||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
|
||||
ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
|
||||
for _, id := range ids {
|
||||
e := m[int32(id)]
|
||||
// Remove the wire type and field number varint, as well as the length varint.
|
||||
|
@ -186,7 +179,7 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
|
|||
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
|
||||
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
|
||||
ms := new(MessageSet)
|
||||
ms := new(messageSet)
|
||||
if err := Unmarshal(buf, ms); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ import (
|
|||
|
||||
func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
|
||||
// Check that a repeated message set entry will be concatenated.
|
||||
in := &MessageSet{
|
||||
in := &messageSet{
|
||||
Item: []*_MessageSet_Item{
|
||||
{TypeId: Int32(12345), Message: []byte("hoo")},
|
||||
{TypeId: Int32(12345), Message: []byte("hah")},
|
||||
|
|
|
@ -144,8 +144,8 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
|
|||
return structPointer_ifield(p, f).(*map[int32]Extension)
|
||||
}
|
||||
|
||||
// Map returns the reflect.Value for the address of a map field in the struct.
|
||||
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
|
||||
// NewAt returns the reflect.Value for a pointer to a field in the struct.
|
||||
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
|
||||
return structPointer_field(p, f).Addr()
|
||||
}
|
||||
|
||||
|
|
|
@ -130,8 +130,8 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
|
|||
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
|
||||
}
|
||||
|
||||
// Map returns the reflect.Value for the address of a map field in the struct.
|
||||
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
|
||||
// NewAt returns the reflect.Value for a pointer to a field in the struct.
|
||||
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
|
||||
return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
|
||||
}
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ package proto
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
|
@ -89,6 +90,12 @@ type decoder func(p *Buffer, prop *Properties, base structPointer) error
|
|||
// A valueDecoder decodes a single integer in a particular encoding.
|
||||
type valueDecoder func(o *Buffer) (x uint64, err error)
|
||||
|
||||
// A oneofMarshaler does the marshaling for all oneof fields in a message.
|
||||
type oneofMarshaler func(Message, *Buffer) error
|
||||
|
||||
// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
|
||||
type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
|
||||
|
||||
// tagMap is an optimization over map[int]int for typical protocol buffer
|
||||
// use-cases. Encoded protocol buffers are often in tag order with small tag
|
||||
// numbers.
|
||||
|
@ -137,6 +144,21 @@ type StructProperties struct {
|
|||
order []int // list of struct field numbers in tag order
|
||||
unrecField field // field id of the XXX_unrecognized []byte field
|
||||
extendable bool // is this an extendable proto
|
||||
|
||||
oneofMarshaler oneofMarshaler
|
||||
oneofUnmarshaler oneofUnmarshaler
|
||||
stype reflect.Type
|
||||
|
||||
// OneofTypes contains information about the oneof fields in this message.
|
||||
// It is keyed by the original name of a field.
|
||||
OneofTypes map[string]*OneofProperties
|
||||
}
|
||||
|
||||
// OneofProperties represents information about a specific field in a oneof.
|
||||
type OneofProperties struct {
|
||||
Type reflect.Type // pointer to generated struct type for this oneof field
|
||||
Field int // struct field number of the containing oneof in the message
|
||||
Prop *Properties
|
||||
}
|
||||
|
||||
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
|
||||
|
@ -161,6 +183,7 @@ type Properties struct {
|
|||
Packed bool // relevant for repeated primitives only
|
||||
Enum string // set for enum types only
|
||||
proto3 bool // whether this is known to be a proto3 field; set for []byte only
|
||||
oneof bool // whether this is a oneof field
|
||||
|
||||
Default string // default value
|
||||
HasDefault bool // whether an explicit default was provided
|
||||
|
@ -216,6 +239,9 @@ func (p *Properties) String() string {
|
|||
if p.proto3 {
|
||||
s += ",proto3"
|
||||
}
|
||||
if p.oneof {
|
||||
s += ",oneof"
|
||||
}
|
||||
if len(p.Enum) > 0 {
|
||||
s += ",enum=" + p.Enum
|
||||
}
|
||||
|
@ -292,6 +318,8 @@ func (p *Properties) Parse(s string) {
|
|||
p.Enum = f[5:]
|
||||
case f == "proto3":
|
||||
p.proto3 = true
|
||||
case f == "oneof":
|
||||
p.oneof = true
|
||||
case strings.HasPrefix(f, "def="):
|
||||
p.HasDefault = true
|
||||
p.Default = f[4:] // rest of string
|
||||
|
@ -713,6 +741,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
|
|||
prop.Prop = make([]*Properties, t.NumField())
|
||||
prop.order = make([]int, t.NumField())
|
||||
|
||||
isOneofMessage := false
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
p := new(Properties)
|
||||
|
@ -733,6 +762,10 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
|
|||
if f.Name == "XXX_unrecognized" { // special case
|
||||
prop.unrecField = toField(&f)
|
||||
}
|
||||
oneof := f.Tag.Get("protobuf_oneof") != "" // special case
|
||||
if oneof {
|
||||
isOneofMessage = true
|
||||
}
|
||||
prop.Prop[i] = p
|
||||
prop.order[i] = i
|
||||
if debug {
|
||||
|
@ -742,7 +775,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
|
|||
}
|
||||
print("\n")
|
||||
}
|
||||
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") {
|
||||
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof {
|
||||
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
|
||||
}
|
||||
}
|
||||
|
@ -750,6 +783,41 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
|
|||
// Re-order prop.order.
|
||||
sort.Sort(prop)
|
||||
|
||||
type oneofMessage interface {
|
||||
XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), []interface{})
|
||||
}
|
||||
if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); isOneofMessage && ok {
|
||||
var oots []interface{}
|
||||
prop.oneofMarshaler, prop.oneofUnmarshaler, oots = om.XXX_OneofFuncs()
|
||||
prop.stype = t
|
||||
|
||||
// Interpret oneof metadata.
|
||||
prop.OneofTypes = make(map[string]*OneofProperties)
|
||||
for _, oot := range oots {
|
||||
oop := &OneofProperties{
|
||||
Type: reflect.ValueOf(oot).Type(), // *T
|
||||
Prop: new(Properties),
|
||||
}
|
||||
sft := oop.Type.Elem().Field(0)
|
||||
oop.Prop.Name = sft.Name
|
||||
oop.Prop.Parse(sft.Tag.Get("protobuf"))
|
||||
// There will be exactly one interface field that
|
||||
// this new value is assignable to.
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Type.Kind() != reflect.Interface {
|
||||
continue
|
||||
}
|
||||
if !oop.Type.AssignableTo(f.Type) {
|
||||
continue
|
||||
}
|
||||
oop.Field = i
|
||||
break
|
||||
}
|
||||
prop.OneofTypes[oop.Prop.OrigName] = oop
|
||||
}
|
||||
}
|
||||
|
||||
// build required counts
|
||||
// build tags
|
||||
reqCount := 0
|
||||
|
@ -813,3 +881,35 @@ func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[
|
|||
}
|
||||
enumStringMaps[typeName] = unusedNameMap
|
||||
}
|
||||
|
||||
// EnumValueMap returns the mapping from names to integers of the
|
||||
// enum type enumType, or a nil if not found.
|
||||
func EnumValueMap(enumType string) map[string]int32 {
|
||||
return enumValueMaps[enumType]
|
||||
}
|
||||
|
||||
// A registry of all linked message types.
|
||||
// The string is a fully-qualified proto name ("pkg.Message").
|
||||
var (
|
||||
protoTypes = make(map[string]reflect.Type)
|
||||
revProtoTypes = make(map[reflect.Type]string)
|
||||
)
|
||||
|
||||
// RegisterType is called from generated code and maps from the fully qualified
|
||||
// proto name to the type (pointer to struct) of the protocol buffer.
|
||||
func RegisterType(x Message, name string) {
|
||||
if _, ok := protoTypes[name]; ok {
|
||||
// TODO: Some day, make this a panic.
|
||||
log.Printf("proto: duplicate proto type registered: %s", name)
|
||||
return
|
||||
}
|
||||
t := reflect.TypeOf(x)
|
||||
protoTypes[name] = t
|
||||
revProtoTypes[t] = name
|
||||
}
|
||||
|
||||
// MessageName returns the fully-qualified proto name for the given message type.
|
||||
func MessageName(x Message) string { return revProtoTypes[reflect.TypeOf(x)] }
|
||||
|
||||
// MessageType returns the message type (pointer to struct) for a named message.
|
||||
func MessageType(name string) reflect.Type { return protoTypes[name] }
|
||||
|
|
|
@ -16,10 +16,14 @@ It has these top-level messages:
|
|||
package proto3_proto
|
||||
|
||||
import proto "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import testdata "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto/testdata"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
type Message_Humour int32
|
||||
|
||||
|
@ -118,5 +122,8 @@ func (m *MessageWithMap) GetByteMapping() map[bool][]byte {
|
|||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Message)(nil), "proto3_proto.Message")
|
||||
proto.RegisterType((*Nested)(nil), "proto3_proto.Nested")
|
||||
proto.RegisterType((*MessageWithMap)(nil), "proto3_proto.MessageWithMap")
|
||||
proto.RegisterEnum("proto3_proto.Message_Humour", Message_Humour_name, Message_Humour_value)
|
||||
}
|
||||
|
|
|
@ -124,6 +124,11 @@ var SizeTests = []struct {
|
|||
{"map field with big entry", &pb.MessageWithMap{NameMapping: map[int32]string{8: strings.Repeat("x", 125)}}},
|
||||
{"map field with big key and val", &pb.MessageWithMap{StrToStr: map[string]string{strings.Repeat("x", 70): strings.Repeat("y", 70)}}},
|
||||
{"map field with big numeric key", &pb.MessageWithMap{NameMapping: map[int32]string{0xf00d: "om nom nom"}}},
|
||||
|
||||
{"oneof not set", &pb.Communique{}},
|
||||
{"oneof zero int32", &pb.Communique{Union: &pb.Communique_Number{Number: 0}}},
|
||||
{"oneof int32", &pb.Communique{Union: &pb.Communique_Number{Number: 3}}},
|
||||
{"oneof string", &pb.Communique{Union: &pb.Communique_Name{Name: "Rhythmic Fman"}}},
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
|
|
|
@ -32,6 +32,6 @@
|
|||
all: regenerate
|
||||
|
||||
regenerate:
|
||||
go install github.com/gogo/protobuf/protoc-gen-gogo/version/protoc-min-version
|
||||
go install github.com/gogo/protobuf/protoc-min-version
|
||||
protoc-min-version --version="3.0.0" --gogo_out=. test.proto
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue