Compare commits

...

10 Commits

Author SHA1 Message Date
Gyu-Ho Lee 9f5189cd00 tools/v3benchmark: fix import pb path 2016-02-12 11:13:20 -08:00
Gyu-Ho Lee 6bc9b6e4ed *: bump to v2.2.5+git 2016-02-01 13:21:47 -08:00
Gyu-Ho Lee bc9ddf2601 *: bump to v2.2.5 2016-02-01 11:59:45 -08:00
Gyu-Ho Lee e7285f5626 *: updates for gRPC Godeps update 2016-02-01 11:28:39 -08:00
Gyu-Ho Lee 4613a7e61b Godeps: update gRPC 2016-02-01 11:28:29 -08:00
Xiang Li 31d1fa20bf Godeps: update boltdb 2016-01-28 10:48:22 -08:00
Hitoshi Mitake 259f89d59a etcdserver, auth: not cache a flag of auth status
This commit removes a flag that indicates auth is enabled or disabled
because it doesn't have an invalidation mechanism.

Fixes https://github.com/coreos/etcd/issues/3601 and https://github.com/coreos/etcd/issues/3964

Conflicts (Resolved):
	etcdserver/auth/auth.go
2016-01-27 15:36:16 -08:00
Gyu-Ho Lee 756d701f13 client: do not timeout when wait is true
Current V2 watch waits by encoding URL with wait=true.
When a client sets 'no-sync', it requests directly to
proxy and the proxy redirects it by cloning the request
object, which leads to cancel the original request when
it times out and the cloned request gets closed prematurely.

This fixes coreos#3894 by querying
the original client request in order to not use context timeout
when 'wait=true'.
2016-01-27 15:29:56 -08:00
Xiang Li d14a673ace etcdmain: proxy should only lookup srv if there is no existing cluster file 2016-01-27 15:29:54 -08:00
Gyu-Ho Lee 4abb231505 *: bump to v2.2.4+git 2016-01-13 14:22:16 -08:00
526 changed files with 55669 additions and 13351 deletions

94
Godeps/Godeps.json generated
View File

@ -1,6 +1,6 @@
{
"ImportPath": "github.com/coreos/etcd",
"GoVersion": "go1.4.2",
"GoVersion": "go1.5.1",
"Packages": [
"./..."
],
@ -10,6 +10,10 @@
"Comment": "null-5",
"Rev": "75cd24fc2f2c2a2088577d12123ddee5f54e0675"
},
{
"ImportPath": "github.com/akrennmair/gopcap",
"Rev": "00e11033259acb75598ba416495bb708d864a010"
},
{
"ImportPath": "github.com/beorn7/perks/quantile",
"Rev": "b965b613227fddccbfffe13eae360ed3fa822f8d"
@ -20,17 +24,21 @@
},
{
"ImportPath": "github.com/boltdb/bolt",
"Comment": "v1.0-119-g90fef38",
"Rev": "90fef389f98027ca55594edd7dbd6e7f3926fdad"
"Comment": "v1.1.0-19-g0b00eff",
"Rev": "0b00effdd7a8270ebd91c24297e51643e370dd52"
},
{
"ImportPath": "github.com/bradfitz/http2",
"Rev": "3e36af6d3af0e56fa3da71099f864933dea3d9fb"
"ImportPath": "github.com/cheggaaa/pb",
"Rev": "da1f27ad1d9509b16f65f52fd9d8138b0f2dc7b2"
},
{
"ImportPath": "github.com/codegangsta/cli",
"Comment": "1.2.0-26-gf7ebb76",
"Rev": "f7ebb761e83e21225d1d8954fde853bf8edd46c4"
"Comment": "1.2.0-183-gb5232bb",
"Rev": "b5232bb2934f606f9f27a1305f1eea224e8e8b88"
},
{
"ImportPath": "github.com/coreos/gexpect",
"Rev": "5173270e159f5aa8fbc999dc7e3dcb50f4098a69"
},
{
"ImportPath": "github.com/coreos/go-semver/semver",
@ -55,9 +63,15 @@
"ImportPath": "github.com/coreos/pkg/capnslog",
"Rev": "2c77715c4df99b5420ffcae14ead08f52104065d"
},
{
"ImportPath": "github.com/cpuguy83/go-md2man/md2man",
"Comment": "v1.0.4",
"Rev": "71acacd42f85e5e82f70a55327789582a5200a90"
},
{
"ImportPath": "github.com/gogo/protobuf/proto",
"Rev": "64f27bf06efee53589314a6e5a4af34cdd85adf6"
"Comment": "v0.1-118-ge8904f5",
"Rev": "e8904f58e872a473a5b91bc9bf3377d223555263"
},
{
"ImportPath": "github.com/golang/glog",
@ -65,20 +79,37 @@
},
{
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "5677a0e3d5e89854c9974e1256839ee23f8233ca"
"Rev": "6aaa8d47701fa6cf07e914ec01fde3d4a1fe79c3"
},
{
"ImportPath": "github.com/google/btree",
"Rev": "cc6329d4279e3f025a53a83c397d2339b5705c45"
},
{
"ImportPath": "github.com/inconshreveable/mousetrap",
"Rev": "76626ae9c91c4f2a10f34cad8ce83ea42c93bb75"
},
{
"ImportPath": "github.com/jonboulle/clockwork",
"Rev": "72f9bd7c4e0c2a40055ab3d0f09654f730cce982"
},
{
"ImportPath": "github.com/kballard/go-shellquote",
"Rev": "d8ec1a69a250a17bb0e419c386eac1f3711dc142"
},
{
"ImportPath": "github.com/kr/pty",
"Comment": "release.r56-29-gf7ee69f",
"Rev": "f7ee69f31298ecbe5d2b349c711e2547a617d398"
},
{
"ImportPath": "github.com/matttproud/golang_protobuf_extensions/pbutil",
"Rev": "fc2b8d3a73c4867e51861bbdd5ae3c1f0869dd6a"
},
{
"ImportPath": "github.com/olekukonko/ts",
"Rev": "ecf753e7c962639ab5a1fb46f7da627d4c0a04b8"
},
{
"ImportPath": "github.com/prometheus/client_golang/prometheus",
"Comment": "0.7.0-52-ge51041b",
@ -102,8 +133,25 @@
"Rev": "454a56f35412459b5e684fd5ec0f9211b94f002a"
},
{
"ImportPath": "github.com/rakyll/pb",
"Rev": "dc507ad06b7462501281bb4691ee43f0b1d1ec37"
"ImportPath": "github.com/russross/blackfriday",
"Comment": "v1.4-2-g300106c",
"Rev": "300106c228d52c8941d4b3de6054a6062a86dda3"
},
{
"ImportPath": "github.com/shurcooL/sanitized_anchor_name",
"Rev": "10ef21a441db47d8b13ebcc5fd2310f636973c77"
},
{
"ImportPath": "github.com/spacejam/loghisto",
"Rev": "323309774dec8b7430187e46cd0793974ccca04a"
},
{
"ImportPath": "github.com/spf13/cobra",
"Rev": "1c44ec8d3f1552cac48999f9306da23c4d8a288b"
},
{
"ImportPath": "github.com/spf13/pflag",
"Rev": "08b1a584251b5b62f458943640fc8ebd4d50aaa5"
},
{
"ImportPath": "github.com/stretchr/testify/assert",
@ -127,27 +175,27 @@
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "7dbad50ab5b31073856416cdcfeb2796d682f844"
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
},
{
"ImportPath": "golang.org/x/oauth2",
"Rev": "3046bc76d6dfd7d3707f6640f85e42d9c4050f50"
"ImportPath": "golang.org/x/net/http2",
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
},
{
"ImportPath": "golang.org/x/net/internal/timeseries",
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
},
{
"ImportPath": "golang.org/x/net/trace",
"Rev": "04b9de9b512f58addf28c9853d50ebef61c3953e"
},
{
"ImportPath": "golang.org/x/sys/unix",
"Rev": "9c60d1c508f5134d1ca726b4641db998f2523357"
},
{
"ImportPath": "google.golang.org/cloud/compute/metadata",
"Rev": "f20d6dcccb44ed49de45ae3703312cb46e627db1"
},
{
"ImportPath": "google.golang.org/cloud/internal",
"Rev": "f20d6dcccb44ed49de45ae3703312cb46e627db1"
},
{
"ImportPath": "google.golang.org/grpc",
"Rev": "f5ebd86be717593ab029545492c93ddf8914832b"
"Rev": "e29d659177655e589850ba7d3d83f7ce12ef23dd"
}
]
}

View File

@ -0,0 +1,5 @@
#*
*~
/tools/pass/pass
/tools/pcaptest/pcaptest
/tools/tcpdump/tcpdump

View File

@ -0,0 +1,27 @@
Copyright (c) 2009-2011 Andreas Krennmair. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Andreas Krennmair nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,11 @@
# PCAP
This is a simple wrapper around libpcap for Go. Originally written by Andreas
Krennmair <ak@synflood.at> and only minorly touched up by Mark Smith <mark@qq.is>.
Please see the included pcaptest.go and tcpdump.go programs for instructions on
how to use this library.
Miek Gieben <miek@miek.nl> has created a more Go-like package and replaced functionality
with standard functions from the standard library. The package has also been renamed to
pcap.

View File

@ -0,0 +1,527 @@
package pcap
import (
"encoding/binary"
"fmt"
"net"
"reflect"
"strings"
)
const (
TYPE_IP = 0x0800
TYPE_ARP = 0x0806
TYPE_IP6 = 0x86DD
TYPE_VLAN = 0x8100
IP_ICMP = 1
IP_INIP = 4
IP_TCP = 6
IP_UDP = 17
)
const (
ERRBUF_SIZE = 256
// According to pcap-linktype(7).
LINKTYPE_NULL = 0
LINKTYPE_ETHERNET = 1
LINKTYPE_TOKEN_RING = 6
LINKTYPE_ARCNET = 7
LINKTYPE_SLIP = 8
LINKTYPE_PPP = 9
LINKTYPE_FDDI = 10
LINKTYPE_ATM_RFC1483 = 100
LINKTYPE_RAW = 101
LINKTYPE_PPP_HDLC = 50
LINKTYPE_PPP_ETHER = 51
LINKTYPE_C_HDLC = 104
LINKTYPE_IEEE802_11 = 105
LINKTYPE_FRELAY = 107
LINKTYPE_LOOP = 108
LINKTYPE_LINUX_SLL = 113
LINKTYPE_LTALK = 104
LINKTYPE_PFLOG = 117
LINKTYPE_PRISM_HEADER = 119
LINKTYPE_IP_OVER_FC = 122
LINKTYPE_SUNATM = 123
LINKTYPE_IEEE802_11_RADIO = 127
LINKTYPE_ARCNET_LINUX = 129
LINKTYPE_LINUX_IRDA = 144
LINKTYPE_LINUX_LAPD = 177
)
type addrHdr interface {
SrcAddr() string
DestAddr() string
Len() int
}
type addrStringer interface {
String(addr addrHdr) string
}
func decodemac(pkt []byte) uint64 {
mac := uint64(0)
for i := uint(0); i < 6; i++ {
mac = (mac << 8) + uint64(pkt[i])
}
return mac
}
// Decode decodes the headers of a Packet.
func (p *Packet) Decode() {
if len(p.Data) <= 14 {
return
}
p.Type = int(binary.BigEndian.Uint16(p.Data[12:14]))
p.DestMac = decodemac(p.Data[0:6])
p.SrcMac = decodemac(p.Data[6:12])
if len(p.Data) >= 15 {
p.Payload = p.Data[14:]
}
switch p.Type {
case TYPE_IP:
p.decodeIp()
case TYPE_IP6:
p.decodeIp6()
case TYPE_ARP:
p.decodeArp()
case TYPE_VLAN:
p.decodeVlan()
}
}
func (p *Packet) headerString(headers []interface{}) string {
// If there's just one header, return that.
if len(headers) == 1 {
if hdr, ok := headers[0].(fmt.Stringer); ok {
return hdr.String()
}
}
// If there are two headers (IPv4/IPv6 -> TCP/UDP/IP..)
if len(headers) == 2 {
// Commonly the first header is an address.
if addr, ok := p.Headers[0].(addrHdr); ok {
if hdr, ok := p.Headers[1].(addrStringer); ok {
return fmt.Sprintf("%s %s", p.Time, hdr.String(addr))
}
}
}
// For IP in IP, we do a recursive call.
if len(headers) >= 2 {
if addr, ok := headers[0].(addrHdr); ok {
if _, ok := headers[1].(addrHdr); ok {
return fmt.Sprintf("%s > %s IP in IP: ",
addr.SrcAddr(), addr.DestAddr(), p.headerString(headers[1:]))
}
}
}
var typeNames []string
for _, hdr := range headers {
typeNames = append(typeNames, reflect.TypeOf(hdr).String())
}
return fmt.Sprintf("unknown [%s]", strings.Join(typeNames, ","))
}
// String prints a one-line representation of the packet header.
// The output is suitable for use in a tcpdump program.
func (p *Packet) String() string {
// If there are no headers, print "unsupported protocol".
if len(p.Headers) == 0 {
return fmt.Sprintf("%s unsupported protocol %d", p.Time, int(p.Type))
}
return fmt.Sprintf("%s %s", p.Time, p.headerString(p.Headers))
}
// Arphdr is a ARP packet header.
type Arphdr struct {
Addrtype uint16
Protocol uint16
HwAddressSize uint8
ProtAddressSize uint8
Operation uint16
SourceHwAddress []byte
SourceProtAddress []byte
DestHwAddress []byte
DestProtAddress []byte
}
func (arp *Arphdr) String() (s string) {
switch arp.Operation {
case 1:
s = "ARP request"
case 2:
s = "ARP Reply"
}
if arp.Addrtype == LINKTYPE_ETHERNET && arp.Protocol == TYPE_IP {
s = fmt.Sprintf("%012x (%s) > %012x (%s)",
decodemac(arp.SourceHwAddress), arp.SourceProtAddress,
decodemac(arp.DestHwAddress), arp.DestProtAddress)
} else {
s = fmt.Sprintf("addrtype = %d protocol = %d", arp.Addrtype, arp.Protocol)
}
return
}
func (p *Packet) decodeArp() {
if len(p.Payload) < 8 {
return
}
pkt := p.Payload
arp := new(Arphdr)
arp.Addrtype = binary.BigEndian.Uint16(pkt[0:2])
arp.Protocol = binary.BigEndian.Uint16(pkt[2:4])
arp.HwAddressSize = pkt[4]
arp.ProtAddressSize = pkt[5]
arp.Operation = binary.BigEndian.Uint16(pkt[6:8])
if len(pkt) < int(8+2*arp.HwAddressSize+2*arp.ProtAddressSize) {
return
}
arp.SourceHwAddress = pkt[8 : 8+arp.HwAddressSize]
arp.SourceProtAddress = pkt[8+arp.HwAddressSize : 8+arp.HwAddressSize+arp.ProtAddressSize]
arp.DestHwAddress = pkt[8+arp.HwAddressSize+arp.ProtAddressSize : 8+2*arp.HwAddressSize+arp.ProtAddressSize]
arp.DestProtAddress = pkt[8+2*arp.HwAddressSize+arp.ProtAddressSize : 8+2*arp.HwAddressSize+2*arp.ProtAddressSize]
p.Headers = append(p.Headers, arp)
if len(pkt) >= int(8+2*arp.HwAddressSize+2*arp.ProtAddressSize) {
p.Payload = p.Payload[8+2*arp.HwAddressSize+2*arp.ProtAddressSize:]
}
}
// IPadr is the header of an IP packet.
type Iphdr struct {
Version uint8
Ihl uint8
Tos uint8
Length uint16
Id uint16
Flags uint8
FragOffset uint16
Ttl uint8
Protocol uint8
Checksum uint16
SrcIp []byte
DestIp []byte
}
func (p *Packet) decodeIp() {
if len(p.Payload) < 20 {
return
}
pkt := p.Payload
ip := new(Iphdr)
ip.Version = uint8(pkt[0]) >> 4
ip.Ihl = uint8(pkt[0]) & 0x0F
ip.Tos = pkt[1]
ip.Length = binary.BigEndian.Uint16(pkt[2:4])
ip.Id = binary.BigEndian.Uint16(pkt[4:6])
flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
ip.Flags = uint8(flagsfrags >> 13)
ip.FragOffset = flagsfrags & 0x1FFF
ip.Ttl = pkt[8]
ip.Protocol = pkt[9]
ip.Checksum = binary.BigEndian.Uint16(pkt[10:12])
ip.SrcIp = pkt[12:16]
ip.DestIp = pkt[16:20]
pEnd := int(ip.Length)
if pEnd > len(pkt) {
pEnd = len(pkt)
}
if len(pkt) >= pEnd && int(ip.Ihl*4) < pEnd {
p.Payload = pkt[ip.Ihl*4 : pEnd]
} else {
p.Payload = []byte{}
}
p.Headers = append(p.Headers, ip)
p.IP = ip
switch ip.Protocol {
case IP_TCP:
p.decodeTcp()
case IP_UDP:
p.decodeUdp()
case IP_ICMP:
p.decodeIcmp()
case IP_INIP:
p.decodeIp()
}
}
func (ip *Iphdr) SrcAddr() string { return net.IP(ip.SrcIp).String() }
func (ip *Iphdr) DestAddr() string { return net.IP(ip.DestIp).String() }
func (ip *Iphdr) Len() int { return int(ip.Length) }
type Vlanhdr struct {
Priority byte
DropEligible bool
VlanIdentifier int
Type int // Not actually part of the vlan header, but the type of the actual packet
}
func (v *Vlanhdr) String() {
fmt.Sprintf("VLAN Priority:%d Drop:%v Tag:%d", v.Priority, v.DropEligible, v.VlanIdentifier)
}
func (p *Packet) decodeVlan() {
pkt := p.Payload
vlan := new(Vlanhdr)
if len(pkt) < 4 {
return
}
vlan.Priority = (pkt[2] & 0xE0) >> 13
vlan.DropEligible = pkt[2]&0x10 != 0
vlan.VlanIdentifier = int(binary.BigEndian.Uint16(pkt[:2])) & 0x0FFF
vlan.Type = int(binary.BigEndian.Uint16(p.Payload[2:4]))
p.Headers = append(p.Headers, vlan)
if len(pkt) >= 5 {
p.Payload = p.Payload[4:]
}
switch vlan.Type {
case TYPE_IP:
p.decodeIp()
case TYPE_IP6:
p.decodeIp6()
case TYPE_ARP:
p.decodeArp()
}
}
type Tcphdr struct {
SrcPort uint16
DestPort uint16
Seq uint32
Ack uint32
DataOffset uint8
Flags uint16
Window uint16
Checksum uint16
Urgent uint16
Data []byte
}
const (
TCP_FIN = 1 << iota
TCP_SYN
TCP_RST
TCP_PSH
TCP_ACK
TCP_URG
TCP_ECE
TCP_CWR
TCP_NS
)
func (p *Packet) decodeTcp() {
if len(p.Payload) < 20 {
return
}
pkt := p.Payload
tcp := new(Tcphdr)
tcp.SrcPort = binary.BigEndian.Uint16(pkt[0:2])
tcp.DestPort = binary.BigEndian.Uint16(pkt[2:4])
tcp.Seq = binary.BigEndian.Uint32(pkt[4:8])
tcp.Ack = binary.BigEndian.Uint32(pkt[8:12])
tcp.DataOffset = (pkt[12] & 0xF0) >> 4
tcp.Flags = binary.BigEndian.Uint16(pkt[12:14]) & 0x1FF
tcp.Window = binary.BigEndian.Uint16(pkt[14:16])
tcp.Checksum = binary.BigEndian.Uint16(pkt[16:18])
tcp.Urgent = binary.BigEndian.Uint16(pkt[18:20])
if len(pkt) >= int(tcp.DataOffset*4) {
p.Payload = pkt[tcp.DataOffset*4:]
}
p.Headers = append(p.Headers, tcp)
p.TCP = tcp
}
func (tcp *Tcphdr) String(hdr addrHdr) string {
return fmt.Sprintf("TCP %s:%d > %s:%d %s SEQ=%d ACK=%d LEN=%d",
hdr.SrcAddr(), int(tcp.SrcPort), hdr.DestAddr(), int(tcp.DestPort),
tcp.FlagsString(), int64(tcp.Seq), int64(tcp.Ack), hdr.Len())
}
func (tcp *Tcphdr) FlagsString() string {
var sflags []string
if 0 != (tcp.Flags & TCP_SYN) {
sflags = append(sflags, "syn")
}
if 0 != (tcp.Flags & TCP_FIN) {
sflags = append(sflags, "fin")
}
if 0 != (tcp.Flags & TCP_ACK) {
sflags = append(sflags, "ack")
}
if 0 != (tcp.Flags & TCP_PSH) {
sflags = append(sflags, "psh")
}
if 0 != (tcp.Flags & TCP_RST) {
sflags = append(sflags, "rst")
}
if 0 != (tcp.Flags & TCP_URG) {
sflags = append(sflags, "urg")
}
if 0 != (tcp.Flags & TCP_NS) {
sflags = append(sflags, "ns")
}
if 0 != (tcp.Flags & TCP_CWR) {
sflags = append(sflags, "cwr")
}
if 0 != (tcp.Flags & TCP_ECE) {
sflags = append(sflags, "ece")
}
return fmt.Sprintf("[%s]", strings.Join(sflags, " "))
}
type Udphdr struct {
SrcPort uint16
DestPort uint16
Length uint16
Checksum uint16
}
func (p *Packet) decodeUdp() {
if len(p.Payload) < 8 {
return
}
pkt := p.Payload
udp := new(Udphdr)
udp.SrcPort = binary.BigEndian.Uint16(pkt[0:2])
udp.DestPort = binary.BigEndian.Uint16(pkt[2:4])
udp.Length = binary.BigEndian.Uint16(pkt[4:6])
udp.Checksum = binary.BigEndian.Uint16(pkt[6:8])
p.Headers = append(p.Headers, udp)
p.UDP = udp
if len(p.Payload) >= 8 {
p.Payload = pkt[8:]
}
}
func (udp *Udphdr) String(hdr addrHdr) string {
return fmt.Sprintf("UDP %s:%d > %s:%d LEN=%d CHKSUM=%d",
hdr.SrcAddr(), int(udp.SrcPort), hdr.DestAddr(), int(udp.DestPort),
int(udp.Length), int(udp.Checksum))
}
type Icmphdr struct {
Type uint8
Code uint8
Checksum uint16
Id uint16
Seq uint16
Data []byte
}
func (p *Packet) decodeIcmp() *Icmphdr {
if len(p.Payload) < 8 {
return nil
}
pkt := p.Payload
icmp := new(Icmphdr)
icmp.Type = pkt[0]
icmp.Code = pkt[1]
icmp.Checksum = binary.BigEndian.Uint16(pkt[2:4])
icmp.Id = binary.BigEndian.Uint16(pkt[4:6])
icmp.Seq = binary.BigEndian.Uint16(pkt[6:8])
p.Payload = pkt[8:]
p.Headers = append(p.Headers, icmp)
return icmp
}
func (icmp *Icmphdr) String(hdr addrHdr) string {
return fmt.Sprintf("ICMP %s > %s Type = %d Code = %d ",
hdr.SrcAddr(), hdr.DestAddr(), icmp.Type, icmp.Code)
}
func (icmp *Icmphdr) TypeString() (result string) {
switch icmp.Type {
case 0:
result = fmt.Sprintf("Echo reply seq=%d", icmp.Seq)
case 3:
switch icmp.Code {
case 0:
result = "Network unreachable"
case 1:
result = "Host unreachable"
case 2:
result = "Protocol unreachable"
case 3:
result = "Port unreachable"
default:
result = "Destination unreachable"
}
case 8:
result = fmt.Sprintf("Echo request seq=%d", icmp.Seq)
case 30:
result = "Traceroute"
}
return
}
type Ip6hdr struct {
// http://www.networksorcery.com/enp/protocol/ipv6.htm
Version uint8 // 4 bits
TrafficClass uint8 // 8 bits
FlowLabel uint32 // 20 bits
Length uint16 // 16 bits
NextHeader uint8 // 8 bits, same as Protocol in Iphdr
HopLimit uint8 // 8 bits
SrcIp []byte // 16 bytes
DestIp []byte // 16 bytes
}
func (p *Packet) decodeIp6() {
if len(p.Payload) < 40 {
return
}
pkt := p.Payload
ip6 := new(Ip6hdr)
ip6.Version = uint8(pkt[0]) >> 4
ip6.TrafficClass = uint8((binary.BigEndian.Uint16(pkt[0:2]) >> 4) & 0x00FF)
ip6.FlowLabel = binary.BigEndian.Uint32(pkt[0:4]) & 0x000FFFFF
ip6.Length = binary.BigEndian.Uint16(pkt[4:6])
ip6.NextHeader = pkt[6]
ip6.HopLimit = pkt[7]
ip6.SrcIp = pkt[8:24]
ip6.DestIp = pkt[24:40]
if len(p.Payload) >= 40 {
p.Payload = pkt[40:]
}
p.Headers = append(p.Headers, ip6)
switch ip6.NextHeader {
case IP_TCP:
p.decodeTcp()
case IP_UDP:
p.decodeUdp()
case IP_ICMP:
p.decodeIcmp()
case IP_INIP:
p.decodeIp()
}
}
func (ip6 *Ip6hdr) SrcAddr() string { return net.IP(ip6.SrcIp).String() }
func (ip6 *Ip6hdr) DestAddr() string { return net.IP(ip6.DestIp).String() }
func (ip6 *Ip6hdr) Len() int { return int(ip6.Length) }

View File

@ -0,0 +1,247 @@
package pcap
import (
"bytes"
"testing"
"time"
)
var testSimpleTcpPacket *Packet = &Packet{
Data: []byte{
0x00, 0x00, 0x0c, 0x9f, 0xf0, 0x20, 0xbc, 0x30, 0x5b, 0xe8, 0xd3, 0x49,
0x08, 0x00, 0x45, 0x00, 0x01, 0xa4, 0x39, 0xdf, 0x40, 0x00, 0x40, 0x06,
0x55, 0x5a, 0xac, 0x11, 0x51, 0x49, 0xad, 0xde, 0xfe, 0xe1, 0xc5, 0xf7,
0x00, 0x50, 0xc5, 0x7e, 0x0e, 0x48, 0x49, 0x07, 0x42, 0x32, 0x80, 0x18,
0x00, 0x73, 0xab, 0xb1, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x03, 0x77,
0x37, 0x9c, 0x42, 0x77, 0x5e, 0x3a, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x20,
0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, 0x0d, 0x0a, 0x48, 0x6f,
0x73, 0x74, 0x3a, 0x20, 0x77, 0x77, 0x77, 0x2e, 0x66, 0x69, 0x73, 0x68,
0x2e, 0x63, 0x6f, 0x6d, 0x0d, 0x0a, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63,
0x74, 0x69, 0x6f, 0x6e, 0x3a, 0x20, 0x6b, 0x65, 0x65, 0x70, 0x2d, 0x61,
0x6c, 0x69, 0x76, 0x65, 0x0d, 0x0a, 0x55, 0x73, 0x65, 0x72, 0x2d, 0x41,
0x67, 0x65, 0x6e, 0x74, 0x3a, 0x20, 0x4d, 0x6f, 0x7a, 0x69, 0x6c, 0x6c,
0x61, 0x2f, 0x35, 0x2e, 0x30, 0x20, 0x28, 0x58, 0x31, 0x31, 0x3b, 0x20,
0x4c, 0x69, 0x6e, 0x75, 0x78, 0x20, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34,
0x29, 0x20, 0x41, 0x70, 0x70, 0x6c, 0x65, 0x57, 0x65, 0x62, 0x4b, 0x69,
0x74, 0x2f, 0x35, 0x33, 0x35, 0x2e, 0x32, 0x20, 0x28, 0x4b, 0x48, 0x54,
0x4d, 0x4c, 0x2c, 0x20, 0x6c, 0x69, 0x6b, 0x65, 0x20, 0x47, 0x65, 0x63,
0x6b, 0x6f, 0x29, 0x20, 0x43, 0x68, 0x72, 0x6f, 0x6d, 0x65, 0x2f, 0x31,
0x35, 0x2e, 0x30, 0x2e, 0x38, 0x37, 0x34, 0x2e, 0x31, 0x32, 0x31, 0x20,
0x53, 0x61, 0x66, 0x61, 0x72, 0x69, 0x2f, 0x35, 0x33, 0x35, 0x2e, 0x32,
0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x3a, 0x20, 0x74, 0x65,
0x78, 0x74, 0x2f, 0x68, 0x74, 0x6d, 0x6c, 0x2c, 0x61, 0x70, 0x70, 0x6c,
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, 0x68, 0x74, 0x6d,
0x6c, 0x2b, 0x78, 0x6d, 0x6c, 0x2c, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, 0x6d, 0x6c, 0x3b, 0x71, 0x3d,
0x30, 0x2e, 0x39, 0x2c, 0x2a, 0x2f, 0x2a, 0x3b, 0x71, 0x3d, 0x30, 0x2e,
0x38, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x45, 0x6e,
0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0x20, 0x67, 0x7a, 0x69, 0x70,
0x2c, 0x64, 0x65, 0x66, 0x6c, 0x61, 0x74, 0x65, 0x2c, 0x73, 0x64, 0x63,
0x68, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x4c, 0x61,
0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x3a, 0x20, 0x65, 0x6e, 0x2d, 0x55,
0x53, 0x2c, 0x65, 0x6e, 0x3b, 0x71, 0x3d, 0x30, 0x2e, 0x38, 0x0d, 0x0a,
0x41, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x43, 0x68, 0x61, 0x72, 0x73,
0x65, 0x74, 0x3a, 0x20, 0x49, 0x53, 0x4f, 0x2d, 0x38, 0x38, 0x35, 0x39,
0x2d, 0x31, 0x2c, 0x75, 0x74, 0x66, 0x2d, 0x38, 0x3b, 0x71, 0x3d, 0x30,
0x2e, 0x37, 0x2c, 0x2a, 0x3b, 0x71, 0x3d, 0x30, 0x2e, 0x33, 0x0d, 0x0a,
0x0d, 0x0a,
}}
func BenchmarkDecodeSimpleTcpPacket(b *testing.B) {
for i := 0; i < b.N; i++ {
testSimpleTcpPacket.Decode()
}
}
func TestDecodeSimpleTcpPacket(t *testing.T) {
p := testSimpleTcpPacket
p.Decode()
if p.DestMac != 0x00000c9ff020 {
t.Error("Dest mac", p.DestMac)
}
if p.SrcMac != 0xbc305be8d349 {
t.Error("Src mac", p.SrcMac)
}
if len(p.Headers) != 2 {
t.Error("Incorrect number of headers", len(p.Headers))
return
}
if ip, ipOk := p.Headers[0].(*Iphdr); ipOk {
if ip.Version != 4 {
t.Error("ip Version", ip.Version)
}
if ip.Ihl != 5 {
t.Error("ip header length", ip.Ihl)
}
if ip.Tos != 0 {
t.Error("ip TOS", ip.Tos)
}
if ip.Length != 420 {
t.Error("ip Length", ip.Length)
}
if ip.Id != 14815 {
t.Error("ip ID", ip.Id)
}
if ip.Flags != 0x02 {
t.Error("ip Flags", ip.Flags)
}
if ip.FragOffset != 0 {
t.Error("ip Fragoffset", ip.FragOffset)
}
if ip.Ttl != 64 {
t.Error("ip TTL", ip.Ttl)
}
if ip.Protocol != 6 {
t.Error("ip Protocol", ip.Protocol)
}
if ip.Checksum != 0x555A {
t.Error("ip Checksum", ip.Checksum)
}
if !bytes.Equal(ip.SrcIp, []byte{172, 17, 81, 73}) {
t.Error("ip Src", ip.SrcIp)
}
if !bytes.Equal(ip.DestIp, []byte{173, 222, 254, 225}) {
t.Error("ip Dest", ip.DestIp)
}
if tcp, tcpOk := p.Headers[1].(*Tcphdr); tcpOk {
if tcp.SrcPort != 50679 {
t.Error("tcp srcport", tcp.SrcPort)
}
if tcp.DestPort != 80 {
t.Error("tcp destport", tcp.DestPort)
}
if tcp.Seq != 0xc57e0e48 {
t.Error("tcp seq", tcp.Seq)
}
if tcp.Ack != 0x49074232 {
t.Error("tcp ack", tcp.Ack)
}
if tcp.DataOffset != 8 {
t.Error("tcp dataoffset", tcp.DataOffset)
}
if tcp.Flags != 0x18 {
t.Error("tcp flags", tcp.Flags)
}
if tcp.Window != 0x73 {
t.Error("tcp window", tcp.Window)
}
if tcp.Checksum != 0xabb1 {
t.Error("tcp checksum", tcp.Checksum)
}
if tcp.Urgent != 0 {
t.Error("tcp urgent", tcp.Urgent)
}
} else {
t.Error("Second header is not TCP header")
}
} else {
t.Error("First header is not IP header")
}
if string(p.Payload) != "GET / HTTP/1.1\r\nHost: www.fish.com\r\nConnection: keep-alive\r\nUser-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.2 (KHTML, like Gecko) Chrome/15.0.874.121 Safari/535.2\r\nAccept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Encoding: gzip,deflate,sdch\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3\r\n\r\n" {
t.Error("--- PAYLOAD STRING ---\n", string(p.Payload), "\n--- PAYLOAD BYTES ---\n", p.Payload)
}
}
// Makes sure packet payload doesn't display the 6 trailing null of this packet
// as part of the payload. They're actually the ethernet trailer.
func TestDecodeSmallTcpPacketHasEmptyPayload(t *testing.T) {
p := &Packet{
// This packet is only 54 bits (an empty TCP RST), thus 6 trailing null
// bytes are added by the ethernet layer to make it the minimum packet size.
Data: []byte{
0xbc, 0x30, 0x5b, 0xe8, 0xd3, 0x49, 0xb8, 0xac, 0x6f, 0x92, 0xd5, 0xbf,
0x08, 0x00, 0x45, 0x00, 0x00, 0x28, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06,
0x3f, 0x9f, 0xac, 0x11, 0x51, 0xc5, 0xac, 0x11, 0x51, 0x49, 0x00, 0x63,
0x9a, 0xef, 0x00, 0x00, 0x00, 0x00, 0x2e, 0xc1, 0x27, 0x83, 0x50, 0x14,
0x00, 0x00, 0xc3, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}}
p.Decode()
if p.Payload == nil {
t.Error("Nil payload")
}
if len(p.Payload) != 0 {
t.Error("Non-empty payload:", p.Payload)
}
}
func TestDecodeVlanPacket(t *testing.T) {
p := &Packet{
Data: []byte{
0x00, 0x10, 0xdb, 0xff, 0x10, 0x00, 0x00, 0x15, 0x2c, 0x9d, 0xcc, 0x00, 0x81, 0x00, 0x01, 0xf7,
0x08, 0x00, 0x45, 0x00, 0x00, 0x28, 0x29, 0x8d, 0x40, 0x00, 0x7d, 0x06, 0x83, 0xa0, 0xac, 0x1b,
0xca, 0x8e, 0x45, 0x16, 0x94, 0xe2, 0xd4, 0x0a, 0x00, 0x50, 0xdf, 0xab, 0x9c, 0xc6, 0xcd, 0x1e,
0xe5, 0xd1, 0x50, 0x10, 0x01, 0x00, 0x5a, 0x74, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}}
p.Decode()
if p.Type != TYPE_VLAN {
t.Error("Didn't detect vlan")
}
if len(p.Headers) != 3 {
t.Error("Incorrect number of headers:", len(p.Headers))
for i, h := range p.Headers {
t.Errorf("Header %d: %#v", i, h)
}
t.FailNow()
}
if _, ok := p.Headers[0].(*Vlanhdr); !ok {
t.Errorf("First header isn't vlan: %q", p.Headers[0])
}
if _, ok := p.Headers[1].(*Iphdr); !ok {
t.Errorf("Second header isn't IP: %q", p.Headers[1])
}
if _, ok := p.Headers[2].(*Tcphdr); !ok {
t.Errorf("Third header isn't TCP: %q", p.Headers[2])
}
}
func TestDecodeFuzzFallout(t *testing.T) {
testData := []struct {
Data []byte
}{
{[]byte("000000000000\x81\x000")},
{[]byte("000000000000\x81\x00000")},
{[]byte("000000000000\x86\xdd0")},
{[]byte("000000000000\b\x000")},
{[]byte("000000000000\b\x060")},
{[]byte{}},
{[]byte("000000000000\b\x0600000000")},
{[]byte("000000000000\x86\xdd000000\x01000000000000000000000000000000000")},
{[]byte("000000000000\x81\x0000\b\x0600000000")},
{[]byte("000000000000\b\x00n0000000000000000000")},
{[]byte("000000000000\x86\xdd000000\x0100000000000000000000000000000000000")},
{[]byte("000000000000\x81\x0000\b\x00g0000000000000000000")},
//{[]byte()},
{[]byte("000000000000\b\x00400000000\x110000000000")},
{[]byte("0nMء\xfe\x13\x13\x81\x00gr\b\x00&x\xc9\xe5b'\x1e0\x00\x04\x00\x0020596224")},
{[]byte("000000000000\x81\x0000\b\x00400000000\x110000000000")},
{[]byte("000000000000\b\x00000000000\x0600\xff0000000")},
{[]byte("000000000000\x86\xdd000000\x06000000000000000000000000000000000")},
{[]byte("000000000000\x81\x0000\b\x00000000000\x0600b0000000")},
{[]byte("000000000000\x81\x0000\b\x00400000000\x060000000000")},
{[]byte("000000000000\x86\xdd000000\x11000000000000000000000000000000000")},
{[]byte("000000000000\x86\xdd000000\x0600000000000000000000000000000000000000000000M")},
{[]byte("000000000000\b\x00500000000\x0600000000000")},
{[]byte("0nM\xd80\xfe\x13\x13\x81\x00gr\b\x00&x\xc9\xe5b'\x1e0\x00\x04\x00\x0020596224")},
}
for _, entry := range testData {
pkt := &Packet{
Time: time.Now(),
Caplen: uint32(len(entry.Data)),
Len: uint32(len(entry.Data)),
Data: entry.Data,
}
pkt.Decode()
/*
func() {
defer func() {
if err := recover(); err != nil {
t.Fatalf("%d. %q failed: %v", idx, string(entry.Data), err)
}
}()
pkt.Decode()
}()
*/
}
}

View File

@ -0,0 +1,206 @@
package pcap
import (
"encoding/binary"
"fmt"
"io"
"time"
)
// FileHeader is the parsed header of a pcap file.
// http://wiki.wireshark.org/Development/LibpcapFileFormat
type FileHeader struct {
MagicNumber uint32
VersionMajor uint16
VersionMinor uint16
TimeZone int32
SigFigs uint32
SnapLen uint32
Network uint32
}
type PacketTime struct {
Sec int32
Usec int32
}
// Convert the PacketTime to a go Time struct.
func (p *PacketTime) Time() time.Time {
return time.Unix(int64(p.Sec), int64(p.Usec)*1000)
}
// Packet is a single packet parsed from a pcap file.
//
// Convenient access to IP, TCP, and UDP headers is provided after Decode()
// is called if the packet is of the appropriate type.
type Packet struct {
Time time.Time // packet send/receive time
Caplen uint32 // bytes stored in the file (caplen <= len)
Len uint32 // bytes sent/received
Data []byte // packet data
Type int // protocol type, see LINKTYPE_*
DestMac uint64
SrcMac uint64
Headers []interface{} // decoded headers, in order
Payload []byte // remaining non-header bytes
IP *Iphdr // IP header (for IP packets, after decoding)
TCP *Tcphdr // TCP header (for TCP packets, after decoding)
UDP *Udphdr // UDP header (for UDP packets after decoding)
}
// Reader parses pcap files.
type Reader struct {
flip bool
buf io.Reader
err error
fourBytes []byte
twoBytes []byte
sixteenBytes []byte
Header FileHeader
}
// NewReader reads pcap data from an io.Reader.
func NewReader(reader io.Reader) (*Reader, error) {
r := &Reader{
buf: reader,
fourBytes: make([]byte, 4),
twoBytes: make([]byte, 2),
sixteenBytes: make([]byte, 16),
}
switch magic := r.readUint32(); magic {
case 0xa1b2c3d4:
r.flip = false
case 0xd4c3b2a1:
r.flip = true
default:
return nil, fmt.Errorf("pcap: bad magic number: %0x", magic)
}
r.Header = FileHeader{
MagicNumber: 0xa1b2c3d4,
VersionMajor: r.readUint16(),
VersionMinor: r.readUint16(),
TimeZone: r.readInt32(),
SigFigs: r.readUint32(),
SnapLen: r.readUint32(),
Network: r.readUint32(),
}
return r, nil
}
// Next returns the next packet or nil if no more packets can be read.
func (r *Reader) Next() *Packet {
d := r.sixteenBytes
r.err = r.read(d)
if r.err != nil {
return nil
}
timeSec := asUint32(d[0:4], r.flip)
timeUsec := asUint32(d[4:8], r.flip)
capLen := asUint32(d[8:12], r.flip)
origLen := asUint32(d[12:16], r.flip)
data := make([]byte, capLen)
if r.err = r.read(data); r.err != nil {
return nil
}
return &Packet{
Time: time.Unix(int64(timeSec), int64(timeUsec)),
Caplen: capLen,
Len: origLen,
Data: data,
}
}
func (r *Reader) read(data []byte) error {
var err error
n, err := r.buf.Read(data)
for err == nil && n != len(data) {
var chunk int
chunk, err = r.buf.Read(data[n:])
n += chunk
}
if len(data) == n {
return nil
}
return err
}
func (r *Reader) readUint32() uint32 {
data := r.fourBytes
if r.err = r.read(data); r.err != nil {
return 0
}
return asUint32(data, r.flip)
}
func (r *Reader) readInt32() int32 {
data := r.fourBytes
if r.err = r.read(data); r.err != nil {
return 0
}
return int32(asUint32(data, r.flip))
}
func (r *Reader) readUint16() uint16 {
data := r.twoBytes
if r.err = r.read(data); r.err != nil {
return 0
}
return asUint16(data, r.flip)
}
// Writer writes a pcap file.
type Writer struct {
writer io.Writer
buf []byte
}
// NewWriter creates a Writer that stores output in an io.Writer.
// The FileHeader is written immediately.
func NewWriter(writer io.Writer, header *FileHeader) (*Writer, error) {
w := &Writer{
writer: writer,
buf: make([]byte, 24),
}
binary.LittleEndian.PutUint32(w.buf, header.MagicNumber)
binary.LittleEndian.PutUint16(w.buf[4:], header.VersionMajor)
binary.LittleEndian.PutUint16(w.buf[6:], header.VersionMinor)
binary.LittleEndian.PutUint32(w.buf[8:], uint32(header.TimeZone))
binary.LittleEndian.PutUint32(w.buf[12:], header.SigFigs)
binary.LittleEndian.PutUint32(w.buf[16:], header.SnapLen)
binary.LittleEndian.PutUint32(w.buf[20:], header.Network)
if _, err := writer.Write(w.buf); err != nil {
return nil, err
}
return w, nil
}
// Writer writes a packet to the underlying writer.
func (w *Writer) Write(pkt *Packet) error {
binary.LittleEndian.PutUint32(w.buf, uint32(pkt.Time.Unix()))
binary.LittleEndian.PutUint32(w.buf[4:], uint32(pkt.Time.Nanosecond()))
binary.LittleEndian.PutUint32(w.buf[8:], uint32(pkt.Time.Unix()))
binary.LittleEndian.PutUint32(w.buf[12:], pkt.Len)
if _, err := w.writer.Write(w.buf[:16]); err != nil {
return err
}
_, err := w.writer.Write(pkt.Data)
return err
}
func asUint32(data []byte, flip bool) uint32 {
if flip {
return binary.BigEndian.Uint32(data)
}
return binary.LittleEndian.Uint32(data)
}
func asUint16(data []byte, flip bool) uint16 {
if flip {
return binary.BigEndian.Uint16(data)
}
return binary.LittleEndian.Uint16(data)
}

View File

@ -0,0 +1,266 @@
// Interface to both live and offline pcap parsing.
package pcap
/*
#cgo linux LDFLAGS: -lpcap
#cgo freebsd LDFLAGS: -lpcap
#cgo darwin LDFLAGS: -lpcap
#cgo windows CFLAGS: -I C:/WpdPack/Include
#cgo windows,386 LDFLAGS: -L C:/WpdPack/Lib -lwpcap
#cgo windows,amd64 LDFLAGS: -L C:/WpdPack/Lib/x64 -lwpcap
#include <stdlib.h>
#include <pcap.h>
// Workaround for not knowing how to cast to const u_char**
int hack_pcap_next_ex(pcap_t *p, struct pcap_pkthdr **pkt_header,
u_char **pkt_data) {
return pcap_next_ex(p, pkt_header, (const u_char **)pkt_data);
}
*/
import "C"
import (
"errors"
"net"
"syscall"
"time"
"unsafe"
)
type Pcap struct {
cptr *C.pcap_t
}
type Stat struct {
PacketsReceived uint32
PacketsDropped uint32
PacketsIfDropped uint32
}
type Interface struct {
Name string
Description string
Addresses []IFAddress
// TODO: add more elements
}
type IFAddress struct {
IP net.IP
Netmask net.IPMask
// TODO: add broadcast + PtP dst ?
}
func (p *Pcap) Next() (pkt *Packet) {
rv, _ := p.NextEx()
return rv
}
// Openlive opens a device and returns a *Pcap handler
func Openlive(device string, snaplen int32, promisc bool, timeout_ms int32) (handle *Pcap, err error) {
var buf *C.char
buf = (*C.char)(C.calloc(ERRBUF_SIZE, 1))
h := new(Pcap)
var pro int32
if promisc {
pro = 1
}
dev := C.CString(device)
defer C.free(unsafe.Pointer(dev))
h.cptr = C.pcap_open_live(dev, C.int(snaplen), C.int(pro), C.int(timeout_ms), buf)
if nil == h.cptr {
handle = nil
err = errors.New(C.GoString(buf))
} else {
handle = h
}
C.free(unsafe.Pointer(buf))
return
}
func Openoffline(file string) (handle *Pcap, err error) {
var buf *C.char
buf = (*C.char)(C.calloc(ERRBUF_SIZE, 1))
h := new(Pcap)
cf := C.CString(file)
defer C.free(unsafe.Pointer(cf))
h.cptr = C.pcap_open_offline(cf, buf)
if nil == h.cptr {
handle = nil
err = errors.New(C.GoString(buf))
} else {
handle = h
}
C.free(unsafe.Pointer(buf))
return
}
func (p *Pcap) NextEx() (pkt *Packet, result int32) {
var pkthdr *C.struct_pcap_pkthdr
var buf_ptr *C.u_char
var buf unsafe.Pointer
result = int32(C.hack_pcap_next_ex(p.cptr, &pkthdr, &buf_ptr))
buf = unsafe.Pointer(buf_ptr)
if nil == buf {
return
}
pkt = new(Packet)
pkt.Time = time.Unix(int64(pkthdr.ts.tv_sec), int64(pkthdr.ts.tv_usec)*1000)
pkt.Caplen = uint32(pkthdr.caplen)
pkt.Len = uint32(pkthdr.len)
pkt.Data = C.GoBytes(buf, C.int(pkthdr.caplen))
return
}
func (p *Pcap) Close() {
C.pcap_close(p.cptr)
}
func (p *Pcap) Geterror() error {
return errors.New(C.GoString(C.pcap_geterr(p.cptr)))
}
func (p *Pcap) Getstats() (stat *Stat, err error) {
var cstats _Ctype_struct_pcap_stat
if -1 == C.pcap_stats(p.cptr, &cstats) {
return nil, p.Geterror()
}
stats := new(Stat)
stats.PacketsReceived = uint32(cstats.ps_recv)
stats.PacketsDropped = uint32(cstats.ps_drop)
stats.PacketsIfDropped = uint32(cstats.ps_ifdrop)
return stats, nil
}
func (p *Pcap) Setfilter(expr string) (err error) {
var bpf _Ctype_struct_bpf_program
cexpr := C.CString(expr)
defer C.free(unsafe.Pointer(cexpr))
if -1 == C.pcap_compile(p.cptr, &bpf, cexpr, 1, 0) {
return p.Geterror()
}
if -1 == C.pcap_setfilter(p.cptr, &bpf) {
C.pcap_freecode(&bpf)
return p.Geterror()
}
C.pcap_freecode(&bpf)
return nil
}
func Version() string {
return C.GoString(C.pcap_lib_version())
}
func (p *Pcap) Datalink() int {
return int(C.pcap_datalink(p.cptr))
}
func (p *Pcap) Setdatalink(dlt int) error {
if -1 == C.pcap_set_datalink(p.cptr, C.int(dlt)) {
return p.Geterror()
}
return nil
}
func DatalinkValueToName(dlt int) string {
if name := C.pcap_datalink_val_to_name(C.int(dlt)); name != nil {
return C.GoString(name)
}
return ""
}
func DatalinkValueToDescription(dlt int) string {
if desc := C.pcap_datalink_val_to_description(C.int(dlt)); desc != nil {
return C.GoString(desc)
}
return ""
}
func Findalldevs() (ifs []Interface, err error) {
var buf *C.char
buf = (*C.char)(C.calloc(ERRBUF_SIZE, 1))
defer C.free(unsafe.Pointer(buf))
var alldevsp *C.pcap_if_t
if -1 == C.pcap_findalldevs((**C.pcap_if_t)(&alldevsp), buf) {
return nil, errors.New(C.GoString(buf))
}
defer C.pcap_freealldevs((*C.pcap_if_t)(alldevsp))
dev := alldevsp
var i uint32
for i = 0; dev != nil; dev = (*C.pcap_if_t)(dev.next) {
i++
}
ifs = make([]Interface, i)
dev = alldevsp
for j := uint32(0); dev != nil; dev = (*C.pcap_if_t)(dev.next) {
var iface Interface
iface.Name = C.GoString(dev.name)
iface.Description = C.GoString(dev.description)
iface.Addresses = findalladdresses(dev.addresses)
// TODO: add more elements
ifs[j] = iface
j++
}
return
}
func findalladdresses(addresses *_Ctype_struct_pcap_addr) (retval []IFAddress) {
// TODO - make it support more than IPv4 and IPv6?
retval = make([]IFAddress, 0, 1)
for curaddr := addresses; curaddr != nil; curaddr = (*_Ctype_struct_pcap_addr)(curaddr.next) {
var a IFAddress
var err error
if a.IP, err = sockaddr_to_IP((*syscall.RawSockaddr)(unsafe.Pointer(curaddr.addr))); err != nil {
continue
}
if a.Netmask, err = sockaddr_to_IP((*syscall.RawSockaddr)(unsafe.Pointer(curaddr.addr))); err != nil {
continue
}
retval = append(retval, a)
}
return
}
func sockaddr_to_IP(rsa *syscall.RawSockaddr) (IP []byte, err error) {
switch rsa.Family {
case syscall.AF_INET:
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(rsa))
IP = make([]byte, 4)
for i := 0; i < len(IP); i++ {
IP[i] = pp.Addr[i]
}
return
case syscall.AF_INET6:
pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(rsa))
IP = make([]byte, 16)
for i := 0; i < len(IP); i++ {
IP[i] = pp.Addr[i]
}
return
}
err = errors.New("Unsupported address type")
return
}
func (p *Pcap) Inject(data []byte) (err error) {
buf := (*C.char)(C.malloc((C.size_t)(len(data))))
for i := 0; i < len(data); i++ {
*(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(buf)) + uintptr(i))) = data[i]
}
if -1 == C.pcap_sendpacket(p.cptr, (*C.u_char)(unsafe.Pointer(buf)), (C.int)(len(data))) {
err = p.Geterror()
}
C.free(unsafe.Pointer(buf))
return
}

View File

@ -0,0 +1,49 @@
package main
import (
"flag"
"fmt"
"os"
"runtime/pprof"
"time"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
)
func main() {
var filename *string = flag.String("file", "", "filename")
var decode *bool = flag.Bool("d", false, "If true, decode each packet")
var cpuprofile *string = flag.String("cpuprofile", "", "filename")
flag.Parse()
h, err := pcap.Openoffline(*filename)
if err != nil {
fmt.Printf("Couldn't create pcap reader: %v", err)
}
if *cpuprofile != "" {
if out, err := os.Create(*cpuprofile); err == nil {
pprof.StartCPUProfile(out)
defer func() {
pprof.StopCPUProfile()
out.Close()
}()
} else {
panic(err)
}
}
i, nilPackets := 0, 0
start := time.Now()
for pkt, code := h.NextEx(); code != -2; pkt, code = h.NextEx() {
if pkt == nil {
nilPackets++
} else if *decode {
pkt.Decode()
}
i++
}
duration := time.Since(start)
fmt.Printf("Took %v to process %v packets, %v per packet, %d nil packets\n", duration, i, duration/time.Duration(i), nilPackets)
}

View File

@ -0,0 +1,96 @@
package main
// Parses a pcap file, writes it back to disk, then verifies the files
// are the same.
import (
"bufio"
"flag"
"fmt"
"io"
"os"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
)
var input *string = flag.String("input", "", "input file")
var output *string = flag.String("output", "", "output file")
var decode *bool = flag.Bool("decode", false, "print decoded packets")
func copyPcap(dest, src string) {
f, err := os.Open(src)
if err != nil {
fmt.Printf("couldn't open %q: %v\n", src, err)
return
}
defer f.Close()
reader, err := pcap.NewReader(bufio.NewReader(f))
if err != nil {
fmt.Printf("couldn't create reader: %v\n", err)
return
}
w, err := os.Create(dest)
if err != nil {
fmt.Printf("couldn't open %q: %v\n", dest, err)
return
}
defer w.Close()
buf := bufio.NewWriter(w)
writer, err := pcap.NewWriter(buf, &reader.Header)
if err != nil {
fmt.Printf("couldn't create writer: %v\n", err)
return
}
for {
pkt := reader.Next()
if pkt == nil {
break
}
if *decode {
pkt.Decode()
fmt.Println(pkt.String())
}
writer.Write(pkt)
}
buf.Flush()
}
func check(dest, src string) {
f, err := os.Open(src)
if err != nil {
fmt.Printf("couldn't open %q: %v\n", src, err)
return
}
defer f.Close()
freader := bufio.NewReader(f)
g, err := os.Open(dest)
if err != nil {
fmt.Printf("couldn't open %q: %v\n", src, err)
return
}
defer g.Close()
greader := bufio.NewReader(g)
for {
fb, ferr := freader.ReadByte()
gb, gerr := greader.ReadByte()
if ferr == io.EOF && gerr == io.EOF {
break
}
if fb == gb {
continue
}
fmt.Println("FAIL")
return
}
fmt.Println("PASS")
}
func main() {
flag.Parse()
copyPcap(*output, *input)
check(*output, *input)
}

View File

@ -0,0 +1,82 @@
package main
import (
"flag"
"fmt"
"time"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
)
func min(x uint32, y uint32) uint32 {
if x < y {
return x
}
return y
}
func main() {
var device *string = flag.String("d", "", "device")
var file *string = flag.String("r", "", "file")
var expr *string = flag.String("e", "", "filter expression")
flag.Parse()
var h *pcap.Pcap
var err error
ifs, err := pcap.Findalldevs()
if len(ifs) == 0 {
fmt.Printf("Warning: no devices found : %s\n", err)
} else {
for i := 0; i < len(ifs); i++ {
fmt.Printf("dev %d: %s (%s)\n", i+1, ifs[i].Name, ifs[i].Description)
}
}
if *device != "" {
h, err = pcap.Openlive(*device, 65535, true, 0)
if h == nil {
fmt.Printf("Openlive(%s) failed: %s\n", *device, err)
return
}
} else if *file != "" {
h, err = pcap.Openoffline(*file)
if h == nil {
fmt.Printf("Openoffline(%s) failed: %s\n", *file, err)
return
}
} else {
fmt.Printf("usage: pcaptest [-d <device> | -r <file>]\n")
return
}
defer h.Close()
fmt.Printf("pcap version: %s\n", pcap.Version())
if *expr != "" {
fmt.Printf("Setting filter: %s\n", *expr)
err := h.Setfilter(*expr)
if err != nil {
fmt.Printf("Warning: setting filter failed: %s\n", err)
}
}
for pkt := h.Next(); pkt != nil; pkt = h.Next() {
fmt.Printf("time: %d.%06d (%s) caplen: %d len: %d\nData:",
int64(pkt.Time.Second()), int64(pkt.Time.Nanosecond()),
time.Unix(int64(pkt.Time.Second()), 0).String(), int64(pkt.Caplen), int64(pkt.Len))
for i := uint32(0); i < pkt.Caplen; i++ {
if i%32 == 0 {
fmt.Printf("\n")
}
if 32 <= pkt.Data[i] && pkt.Data[i] <= 126 {
fmt.Printf("%c", pkt.Data[i])
} else {
fmt.Printf(".")
}
}
fmt.Printf("\n\n")
}
}

View File

@ -0,0 +1,121 @@
package main
import (
"bufio"
"flag"
"fmt"
"os"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/akrennmair/gopcap"
)
const (
TYPE_IP = 0x0800
TYPE_ARP = 0x0806
TYPE_IP6 = 0x86DD
IP_ICMP = 1
IP_INIP = 4
IP_TCP = 6
IP_UDP = 17
)
var out *bufio.Writer
var errout *bufio.Writer
func main() {
var device *string = flag.String("i", "", "interface")
var snaplen *int = flag.Int("s", 65535, "snaplen")
var hexdump *bool = flag.Bool("X", false, "hexdump")
expr := ""
out = bufio.NewWriter(os.Stdout)
errout = bufio.NewWriter(os.Stderr)
flag.Usage = func() {
fmt.Fprintf(errout, "usage: %s [ -i interface ] [ -s snaplen ] [ -X ] [ expression ]\n", os.Args[0])
errout.Flush()
os.Exit(1)
}
flag.Parse()
if len(flag.Args()) > 0 {
expr = flag.Arg(0)
}
if *device == "" {
devs, err := pcap.Findalldevs()
if err != nil {
fmt.Fprintf(errout, "tcpdump: couldn't find any devices: %s\n", err)
}
if 0 == len(devs) {
flag.Usage()
}
*device = devs[0].Name
}
h, err := pcap.Openlive(*device, int32(*snaplen), true, 0)
if h == nil {
fmt.Fprintf(errout, "tcpdump: %s\n", err)
errout.Flush()
return
}
defer h.Close()
if expr != "" {
ferr := h.Setfilter(expr)
if ferr != nil {
fmt.Fprintf(out, "tcpdump: %s\n", ferr)
out.Flush()
}
}
for pkt := h.Next(); pkt != nil; pkt = h.Next() {
pkt.Decode()
fmt.Fprintf(out, "%s\n", pkt.String())
if *hexdump {
Hexdump(pkt)
}
out.Flush()
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func Hexdump(pkt *pcap.Packet) {
for i := 0; i < len(pkt.Data); i += 16 {
Dumpline(uint32(i), pkt.Data[i:min(i+16, len(pkt.Data))])
}
}
func Dumpline(addr uint32, line []byte) {
fmt.Fprintf(out, "\t0x%04x: ", int32(addr))
var i uint16
for i = 0; i < 16 && i < uint16(len(line)); i++ {
if i%2 == 0 {
out.WriteString(" ")
}
fmt.Fprintf(out, "%02x", line[i])
}
for j := i; j <= 16; j++ {
if j%2 == 0 {
out.WriteString(" ")
}
out.WriteString(" ")
}
out.WriteString(" ")
for i = 0; i < 16 && i < uint16(len(line)); i++ {
if line[i] >= 32 && line[i] <= 126 {
fmt.Fprintf(out, "%c", line[i])
} else {
out.WriteString(".")
}
}
out.WriteString("\n")
}

View File

@ -1,8 +1,8 @@
Bolt [![Build Status](https://drone.io/github.com/boltdb/bolt/status.png)](https://drone.io/github.com/boltdb/bolt/latest) [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.png?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.png)](https://godoc.org/github.com/boltdb/bolt) ![Version](http://img.shields.io/badge/version-1.0-green.png)
====
Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas] and
the [LMDB project][lmdb]. The goal of the project is to provide a simple,
Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas]
[LMDB project][lmdb]. The goal of the project is to provide a simple,
fast, and reliable database for projects that don't require a full database
server such as Postgres or MySQL.
@ -180,8 +180,8 @@ and then safely close your transaction if an error is returned. This is the
recommended way to use Bolt transactions.
However, sometimes you may want to manually start and end your transactions.
You can use the `Tx.Begin()` function directly but _please_ be sure to close the
transaction.
You can use the `Tx.Begin()` function directly but **please** be sure to close
the transaction.
```go
// Start a writable transaction.
@ -256,7 +256,7 @@ db.View(func(tx *bolt.Tx) error {
```
The `Get()` function does not return an error because its operation is
guarenteed to work (unless there is some kind of system failure). If the key
guaranteed to work (unless there is some kind of system failure). If the key
exists then it will return its byte slice value. If it doesn't exist then it
will return `nil`. It's important to note that you can have a zero-length value
set to a key which is different than the key not existing.
@ -268,6 +268,50 @@ transaction is open. If you need to use a value outside of the transaction
then you must use `copy()` to copy it to another byte slice.
### Autoincrementing integer for the bucket
By using the NextSequence() function, you can let Bolt determine a sequence
which can be used as the unique identifier for your key/value pairs. See the
example below.
```go
// CreateUser saves u to the store. The new user ID is set on u once the data is persisted.
func (s *Store) CreateUser(u *User) error {
return s.db.Update(func(tx *bolt.Tx) error {
// Retrieve the users bucket.
// This should be created when the DB is first opened.
b := tx.Bucket([]byte("users"))
// Generate ID for the user.
// This returns an error only if the Tx is closed or not writeable.
// That can't happen in an Update() call so I ignore the error check.
id, _ = b.NextSequence()
u.ID = int(id)
// Marshal user data into bytes.
buf, err := json.Marshal(u)
if err != nil {
return err
}
// Persist bytes to users bucket.
return b.Put(itob(u.ID), buf)
})
}
// itob returns an 8-byte big endian representation of v.
func itob(v int) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(v))
return b
}
type User struct {
ID int
...
}
```
### Iterating over keys
Bolt stores its keys in byte-sorted order within a bucket. This makes sequential
@ -382,8 +426,11 @@ func (*Bucket) DeleteBucket(key []byte) error
Bolt is a single file so it's easy to backup. You can use the `Tx.WriteTo()`
function to write a consistent view of the database to a writer. If you call
this from a read-only transaction, it will perform a hot backup and not block
your other database reads and writes. It will also use `O_DIRECT` when available
to prevent page cache trashing.
your other database reads and writes.
By default, it will use a regular file handle which will utilize the operating
system's page cache. See the [`Tx`](https://godoc.org/github.com/boltdb/bolt#Tx)
documentation for information about optimizing for larger-than-RAM datasets.
One common use case is to backup over HTTP so you can use tools like `cURL` to
do database backups:
@ -500,7 +547,7 @@ they are libraries bundled into the application, however, their underlying
structure is a log-structured merge-tree (LSM tree). An LSM tree optimizes
random writes by using a write ahead log and multi-tiered, sorted files called
SSTables. Bolt uses a B+tree internally and only a single file. Both approaches
have trade offs.
have trade-offs.
If you require a high random write throughput (>10,000 w/sec) or you need to use
spinning disks then LevelDB could be a good choice. If your application is
@ -568,7 +615,9 @@ Here are a few things to note when evaluating and using Bolt:
can in memory and will release memory as needed to other processes. This means
that Bolt can show very high memory usage when working with large databases.
However, this is expected and the OS will release memory as needed. Bolt can
handle databases much larger than the available physical RAM.
handle databases much larger than the available physical RAM, provided its
memory-map fits in the process virtual address space. It may be problematic
on 32-bits systems.
* The data structures in the Bolt database are memory mapped so the data file
will be endian specific. This means that you cannot copy a Bolt file from a
@ -587,6 +636,56 @@ Here are a few things to note when evaluating and using Bolt:
[page-allocation]: https://github.com/boltdb/bolt/issues/308#issuecomment-74811638
## Reading the Source
Bolt is a relatively small code base (<3KLOC) for an embedded, serializable,
transactional key/value database so it can be a good starting point for people
interested in how databases work.
The best places to start are the main entry points into Bolt:
- `Open()` - Initializes the reference to the database. It's responsible for
creating the database if it doesn't exist, obtaining an exclusive lock on the
file, reading the meta pages, & memory-mapping the file.
- `DB.Begin()` - Starts a read-only or read-write transaction depending on the
value of the `writable` argument. This requires briefly obtaining the "meta"
lock to keep track of open transactions. Only one read-write transaction can
exist at a time so the "rwlock" is acquired during the life of a read-write
transaction.
- `Bucket.Put()` - Writes a key/value pair into a bucket. After validating the
arguments, a cursor is used to traverse the B+tree to the page and position
where they key & value will be written. Once the position is found, the bucket
materializes the underlying page and the page's parent pages into memory as
"nodes". These nodes are where mutations occur during read-write transactions.
These changes get flushed to disk during commit.
- `Bucket.Get()` - Retrieves a key/value pair from a bucket. This uses a cursor
to move to the page & position of a key/value pair. During a read-only
transaction, the key and value data is returned as a direct reference to the
underlying mmap file so there's no allocation overhead. For read-write
transactions, this data may reference the mmap file or one of the in-memory
node values.
- `Cursor` - This object is simply for traversing the B+tree of on-disk pages
or in-memory nodes. It can seek to a specific key, move to the first or last
value, or it can move forward or backward. The cursor handles the movement up
and down the B+tree transparently to the end user.
- `Tx.Commit()` - Converts the in-memory dirty nodes and the list of free pages
into pages to be written to disk. Writing to disk then occurs in two phases.
First, the dirty pages are written to disk and an `fsync()` occurs. Second, a
new meta page with an incremented transaction ID is written and another
`fsync()` occurs. This two phase write ensures that partially written data
pages are ignored in the event of a crash since the meta page pointing to them
is never written. Partially written meta pages are invalidated because they
are written with a checksum.
If you have additional notes that could be helpful for others, please submit
them via pull request.
## Other Projects Using Bolt
Below is a list of public, open source projects that use Bolt:
@ -615,7 +714,11 @@ Below is a list of public, open source projects that use Bolt:
* [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data.
* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system.
* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware.
* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistant, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs.
* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistent, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs.
* [drive](https://github.com/odeke-em/drive) - drive is an unofficial Google Drive command line client for \*NIX operating systems.
* [stow](https://github.com/djherbis/stow) - a persistence manager for objects
backed by boltdb.
* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining
simple tx and key scans.
If you are using Bolt in a project please send a pull request to add it to the list.

View File

@ -0,0 +1,9 @@
// +build arm64
package bolt
// maxMapSize represents the largest mmap size supported by Bolt.
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
// maxAllocSize is the size used when creating array pointers.
const maxAllocSize = 0x7FFFFFFF

View File

@ -4,8 +4,6 @@ import (
"syscall"
)
var odirect = syscall.O_DIRECT
// fdatasync flushes written data to a file descriptor.
func fdatasync(db *DB) error {
return syscall.Fdatasync(int(db.file.Fd()))

View File

@ -11,8 +11,6 @@ const (
msInvalidate // invalidate cached data
)
var odirect int
func msync(db *DB) error {
_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(db.data)), uintptr(db.datasz), msInvalidate)
if errno != 0 {

View File

@ -0,0 +1,9 @@
// +build ppc64le
package bolt
// maxMapSize represents the largest mmap size supported by Bolt.
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
// maxAllocSize is the size used when creating array pointers.
const maxAllocSize = 0x7FFFFFFF

View File

@ -0,0 +1,9 @@
// +build s390x
package bolt
// maxMapSize represents the largest mmap size supported by Bolt.
const maxMapSize = 0xFFFFFFFFFFFF // 256TB
// maxAllocSize is the size used when creating array pointers.
const maxAllocSize = 0x7FFFFFFF

View File

@ -58,7 +58,7 @@ func mmap(db *DB, sz int) error {
}
// Map the data file to memory.
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED)
b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
if err != nil {
return err
}

View File

@ -2,11 +2,12 @@ package bolt
import (
"fmt"
"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/sys/unix"
"os"
"syscall"
"time"
"unsafe"
"github.com/coreos/etcd/Godeps/_workspace/src/golang.org/x/sys/unix"
)
// flock acquires an advisory lock on a file descriptor.
@ -67,7 +68,7 @@ func mmap(db *DB, sz int) error {
}
// Map the data file to memory.
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED)
b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags)
if err != nil {
return err
}

View File

@ -8,7 +8,37 @@ import (
"unsafe"
)
var odirect int
// LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1
var (
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
procLockFileEx = modkernel32.NewProc("LockFileEx")
procUnlockFileEx = modkernel32.NewProc("UnlockFileEx")
)
const (
// see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx
flagLockExclusive = 2
flagLockFailImmediately = 1
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx
errLockViolation syscall.Errno = 0x21
)
func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)))
if r == 0 {
return err
}
return nil
}
func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0)
if r == 0 {
return err
}
return nil
}
// fdatasync flushes written data to a file descriptor.
func fdatasync(db *DB) error {
@ -16,13 +46,37 @@ func fdatasync(db *DB) error {
}
// flock acquires an advisory lock on a file descriptor.
func flock(f *os.File, _ bool, _ time.Duration) error {
return nil
func flock(f *os.File, exclusive bool, timeout time.Duration) error {
var t time.Time
for {
// If we're beyond our timeout then return an error.
// This can only occur after we've attempted a flock once.
if t.IsZero() {
t = time.Now()
} else if timeout > 0 && time.Since(t) > timeout {
return ErrTimeout
}
var flag uint32 = flagLockFailImmediately
if exclusive {
flag |= flagLockExclusive
}
err := lockFileEx(syscall.Handle(f.Fd()), flag, 0, 1, 0, &syscall.Overlapped{})
if err == nil {
return nil
} else if err != errLockViolation {
return err
}
// Wait for a bit and try again.
time.Sleep(50 * time.Millisecond)
}
}
// funlock releases an advisory lock on a file descriptor.
func funlock(f *os.File) error {
return nil
return unlockFileEx(syscall.Handle(f.Fd()), 0, 1, 0, &syscall.Overlapped{})
}
// mmap memory maps a DB's data file.

View File

@ -2,8 +2,6 @@
package bolt
var odirect int
// fdatasync flushes written data to a file descriptor.
func fdatasync(db *DB) error {
return db.file.Sync()

View File

@ -11,7 +11,7 @@ const (
MaxKeySize = 32768
// MaxValueSize is the maximum length of a value, in bytes.
MaxValueSize = 4294967295
MaxValueSize = (1 << 31) - 2
)
const (
@ -99,6 +99,7 @@ func (b *Bucket) Cursor() *Cursor {
// Bucket retrieves a nested bucket by name.
// Returns nil if the bucket does not exist.
// The bucket instance is only valid for the lifetime of the transaction.
func (b *Bucket) Bucket(name []byte) *Bucket {
if b.buckets != nil {
if child := b.buckets[string(name)]; child != nil {
@ -148,6 +149,7 @@ func (b *Bucket) openBucket(value []byte) *Bucket {
// CreateBucket creates a new bucket at the given key and returns the new bucket.
// Returns an error if the key already exists, if the bucket name is blank, or if the bucket name is too long.
// The bucket instance is only valid for the lifetime of the transaction.
func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) {
if b.tx.db == nil {
return nil, ErrTxClosed
@ -192,6 +194,7 @@ func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) {
// CreateBucketIfNotExists creates a new bucket if it doesn't already exist and returns a reference to it.
// Returns an error if the bucket name is blank, or if the bucket name is too long.
// The bucket instance is only valid for the lifetime of the transaction.
func (b *Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) {
child, err := b.CreateBucket(key)
if err == ErrBucketExists {
@ -270,6 +273,7 @@ func (b *Bucket) Get(key []byte) []byte {
// Put sets the value for a key in the bucket.
// If the key exist then its previous value will be overwritten.
// Supplied value must remain valid for the life of the transaction.
// Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large.
func (b *Bucket) Put(key []byte, value []byte) error {
if b.tx.db == nil {
@ -346,7 +350,8 @@ func (b *Bucket) NextSequence() (uint64, error) {
// ForEach executes a function for each key/value pair in a bucket.
// If the provided function returns an error then the iteration is stopped and
// the error is returned to the caller.
// the error is returned to the caller. The provided function must not modify
// the bucket; this will result in undefined behavior.
func (b *Bucket) ForEach(fn func(k, v []byte) error) error {
if b.tx.db == nil {
return ErrTxClosed

View File

@ -253,7 +253,7 @@ func TestBucket_Delete_FreelistOverflow(t *testing.T) {
b := tx.Bucket([]byte("0"))
c := b.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
b.Delete(k)
c.Delete()
}
return nil
})

View File

@ -34,6 +34,13 @@ func (c *Cursor) First() (key []byte, value []byte) {
p, n := c.bucket.pageNode(c.bucket.root)
c.stack = append(c.stack, elemRef{page: p, node: n, index: 0})
c.first()
// If we land on an empty page then move to the next value.
// https://github.com/boltdb/bolt/issues/450
if c.stack[len(c.stack)-1].count() == 0 {
c.next()
}
k, v, flags := c.keyValue()
if (flags & uint32(bucketLeafFlag)) != 0 {
return k, nil
@ -209,28 +216,37 @@ func (c *Cursor) last() {
// next moves to the next leaf element and returns the key and value.
// If the cursor is at the last leaf element then it stays there and returns nil.
func (c *Cursor) next() (key []byte, value []byte, flags uint32) {
// Attempt to move over one element until we're successful.
// Move up the stack as we hit the end of each page in our stack.
var i int
for i = len(c.stack) - 1; i >= 0; i-- {
elem := &c.stack[i]
if elem.index < elem.count()-1 {
elem.index++
break
for {
// Attempt to move over one element until we're successful.
// Move up the stack as we hit the end of each page in our stack.
var i int
for i = len(c.stack) - 1; i >= 0; i-- {
elem := &c.stack[i]
if elem.index < elem.count()-1 {
elem.index++
break
}
}
}
// If we've hit the root page then stop and return. This will leave the
// cursor on the last element of the last page.
if i == -1 {
return nil, nil, 0
}
// If we've hit the root page then stop and return. This will leave the
// cursor on the last element of the last page.
if i == -1 {
return nil, nil, 0
}
// Otherwise start from where we left off in the stack and find the
// first element of the first leaf page.
c.stack = c.stack[:i+1]
c.first()
return c.keyValue()
// Otherwise start from where we left off in the stack and find the
// first element of the first leaf page.
c.stack = c.stack[:i+1]
c.first()
// If this is an empty page then restart and move back up the stack.
// https://github.com/boltdb/bolt/issues/450
if c.stack[len(c.stack)-1].count() == 0 {
continue
}
return c.keyValue()
}
}
// search recursively performs a binary search against a given page/node until it finds a given key.

View File

@ -303,6 +303,49 @@ func TestCursor_Restart(t *testing.T) {
tx.Rollback()
}
// Ensure that a cursor can skip over empty pages that have been deleted.
func TestCursor_First_EmptyPages(t *testing.T) {
db := NewTestDB()
defer db.Close()
// Create 1000 keys in the "widgets" bucket.
db.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucket([]byte("widgets"))
if err != nil {
t.Fatal(err)
}
for i := 0; i < 1000; i++ {
if err := b.Put(u64tob(uint64(i)), []byte{}); err != nil {
t.Fatal(err)
}
}
return nil
})
// Delete half the keys and then try to iterate.
db.Update(func(tx *bolt.Tx) error {
b := tx.Bucket([]byte("widgets"))
for i := 0; i < 600; i++ {
if err := b.Delete(u64tob(uint64(i))); err != nil {
t.Fatal(err)
}
}
c := b.Cursor()
var n int
for k, _ := c.First(); k != nil; k, _ = c.Next() {
n++
}
if n != 400 {
t.Fatalf("unexpected key count: %d", n)
}
return nil
})
}
// Ensure that a Tx can iterate over all elements in a bucket.
func TestCursor_QuickCheck(t *testing.T) {
f := func(items testdata) bool {

View File

@ -63,6 +63,10 @@ type DB struct {
// https://github.com/boltdb/bolt/issues/284
NoGrowSync bool
// If you want to read the entire database fast, you can set MmapFlag to
// syscall.MAP_POPULATE on Linux 2.6.23+ for sequential read-ahead.
MmapFlags int
// MaxBatchSize is the maximum size of a batch. Default value is
// copied from DefaultMaxBatchSize in Open.
//
@ -136,6 +140,7 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) {
options = DefaultOptions
}
db.NoGrowSync = options.NoGrowSync
db.MmapFlags = options.MmapFlags
// Set default values for later DB operations.
db.MaxBatchSize = DefaultMaxBatchSize
@ -672,6 +677,9 @@ type Options struct {
// Open database in read-only mode. Uses flock(..., LOCK_SH |LOCK_NB) to
// grab a shared lock (UNIX).
ReadOnly bool
// Sets the DB.MmapFlags flag before memory mapping the file.
MmapFlags int
}
// DefaultOptions represent the options used if nil options are passed into Open().

View File

@ -39,9 +39,6 @@ func TestOpen(t *testing.T) {
// Ensure that opening an already open database file will timeout.
func TestOpen_Timeout(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("timeout not supported on windows")
}
if runtime.GOOS == "solaris" {
t.Skip("solaris fcntl locks don't support intra-process locking")
}
@ -66,9 +63,6 @@ func TestOpen_Timeout(t *testing.T) {
// Ensure that opening an already open database file will wait until its closed.
func TestOpen_Wait(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("timeout not supported on windows")
}
if runtime.GOOS == "solaris" {
t.Skip("solaris fcntl locks don't support intra-process locking")
}
@ -622,7 +616,7 @@ func TestDB_Consistency(t *testing.T) {
})
}
// Ensure that DB stats can be substracted from one another.
// Ensure that DB stats can be subtracted from one another.
func TestDBStats_Sub(t *testing.T) {
var a, b bolt.Stats
a.TxStats.PageCount = 3

View File

@ -29,6 +29,14 @@ type Tx struct {
pages map[pgid]*page
stats TxStats
commitHandlers []func()
// WriteFlag specifies the flag for write-related methods like WriteTo().
// Tx opens the database file with the specified flag to copy the data.
//
// By default, the flag is unset, which works well for mostly in-memory
// workloads. For databases that are much larger than available RAM,
// set the flag to syscall.O_DIRECT to avoid trashing the page cache.
WriteFlag int
}
// init initializes the transaction.
@ -87,18 +95,21 @@ func (tx *Tx) Stats() TxStats {
// Bucket retrieves a bucket by name.
// Returns nil if the bucket does not exist.
// The bucket instance is only valid for the lifetime of the transaction.
func (tx *Tx) Bucket(name []byte) *Bucket {
return tx.root.Bucket(name)
}
// CreateBucket creates a new bucket.
// Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long.
// The bucket instance is only valid for the lifetime of the transaction.
func (tx *Tx) CreateBucket(name []byte) (*Bucket, error) {
return tx.root.CreateBucket(name)
}
// CreateBucketIfNotExists creates a new bucket if it doesn't already exist.
// Returns an error if the bucket name is blank, or if the bucket name is too long.
// The bucket instance is only valid for the lifetime of the transaction.
func (tx *Tx) CreateBucketIfNotExists(name []byte) (*Bucket, error) {
return tx.root.CreateBucketIfNotExists(name)
}
@ -236,7 +247,8 @@ func (tx *Tx) close() {
var freelistPendingN = tx.db.freelist.pending_count()
var freelistAlloc = tx.db.freelist.size()
// Remove writer lock.
// Remove transaction ref & writer lock.
tx.db.rwtx = nil
tx.db.rwlock.Unlock()
// Merge statistics.
@ -250,7 +262,12 @@ func (tx *Tx) close() {
} else {
tx.db.removeTx(tx)
}
// Clear all references.
tx.db = nil
tx.meta = nil
tx.root = Bucket{tx: tx}
tx.pages = nil
}
// Copy writes the entire database to a writer.
@ -263,21 +280,18 @@ func (tx *Tx) Copy(w io.Writer) error {
// WriteTo writes the entire database to a writer.
// If err == nil then exactly tx.Size() bytes will be written into the writer.
func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
// Attempt to open reader directly.
var f *os.File
if f, err = os.OpenFile(tx.db.path, os.O_RDONLY|odirect, 0); err != nil {
// Fallback to a regular open if that doesn't work.
if f, err = os.OpenFile(tx.db.path, os.O_RDONLY, 0); err != nil {
return 0, err
}
// Attempt to open reader with WriteFlag
f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0)
if err != nil {
return 0, err
}
defer f.Close()
// Copy the meta pages.
tx.db.metalock.Lock()
n, err = io.CopyN(w, f, int64(tx.db.pageSize*2))
tx.db.metalock.Unlock()
if err != nil {
_ = f.Close()
return n, fmt.Errorf("meta copy: %s", err)
}
@ -285,7 +299,6 @@ func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) {
wn, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2))
n += wn
if err != nil {
_ = f.Close()
return n, err
}

View File

@ -1 +0,0 @@
*~

View File

@ -1,19 +0,0 @@
# This file is like Go's AUTHORS file: it lists Copyright holders.
# The list of humans who have contributd is in the CONTRIBUTORS file.
#
# To contribute to this project, because it will eventually be folded
# back in to Go itself, you need to submit a CLA:
#
# http://golang.org/doc/contribute.html#copyright
#
# Then you get added to CONTRIBUTORS and you or your company get added
# to the AUTHORS file.
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
Google, Inc.
Keith Rarick <kr@xph.us> github=kr
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
Matt Layher <mdlayher@gmail.com> github=mdlayher
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t

View File

@ -1,19 +0,0 @@
# This file is like Go's CONTRIBUTORS file: it lists humans.
# The list of copyright holders (which may be companies) are in the AUTHORS file.
#
# To contribute to this project, because it will eventually be folded
# back in to Go itself, you need to submit a CLA:
#
# http://golang.org/doc/contribute.html#copyright
#
# Then you get added to CONTRIBUTORS and you or your company get added
# to the AUTHORS file.
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
Brad Fitzpatrick <bradfitz@golang.org> github=bradfitz
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
Keith Rarick <kr@xph.us> github=kr
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
Matt Layher <mdlayher@gmail.com> github=mdlayher
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t

View File

@ -1,5 +0,0 @@
We only accept contributions from users who have gone through Go's
contribution process (signed a CLA).
Please acknowledge whether you have (and use the same email) if
sending a pull request.

View File

@ -1,7 +0,0 @@
Copyright 2014 Google & the Go AUTHORS
Go AUTHORS are:
See https://code.google.com/p/go/source/browse/AUTHORS
Licensed under the terms of Go itself:
https://code.google.com/p/go/source/browse/LICENSE

View File

@ -1,75 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"errors"
)
// buffer is an io.ReadWriteCloser backed by a fixed size buffer.
// It never allocates, but moves old data as new data is written.
type buffer struct {
buf []byte
r, w int
closed bool
err error // err to return to reader
}
var (
errReadEmpty = errors.New("read from empty buffer")
errWriteFull = errors.New("write on full buffer")
)
// Read copies bytes from the buffer into p.
// It is an error to read when no data is available.
func (b *buffer) Read(p []byte) (n int, err error) {
n = copy(p, b.buf[b.r:b.w])
b.r += n
if b.closed && b.r == b.w {
err = b.err
} else if b.r == b.w && n == 0 {
err = errReadEmpty
}
return n, err
}
// Len returns the number of bytes of the unread portion of the buffer.
func (b *buffer) Len() int {
return b.w - b.r
}
// Write copies bytes from p into the buffer.
// It is an error to write more data than the buffer can hold.
func (b *buffer) Write(p []byte) (n int, err error) {
if b.closed {
return 0, errors.New("closed")
}
// Slide existing data to beginning.
if b.r > 0 && len(p) > len(b.buf)-b.w {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}
// Write new data.
n = copy(b.buf[b.w:], p)
b.w += n
if n < len(p) {
err = errWriteFull
}
return n, err
}
// Close marks the buffer as closed. Future calls to Write will
// return an error. Future calls to Read, once the buffer is
// empty, will return err.
func (b *buffer) Close(err error) {
if !b.closed {
b.closed = true
b.err = err
}
}

View File

@ -1,73 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"io"
"reflect"
"testing"
)
var bufferReadTests = []struct {
buf buffer
read, wn int
werr error
wp []byte
wbuf buffer
}{
{
buffer{[]byte{'a', 0}, 0, 1, false, nil},
5, 1, nil, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, false, nil},
},
{
buffer{[]byte{'a', 0}, 0, 1, true, io.EOF},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, true, io.EOF},
},
{
buffer{[]byte{0, 'a'}, 1, 2, false, nil},
5, 1, nil, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, false, nil},
},
{
buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF},
},
{
buffer{[]byte{}, 0, 0, false, nil},
5, 0, errReadEmpty, []byte{},
buffer{[]byte{}, 0, 0, false, nil},
},
{
buffer{[]byte{}, 0, 0, true, io.EOF},
5, 0, io.EOF, []byte{},
buffer{[]byte{}, 0, 0, true, io.EOF},
},
}
func TestBufferRead(t *testing.T) {
for i, tt := range bufferReadTests {
read := make([]byte, tt.read)
n, err := tt.buf.Read(read)
if n != tt.wn {
t.Errorf("#%d: wn = %d want %d", i, n, tt.wn)
continue
}
if err != tt.werr {
t.Errorf("#%d: werr = %v want %v", i, err, tt.werr)
continue
}
read = read[:n]
if !reflect.DeepEqual(read, tt.wp) {
t.Errorf("#%d: read = %+v want %+v", i, read, tt.wp)
}
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
t.Errorf("#%d: buf = %+v want %+v", i, tt.buf, tt.wbuf)
}
}
}

View File

@ -1,5 +0,0 @@
h2demo.linux: h2demo.go
GOOS=linux go build --tags=h2demo -o h2demo.linux .
upload: h2demo.linux
cat h2demo.linux | go run launch.go --write_object=http2-demo-server-tls/h2demo --write_object_is_public

View File

@ -1,43 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"sync"
)
type pipe struct {
b buffer
c sync.Cond
m sync.Mutex
}
// Read waits until data is available and copies bytes
// from the buffer into p.
func (r *pipe) Read(p []byte) (n int, err error) {
r.c.L.Lock()
defer r.c.L.Unlock()
for r.b.Len() == 0 && !r.b.closed {
r.c.Wait()
}
return r.b.Read(p)
}
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
func (w *pipe) Write(p []byte) (n int, err error) {
w.c.L.Lock()
defer w.c.L.Unlock()
defer w.c.Signal()
return w.b.Write(p)
}
func (c *pipe) Close(err error) {
c.c.L.Lock()
defer c.c.L.Unlock()
defer c.c.Signal()
c.b.Close(err)
}

View File

@ -1,24 +0,0 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"errors"
"testing"
)
func TestPipeClose(t *testing.T) {
var p pipe
p.c.L = &p.m
a := errors.New("a")
b := errors.New("b")
p.Close(a)
p.Close(b)
_, err := p.Read(make([]byte, 1))
if err != a {
t.Errorf("err = %v want %v", err, a)
}
}

View File

@ -1,553 +0,0 @@
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/bradfitz/http2/hpack"
)
type Transport struct {
Fallback http.RoundTripper
// TODO: remove this and make more general with a TLS dial hook, like http
InsecureTLSDial bool
connMu sync.Mutex
conns map[string][]*clientConn // key is host:port
}
type clientConn struct {
t *Transport
tconn *tls.Conn
tlsState *tls.ConnectionState
connKey []string // key(s) this connection is cached in, in t.conns
readerDone chan struct{} // closed on error
readerErr error // set before readerDone is closed
hdec *hpack.Decoder
nextRes *http.Response
mu sync.Mutex
closed bool
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
streams map[uint32]*clientStream
nextStreamID uint32
bw *bufio.Writer
werr error // first write error that has occurred
br *bufio.Reader
fr *Framer
// Settings from peer:
maxFrameSize uint32
maxConcurrentStreams uint32
initialWindowSize uint32
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
}
type clientStream struct {
ID uint32
resc chan resAndError
pw *io.PipeWriter
pr *io.PipeReader
}
type stickyErrWriter struct {
w io.Writer
err *error
}
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
n, err = sew.w.Write(p)
*sew.err = err
return
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
if t.Fallback == nil {
return nil, errors.New("http2: unsupported scheme and no Fallback")
}
return t.Fallback.RoundTrip(req)
}
host, port, err := net.SplitHostPort(req.URL.Host)
if err != nil {
host = req.URL.Host
port = "443"
}
for {
cc, err := t.getClientConn(host, port)
if err != nil {
return nil, err
}
res, err := cc.roundTrip(req)
if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
continue
}
if err != nil {
return nil, err
}
return res, nil
}
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle.
// It does not interrupt any connections currently in use.
func (t *Transport) CloseIdleConnections() {
t.connMu.Lock()
defer t.connMu.Unlock()
for _, vv := range t.conns {
for _, cc := range vv {
cc.closeIfIdle()
}
}
}
var errClientConnClosed = errors.New("http2: client conn is closed")
func shouldRetryRequest(err error) bool {
// TODO: or GOAWAY graceful shutdown stuff
return err == errClientConnClosed
}
func (t *Transport) removeClientConn(cc *clientConn) {
t.connMu.Lock()
defer t.connMu.Unlock()
for _, key := range cc.connKey {
vv, ok := t.conns[key]
if !ok {
continue
}
newList := filterOutClientConn(vv, cc)
if len(newList) > 0 {
t.conns[key] = newList
} else {
delete(t.conns, key)
}
}
}
func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
out := in[:0]
for _, v := range in {
if v != exclude {
out = append(out, v)
}
}
return out
}
func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
t.connMu.Lock()
defer t.connMu.Unlock()
key := net.JoinHostPort(host, port)
for _, cc := range t.conns[key] {
if cc.canTakeNewRequest() {
return cc, nil
}
}
if t.conns == nil {
t.conns = make(map[string][]*clientConn)
}
cc, err := t.newClientConn(host, port, key)
if err != nil {
return nil, err
}
t.conns[key] = append(t.conns[key], cc)
return cc, nil
}
func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
cfg := &tls.Config{
ServerName: host,
NextProtos: []string{NextProtoTLS},
InsecureSkipVerify: t.InsecureTLSDial,
}
tconn, err := tls.Dial("tcp", host+":"+port, cfg)
if err != nil {
return nil, err
}
if err := tconn.Handshake(); err != nil {
return nil, err
}
if !t.InsecureTLSDial {
if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
state := tconn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
// TODO(bradfitz): fall back to Fallback
return nil, fmt.Errorf("bad protocol: %v", p)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("could not negotiate protocol mutually")
}
if _, err := tconn.Write(clientPreface); err != nil {
return nil, err
}
cc := &clientConn{
t: t,
tconn: tconn,
connKey: []string{key}, // TODO: cert's validated hostnames too
tlsState: &state,
readerDone: make(chan struct{}),
nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default
initialWindowSize: 65535, // spec default
maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
streams: make(map[uint32]*clientStream),
}
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
cc.br = bufio.NewReader(tconn)
cc.fr = NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.fr.WriteSettings()
// TODO: re-send more conn-level flow control tokens when server uses all these.
cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
cc.bw.Flush()
if cc.werr != nil {
return nil, cc.werr
}
// Read the obligatory SETTINGS frame
f, err := cc.fr.ReadFrame()
if err != nil {
return nil, err
}
sf, ok := f.(*SettingsFrame)
if !ok {
return nil, fmt.Errorf("expected settings frame, got: %T", f)
}
cc.fr.WriteSettingsAck()
cc.bw.Flush()
sf.ForeachSetting(func(s Setting) error {
switch s.ID {
case SettingMaxFrameSize:
cc.maxFrameSize = s.Val
case SettingMaxConcurrentStreams:
cc.maxConcurrentStreams = s.Val
case SettingInitialWindowSize:
cc.initialWindowSize = s.Val
default:
// TODO(bradfitz): handle more
log.Printf("Unhandled Setting: %v", s)
}
return nil
})
// TODO: figure out henc size
cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
go cc.readLoop()
return cc, nil
}
func (cc *clientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.goAway = f
}
func (cc *clientConn) canTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.goAway == nil &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
cc.nextStreamID < 2147483647
}
func (cc *clientConn) closeIfIdle() {
cc.mu.Lock()
if len(cc.streams) > 0 {
cc.mu.Unlock()
return
}
cc.closed = true
// TODO: do clients send GOAWAY too? maybe? Just Close:
cc.mu.Unlock()
cc.tconn.Close()
}
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
cc.mu.Lock()
if cc.closed {
cc.mu.Unlock()
return nil, errClientConnClosed
}
cs := cc.newStream()
hasBody := false // TODO
// we send: HEADERS[+CONTINUATION] + (DATA?)
hdrs := cc.encodeHeaders(req)
first := true
for len(hdrs) > 0 {
chunk := hdrs
if len(chunk) > int(cc.maxFrameSize) {
chunk = chunk[:cc.maxFrameSize]
}
hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0
if first {
cc.fr.WriteHeaders(HeadersFrameParam{
StreamID: cs.ID,
BlockFragment: chunk,
EndStream: !hasBody,
EndHeaders: endHeaders,
})
first = false
} else {
cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
}
}
cc.bw.Flush()
werr := cc.werr
cc.mu.Unlock()
if hasBody {
// TODO: write data. and it should probably be interleaved:
// go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
}
if werr != nil {
return nil, werr
}
re := <-cs.resc
if re.err != nil {
return nil, re.err
}
res := re.res
res.Request = req
res.TLS = cc.tlsState
return res, nil
}
// requires cc.mu be held.
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
cc.hbuf.Reset()
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
host := req.Host
if host == "" {
host = req.URL.Host
}
path := req.URL.Path
if path == "" {
path = "/"
}
cc.writeHeader(":authority", host) // probably not right for all sites
cc.writeHeader(":method", req.Method)
cc.writeHeader(":path", path)
cc.writeHeader(":scheme", "https")
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" {
continue
}
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
}
return cc.hbuf.Bytes()
}
func (cc *clientConn) writeHeader(name, value string) {
log.Printf("sending %q = %q", name, value)
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
type resAndError struct {
res *http.Response
err error
}
// requires cc.mu be held.
func (cc *clientConn) newStream() *clientStream {
cs := &clientStream{
ID: cc.nextStreamID,
resc: make(chan resAndError, 1),
}
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
return cs
}
func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
cc.mu.Lock()
defer cc.mu.Unlock()
cs := cc.streams[id]
if andRemove {
delete(cc.streams, id)
}
return cs
}
// runs in its own goroutine.
func (cc *clientConn) readLoop() {
defer cc.t.removeClientConn(cc)
defer close(cc.readerDone)
activeRes := map[uint32]*clientStream{} // keyed by streamID
// Close any response bodies if the server closes prematurely.
// TODO: also do this if we've written the headers but not
// gotten a response yet.
defer func() {
err := cc.readerErr
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
for _, cs := range activeRes {
cs.pw.CloseWithError(err)
}
}()
// continueStreamID is the stream ID we're waiting for
// continuation frames for.
var continueStreamID uint32
for {
f, err := cc.fr.ReadFrame()
if err != nil {
cc.readerErr = err
return
}
log.Printf("Transport received %v: %#v", f.Header(), f)
streamID := f.Header().StreamID
_, isContinue := f.(*ContinuationFrame)
if isContinue {
if streamID != continueStreamID {
log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
cc.readerErr = ConnectionError(ErrCodeProtocol)
return
}
} else if continueStreamID != 0 {
// Continue frames need to be adjacent in the stream
// and we were in the middle of headers.
log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
cc.readerErr = ConnectionError(ErrCodeProtocol)
return
}
if streamID%2 == 0 {
// Ignore streams pushed from the server for now.
// These always have an even stream id.
continue
}
streamEnded := false
if ff, ok := f.(streamEnder); ok {
streamEnded = ff.StreamEnded()
}
cs := cc.streamByID(streamID, streamEnded)
if cs == nil {
log.Printf("Received frame for untracked stream ID %d", streamID)
continue
}
switch f := f.(type) {
case *HeadersFrame:
cc.nextRes = &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: make(http.Header),
}
cs.pr, cs.pw = io.Pipe()
cc.hdec.Write(f.HeaderBlockFragment())
case *ContinuationFrame:
cc.hdec.Write(f.HeaderBlockFragment())
case *DataFrame:
log.Printf("DATA: %q", f.Data())
cs.pw.Write(f.Data())
case *GoAwayFrame:
cc.t.removeClientConn(cc)
if f.ErrCode != 0 {
// TODO: deal with GOAWAY more. particularly the error code
log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
}
cc.setGoAway(f)
default:
log.Printf("Transport: unhandled response frame type %T", f)
}
headersEnded := false
if he, ok := f.(headersEnder); ok {
headersEnded = he.HeadersEnded()
if headersEnded {
continueStreamID = 0
} else {
continueStreamID = streamID
}
}
if streamEnded {
cs.pw.Close()
delete(activeRes, streamID)
}
if headersEnded {
if cs == nil {
panic("couldn't find stream") // TODO be graceful
}
// TODO: set the Body to one which notes the
// Close and also sends the server a
// RST_STREAM
cc.nextRes.Body = cs.pr
res := cc.nextRes
activeRes[streamID] = cs
cs.resc <- resAndError{res: res}
}
}
}
func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
// TODO: verifiy pseudo headers come before non-pseudo headers
// TODO: verifiy the status is set
log.Printf("Header field: %+v", f)
if f.Name == ":status" {
code, err := strconv.Atoi(f.Value)
if err != nil {
panic("TODO: be graceful")
}
cc.nextRes.Status = f.Value + " " + http.StatusText(code)
cc.nextRes.StatusCode = code
return
}
if strings.HasPrefix(f.Name, ":") {
// "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
// TODO: treat as invalid?
return
}
cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
}

View File

@ -1,168 +0,0 @@
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"flag"
"io"
"io/ioutil"
"net/http"
"os"
"reflect"
"strings"
"testing"
"time"
)
var (
extNet = flag.Bool("extnet", false, "do external network tests")
transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
insecure = flag.Bool("insecure", false, "insecure TLS dials")
)
func TestTransportExternal(t *testing.T) {
if !*extNet {
t.Skip("skipping external network test")
}
req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
rt := &Transport{
InsecureTLSDial: *insecure,
}
res, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("%v", err)
}
res.Write(os.Stdout)
}
func TestTransport(t *testing.T) {
const body = "sup"
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, body)
})
defer st.Close()
tr := &Transport{InsecureTLSDial: true}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
t.Logf("Got res: %+v", res)
if g, w := res.StatusCode, 200; g != w {
t.Errorf("StatusCode = %v; want %v", g, w)
}
if g, w := res.Status, "200 OK"; g != w {
t.Errorf("Status = %q; want %q", g, w)
}
wantHeader := http.Header{
"Content-Length": []string{"3"},
"Content-Type": []string{"text/plain; charset=utf-8"},
}
if !reflect.DeepEqual(res.Header, wantHeader) {
t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
}
if res.Request != req {
t.Errorf("Response.Request = %p; want %p", res.Request, req)
}
if res.TLS == nil {
t.Errorf("Response.TLS = nil; want non-nil", res.TLS)
}
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Error("Body read: %v", err)
} else if string(slurp) != body {
t.Errorf("Body = %q; want %q", slurp, body)
}
}
func TestTransportReusesConns(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
}, optOnlyServer)
defer st.Close()
tr := &Transport{InsecureTLSDial: true}
defer tr.CloseIdleConnections()
get := func() string {
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body read: %v", err)
}
addr := strings.TrimSpace(string(slurp))
if addr == "" {
t.Fatalf("didn't get an addr in response")
}
return addr
}
first := get()
second := get()
if first != second {
t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
}
}
func TestTransportAbortClosesPipes(t *testing.T) {
shutdown := make(chan struct{})
st := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush()
<-shutdown
},
optOnlyServer,
)
defer st.Close()
defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
done := make(chan struct{})
requestMade := make(chan struct{})
go func() {
defer close(done)
tr := &Transport{
InsecureTLSDial: true,
}
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
close(requestMade)
_, err = ioutil.ReadAll(res.Body)
if err == nil {
t.Error("expected error from res.Body.Read")
}
}()
<-requestMade
// Now force the serve loop to end, via closing the connection.
st.closeConn()
// deadlock? that's a bug.
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}

View File

@ -0,0 +1,4 @@
language: go
go:
- 1.4.2
sudo: false

View File

@ -57,6 +57,12 @@ bar.ShowTimeLeft = true
// show average speed
bar.ShowSpeed = true
// sets the width of the progress bar
bar.SetWidth(80)
// sets the width of the progress bar, but if terminal size smaller will be ignored
bar.SetMaxWidth(80)
// convert output to readable format (like KB, MB)
bar.SetUnits(pb.U_BYTES)

View File

@ -1,14 +1,14 @@
package main
import (
"github.com/cheggaaa/pb"
"os"
"fmt"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/cheggaaa/pb"
"io"
"time"
"strings"
"net/http"
"os"
"strconv"
"strings"
"time"
)
func main() {
@ -18,12 +18,12 @@ func main() {
return
}
sourceName, destName := os.Args[1], os.Args[2]
// check source
var source io.Reader
var sourceSize int64
if strings.HasPrefix(sourceName, "http://") {
// open as url
// open as url
resp, err := http.Get(sourceName)
if err != nil {
fmt.Printf("Can't get %s: %v\n", sourceName, err)
@ -54,9 +54,7 @@ func main() {
sourceSize = sourceStat.Size()
source = s
}
// create dest
dest, err := os.Create(destName)
if err != nil {
@ -64,15 +62,15 @@ func main() {
return
}
defer dest.Close()
// create bar
// create bar
bar := pb.New(int(sourceSize)).SetUnits(pb.U_BYTES).SetRefreshRate(time.Millisecond * 10)
bar.ShowSpeed = true
bar.Start()
// create multi writer
writer := io.MultiWriter(dest, bar)
// and copy
io.Copy(writer, source)
bar.Finish()

View File

@ -1,7 +1,7 @@
package main
import (
"github.com/cheggaaa/pb"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/cheggaaa/pb"
"time"
)
@ -13,7 +13,7 @@ func main() {
bar.ShowPercent = true
// show bar (by default already true)
bar.ShowPercent = true
bar.ShowBar = true
// no need counters
bar.ShowCounters = true

45
Godeps/_workspace/src/github.com/cheggaaa/pb/format.go generated vendored Normal file
View File

@ -0,0 +1,45 @@
package pb
import (
"fmt"
"strconv"
"strings"
)
type Units int
const (
// By default, without type handle
U_NO Units = iota
// Handle as b, Kb, Mb, etc
U_BYTES
)
// Format integer
func Format(i int64, units Units) string {
switch units {
case U_BYTES:
return FormatBytes(i)
default:
// by default just convert to string
return strconv.FormatInt(i, 10)
}
}
// Convert bytes to human readable string. Like a 2 MB, 64.2 KB, 52 B
func FormatBytes(i int64) (result string) {
switch {
case i > (1024 * 1024 * 1024 * 1024):
result = fmt.Sprintf("%.02f TB", float64(i)/1024/1024/1024/1024)
case i > (1024 * 1024 * 1024):
result = fmt.Sprintf("%.02f GB", float64(i)/1024/1024/1024)
case i > (1024 * 1024):
result = fmt.Sprintf("%.02f MB", float64(i)/1024/1024)
case i > 1024:
result = fmt.Sprintf("%.02f KB", float64(i)/1024)
default:
result = fmt.Sprintf("%d B", i)
}
result = strings.Trim(result, " ")
return
}

View File

@ -0,0 +1,37 @@
package pb
import (
"fmt"
"strconv"
"testing"
)
func Test_DefaultsToInteger(t *testing.T) {
value := int64(1000)
expected := strconv.Itoa(int(value))
actual := Format(value, -1)
if actual != expected {
t.Error(fmt.Sprintf("Expected {%s} was {%s}", expected, actual))
}
}
func Test_CanFormatAsInteger(t *testing.T) {
value := int64(1000)
expected := strconv.Itoa(int(value))
actual := Format(value, U_NO)
if actual != expected {
t.Error(fmt.Sprintf("Expected {%s} was {%s}", expected, actual))
}
}
func Test_CanFormatAsBytes(t *testing.T) {
value := int64(1000)
expected := "1000 B"
actual := Format(value, U_BYTES)
if actual != expected {
t.Error(fmt.Sprintf("Expected {%s} was {%s}", expected, actual))
}
}

367
Godeps/_workspace/src/github.com/cheggaaa/pb/pb.go generated vendored Normal file
View File

@ -0,0 +1,367 @@
package pb
import (
"fmt"
"io"
"math"
"strings"
"sync"
"sync/atomic"
"time"
"unicode/utf8"
)
const (
// Default refresh rate - 200ms
DEFAULT_REFRESH_RATE = time.Millisecond * 200
FORMAT = "[=>-]"
)
// DEPRECATED
// variables for backward compatibility, from now do not work
// use pb.Format and pb.SetRefreshRate
var (
DefaultRefreshRate = DEFAULT_REFRESH_RATE
BarStart, BarEnd, Empty, Current, CurrentN string
)
// Create new progress bar object
func New(total int) *ProgressBar {
return New64(int64(total))
}
// Create new progress bar object uding int64 as total
func New64(total int64) *ProgressBar {
pb := &ProgressBar{
Total: total,
RefreshRate: DEFAULT_REFRESH_RATE,
ShowPercent: true,
ShowCounters: true,
ShowBar: true,
ShowTimeLeft: true,
ShowFinalTime: true,
Units: U_NO,
ManualUpdate: false,
isFinish: make(chan struct{}),
currentValue: -1,
}
return pb.Format(FORMAT)
}
// Create new object and start
func StartNew(total int) *ProgressBar {
return New(total).Start()
}
// Callback for custom output
// For example:
// bar.Callback = func(s string) {
// mySuperPrint(s)
// }
//
type Callback func(out string)
type ProgressBar struct {
current int64 // current must be first member of struct (https://code.google.com/p/go/issues/detail?id=5278)
Total int64
RefreshRate time.Duration
ShowPercent, ShowCounters bool
ShowSpeed, ShowTimeLeft, ShowBar bool
ShowFinalTime bool
Output io.Writer
Callback Callback
NotPrint bool
Units Units
Width int
ForceWidth bool
ManualUpdate bool
finishOnce sync.Once //Guards isFinish
isFinish chan struct{}
startTime time.Time
startValue int64
currentValue int64
prefix, postfix string
BarStart string
BarEnd string
Empty string
Current string
CurrentN string
}
// Start print
func (pb *ProgressBar) Start() *ProgressBar {
pb.startTime = time.Now()
pb.startValue = pb.current
if pb.Total == 0 {
pb.ShowTimeLeft = false
pb.ShowPercent = false
}
if !pb.ManualUpdate {
go pb.writer()
}
return pb
}
// Increment current value
func (pb *ProgressBar) Increment() int {
return pb.Add(1)
}
// Set current value
func (pb *ProgressBar) Set(current int) *ProgressBar {
return pb.Set64(int64(current))
}
// Set64 sets the current value as int64
func (pb *ProgressBar) Set64(current int64) *ProgressBar {
atomic.StoreInt64(&pb.current, current)
return pb
}
// Add to current value
func (pb *ProgressBar) Add(add int) int {
return int(pb.Add64(int64(add)))
}
func (pb *ProgressBar) Add64(add int64) int64 {
return atomic.AddInt64(&pb.current, add)
}
// Set prefix string
func (pb *ProgressBar) Prefix(prefix string) *ProgressBar {
pb.prefix = prefix
return pb
}
// Set postfix string
func (pb *ProgressBar) Postfix(postfix string) *ProgressBar {
pb.postfix = postfix
return pb
}
// Set custom format for bar
// Example: bar.Format("[=>_]")
func (pb *ProgressBar) Format(format string) *ProgressBar {
formatEntries := strings.Split(format, "")
if len(formatEntries) == 5 {
pb.BarStart = formatEntries[0]
pb.BarEnd = formatEntries[4]
pb.Empty = formatEntries[3]
pb.Current = formatEntries[1]
pb.CurrentN = formatEntries[2]
}
return pb
}
// Set bar refresh rate
func (pb *ProgressBar) SetRefreshRate(rate time.Duration) *ProgressBar {
pb.RefreshRate = rate
return pb
}
// Set units
// bar.SetUnits(U_NO) - by default
// bar.SetUnits(U_BYTES) - for Mb, Kb, etc
func (pb *ProgressBar) SetUnits(units Units) *ProgressBar {
pb.Units = units
return pb
}
// Set max width, if width is bigger than terminal width, will be ignored
func (pb *ProgressBar) SetMaxWidth(width int) *ProgressBar {
pb.Width = width
pb.ForceWidth = false
return pb
}
// Set bar width
func (pb *ProgressBar) SetWidth(width int) *ProgressBar {
pb.Width = width
pb.ForceWidth = true
return pb
}
// End print
func (pb *ProgressBar) Finish() {
//Protect multiple calls
pb.finishOnce.Do(func() {
close(pb.isFinish)
pb.write(atomic.LoadInt64(&pb.current))
if !pb.NotPrint {
fmt.Println()
}
})
}
// End print and write string 'str'
func (pb *ProgressBar) FinishPrint(str string) {
pb.Finish()
fmt.Println(str)
}
// implement io.Writer
func (pb *ProgressBar) Write(p []byte) (n int, err error) {
n = len(p)
pb.Add(n)
return
}
// implement io.Reader
func (pb *ProgressBar) Read(p []byte) (n int, err error) {
n = len(p)
pb.Add(n)
return
}
// Create new proxy reader over bar
func (pb *ProgressBar) NewProxyReader(r io.Reader) *Reader {
return &Reader{r, pb}
}
func (pb *ProgressBar) write(current int64) {
width := pb.GetWidth()
var percentBox, countersBox, timeLeftBox, speedBox, barBox, end, out string
// percents
if pb.ShowPercent {
percent := float64(current) / (float64(pb.Total) / float64(100))
percentBox = fmt.Sprintf(" %.02f %% ", percent)
}
// counters
if pb.ShowCounters {
if pb.Total > 0 {
countersBox = fmt.Sprintf("%s / %s ", Format(current, pb.Units), Format(pb.Total, pb.Units))
} else {
countersBox = Format(current, pb.Units) + " / ? "
}
}
// time left
fromStart := time.Now().Sub(pb.startTime)
currentFromStart := current - pb.startValue
select {
case <-pb.isFinish:
if pb.ShowFinalTime {
left := (fromStart / time.Second) * time.Second
timeLeftBox = left.String()
}
default:
if pb.ShowTimeLeft && currentFromStart > 0 {
perEntry := fromStart / time.Duration(currentFromStart)
left := time.Duration(pb.Total-currentFromStart) * perEntry
left = (left / time.Second) * time.Second
timeLeftBox = left.String()
}
}
// speed
if pb.ShowSpeed && currentFromStart > 0 {
fromStart := time.Now().Sub(pb.startTime)
speed := float64(currentFromStart) / (float64(fromStart) / float64(time.Second))
speedBox = Format(int64(speed), pb.Units) + "/s "
}
barWidth := utf8.RuneCountInString(countersBox + pb.BarStart + pb.BarEnd + percentBox + timeLeftBox + speedBox + pb.prefix + pb.postfix)
// bar
if pb.ShowBar {
size := width - barWidth
if size > 0 {
if pb.Total > 0 {
curCount := int(math.Ceil((float64(current) / float64(pb.Total)) * float64(size)))
emptCount := size - curCount
barBox = pb.BarStart
if emptCount < 0 {
emptCount = 0
}
if curCount > size {
curCount = size
}
if emptCount <= 0 {
barBox += strings.Repeat(pb.Current, curCount)
} else if curCount > 0 {
barBox += strings.Repeat(pb.Current, curCount-1) + pb.CurrentN
}
barBox += strings.Repeat(pb.Empty, emptCount) + pb.BarEnd
} else {
barBox = pb.BarStart
pos := size - int(current)%int(size)
if pos-1 > 0 {
barBox += strings.Repeat(pb.Empty, pos-1)
}
barBox += pb.Current
if size-pos-1 > 0 {
barBox += strings.Repeat(pb.Empty, size-pos-1)
}
barBox += pb.BarEnd
}
}
}
// check len
out = pb.prefix + countersBox + barBox + percentBox + speedBox + timeLeftBox + pb.postfix
if utf8.RuneCountInString(out) < width {
end = strings.Repeat(" ", width-utf8.RuneCountInString(out))
}
// and print!
switch {
case pb.Output != nil:
fmt.Fprint(pb.Output, "\r"+out+end)
case pb.Callback != nil:
pb.Callback(out + end)
case !pb.NotPrint:
fmt.Print("\r" + out + end)
}
}
func (pb *ProgressBar) GetWidth() int {
if pb.ForceWidth {
return pb.Width
}
width := pb.Width
termWidth, _ := terminalWidth()
if width == 0 || termWidth <= width {
width = termWidth
}
return width
}
// Write the current state of the progressbar
func (pb *ProgressBar) Update() {
c := atomic.LoadInt64(&pb.current)
if c != pb.currentValue {
pb.write(c)
pb.currentValue = c
}
}
// Internal loop for writing progressbar
func (pb *ProgressBar) writer() {
pb.Update()
for {
select {
case <-pb.isFinish:
return
case <-time.After(pb.RefreshRate):
pb.Update()
}
}
}
type window struct {
Row uint16
Col uint16
Xpixel uint16
Ypixel uint16
}

View File

@ -0,0 +1,7 @@
// +build linux darwin freebsd netbsd openbsd
package pb
import "syscall"
const sys_ioctl = syscall.SYS_IOCTL

View File

@ -0,0 +1,5 @@
// +build solaris
package pb
const sys_ioctl = 54

View File

@ -0,0 +1,37 @@
package pb
import (
"testing"
)
func Test_IncrementAddsOne(t *testing.T) {
count := 5000
bar := New(count)
expected := 1
actual := bar.Increment()
if actual != expected {
t.Errorf("Expected {%d} was {%d}", expected, actual)
}
}
func Test_Width(t *testing.T) {
count := 5000
bar := New(count)
width := 100
bar.SetWidth(100).Callback = func(out string) {
if len(out) != width {
t.Errorf("Bar width expected {%d} was {%d}", len(out), width)
}
}
bar.Start()
bar.Increment()
bar.Finish()
}
func Test_MultipleFinish(t *testing.T) {
bar := New(5000)
bar.Add(2000)
bar.Finish()
bar.Finish()
}

View File

@ -1,16 +1,16 @@
// +build windows
package pb
import (
"github.com/olekukonko/ts"
)
func bold(str string) string {
return str
}
func terminalWidth() (int, error) {
size , err := ts.GetSize()
return size.Col() , err
}
// +build windows
package pb
import (
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/olekukonko/ts"
)
func bold(str string) string {
return str
}
func terminalWidth() (int, error) {
size, err := ts.GetSize()
return size.Col(), err
}

View File

@ -1,8 +1,9 @@
// +build linux darwin freebsd
// +build linux darwin freebsd netbsd openbsd solaris
package pb
import (
"os"
"runtime"
"syscall"
"unsafe"
@ -13,6 +14,16 @@ const (
TIOCGWINSZ_OSX = 1074295912
)
var tty *os.File
func init() {
var err error
tty, err = os.Open("/dev/tty")
if err != nil {
tty = os.Stdin
}
}
func bold(str string) string {
return "\033[1m" + str + "\033[0m"
}
@ -23,8 +34,8 @@ func terminalWidth() (int, error) {
if runtime.GOOS == "darwin" {
tio = TIOCGWINSZ_OSX
}
res, _, err := syscall.Syscall(syscall.SYS_IOCTL,
uintptr(syscall.Stdin),
res, _, err := syscall.Syscall(sys_ioctl,
tty.Fd(),
uintptr(tio),
uintptr(unsafe.Pointer(w)),
)

17
Godeps/_workspace/src/github.com/cheggaaa/pb/reader.go generated vendored Normal file
View File

@ -0,0 +1,17 @@
package pb
import (
"io"
)
// It's proxy reader, implement io.Reader
type Reader struct {
io.Reader
bar *ProgressBar
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.Reader.Read(p)
r.bar.Add(n)
return
}

View File

@ -1,5 +1,18 @@
language: go
go: 1.1
sudo: false
go:
- 1.0.3
- 1.1.2
- 1.2.2
- 1.3.3
- 1.4.2
- 1.5.1
- tip
matrix:
allow_failures:
- go: tip
script:
- go vet ./...

View File

@ -1,18 +1,17 @@
[![Coverage](http://gocover.io/_badge/github.com/codegangsta/cli?0)](http://gocover.io/github.com/codegangsta/cli)
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli)
[![GoDoc](https://godoc.org/github.com/codegangsta/cli?status.svg)](https://godoc.org/github.com/codegangsta/cli)
# cli.go
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
You can view the API docs here:
http://godoc.org/github.com/codegangsta/cli
`cli.go` is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
## Overview
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive!
**This is where `cli.go` comes into play.** `cli.go` makes command line programming fun, organized, and expressive!
## Installation
Make sure you have a working Go environment (go 1.1 is *required*). [See the install instructions](http://golang.org/doc/install.html).
Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html).
To install `cli.go`, simply run:
```
@ -25,7 +24,7 @@ export PATH=$PATH:$GOPATH/bin
```
## Getting Started
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`.
One of the philosophies behind `cli.go` is that an API should be playful and full of discovery. So a `cli.go` app can be as little as one line of code in `main()`.
``` go
package main
@ -68,8 +67,9 @@ Running this already gives you a ton of functionality, plus support for things l
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness!
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it:
``` go
/* greet.go */
package main
import (
@ -84,7 +84,7 @@ func main() {
app.Action = func(c *cli.Context) {
println("Hello friend!")
}
app.Run(os.Args)
}
```
@ -102,7 +102,8 @@ $ greet
Hello friend!
```
cli.go also generates some bitchass help text:
`cli.go` also generates neat help text:
```
$ greet help
NAME:
@ -157,6 +158,34 @@ app.Action = func(c *cli.Context) {
...
```
You can also set a destination variable for a flag, to which the content will be scanned.
``` go
...
var language string
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang",
Value: "english",
Usage: "language for the greeting",
Destination: &language,
},
}
app.Action = func(c *cli.Context) {
name := "someone"
if len(c.Args()) > 0 {
name = c.Args()[0]
}
if language == "spanish" {
println("Hola", name)
} else {
println("Hello", name)
}
}
...
```
See full list of flags at http://godoc.org/github.com/codegangsta/cli
#### Alternate Names
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g.
@ -171,6 +200,8 @@ app.Flags = []cli.Flag {
}
```
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
#### Values from the Environment
You can also have the default value set from the environment via `EnvVar`. e.g.
@ -186,7 +217,18 @@ app.Flags = []cli.Flag {
}
```
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG",
},
}
```
### Subcommands
@ -196,7 +238,7 @@ Subcommands can be defined for a more git-like command line app.
app.Commands = []cli.Command{
{
Name: "add",
ShortName: "a",
Aliases: []string{"a"},
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
@ -204,7 +246,7 @@ app.Commands = []cli.Command{
},
{
Name: "complete",
ShortName: "c",
Aliases: []string{"c"},
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
@ -212,7 +254,7 @@ app.Commands = []cli.Command{
},
{
Name: "template",
ShortName: "r",
Aliases: []string{"r"},
Usage: "options for task templates",
Subcommands: []cli.Command{
{
@ -230,7 +272,7 @@ app.Commands = []cli.Command{
},
},
},
},
},
}
...
```
@ -248,8 +290,8 @@ app := cli.NewApp()
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "complete",
ShortName: "c",
Name: "complete",
Aliases: []string{"c"},
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
@ -275,13 +317,25 @@ setting the `PROG` variable to the name of your program:
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete`
#### To Distribute
Copy `autocomplete/bash_autocomplete` into `/etc/bash_completion.d/` and rename
it to the name of the program you wish to add autocomplete support for (or
automatically install it there if you are distributing a package). Don't forget
to source the file to make it active in the current shell.
```
sudo cp src/bash_autocomplete /etc/bash_completion.d/<myprogram>
source /etc/bash_completion.d/<myprogram>
```
Alternatively, you can just document that users should source the generic
`autocomplete/bash_autocomplete` in their bash configuration with `$PROG` set
to the name of their program (as above).
## Contribution Guidelines
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
If you are have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out.
## About
cli.go is written by none other than the [Code Gangsta](http://codegangsta.io)

View File

@ -2,18 +2,23 @@ package cli
import (
"fmt"
"io"
"io/ioutil"
"os"
"time"
)
// App is the main structure of a cli application. It is recomended that
// and app be created with the cli.NewApp() function
// an app be created with the cli.NewApp() function
type App struct {
// The name of the program. Defaults to os.Args[0]
Name string
// Full name of command for help, defaults to Name
HelpName string
// Description of the program.
Usage string
// Description of the program argument format.
ArgsUsage string
// Version of the program
Version string
// List of commands to execute
@ -24,21 +29,32 @@ type App struct {
EnableBashCompletion bool
// Boolean to hide built-in help command
HideHelp bool
// Boolean to hide built-in version flag
HideVersion bool
// An action to execute when the bash-completion flag is set
BashComplete func(context *Context)
// An action to execute before any subcommands are run, but after the context is ready
// If a non-nil error is returned, no subcommands are run
Before func(context *Context) error
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After func(context *Context) error
// The action to execute when no subcommands are specified
Action func(context *Context)
// Execute this function if the proper command cannot be found
CommandNotFound func(context *Context, command string)
// Compilation date
Compiled time.Time
// Author
// List of all authors who contributed
Authors []Author
// Copyright of the binary if any
Copyright string
// Name of Author (Note: Use App.Authors, this is deprecated)
Author string
// Author e-mail
// Email of Author (Note: Use App.Authors, this is deprecated)
Email string
// Writer writer to write output to
Writer io.Writer
}
// Tries to find out when this binary was compiled.
@ -55,61 +71,95 @@ func compileTime() time.Time {
func NewApp() *App {
return &App{
Name: os.Args[0],
HelpName: os.Args[0],
Usage: "A new cli application",
Version: "0.0.0",
BashComplete: DefaultAppComplete,
Action: helpCommand.Action,
Compiled: compileTime(),
Writer: os.Stdout,
}
}
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
func (a *App) Run(arguments []string) error {
func (a *App) Run(arguments []string) (err error) {
if a.Author != "" || a.Email != "" {
a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email})
}
newCmds := []Command{}
for _, c := range a.Commands {
if c.HelpName == "" {
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, c.Name)
}
newCmds = append(newCmds, c)
}
a.Commands = newCmds
// append help to commands
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
a.appendFlag(HelpFlag)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
//append version/help flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
a.appendFlag(VersionFlag)
if !a.HideVersion {
a.appendFlag(VersionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err := set.Parse(arguments[1:])
err = set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set)
if nerr != nil {
fmt.Println(nerr)
context := NewContext(a, set, set)
fmt.Fprintln(a.Writer, nerr)
context := NewContext(a, set, nil)
ShowAppHelp(context)
fmt.Println("")
return nerr
}
context := NewContext(a, set, set)
if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
ShowAppHelp(context)
fmt.Println("")
return err
}
context := NewContext(a, set, nil)
if checkCompletions(context) {
return nil
}
if checkHelp(context) {
if err != nil {
fmt.Fprintln(a.Writer, "Incorrect Usage.")
fmt.Fprintln(a.Writer)
ShowAppHelp(context)
return err
}
if !a.HideHelp && checkHelp(context) {
ShowAppHelp(context)
return nil
}
if checkVersion(context) {
if !a.HideVersion && checkVersion(context) {
ShowVersion(context)
return nil
}
if a.After != nil {
defer func() {
afterErr := a.After(context)
if afterErr != nil {
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
@ -134,21 +184,32 @@ func (a *App) Run(arguments []string) error {
// Another entry point to the cli app, takes care of passing arguments and error handling
func (a *App) RunAndExitOnError() {
if err := a.Run(os.Args); err != nil {
os.Stderr.WriteString(fmt.Sprintln(err))
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
func (a *App) RunAsSubcommand(ctx *Context) error {
func (a *App) RunAsSubcommand(ctx *Context) (err error) {
// append help to commands
if len(a.Commands) > 0 {
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
a.appendFlag(HelpFlag)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
}
newCmds := []Command{}
for _, c := range a.Commands {
if c.HelpName == "" {
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, c.Name)
}
newCmds = append(newCmds, c)
}
a.Commands = newCmds
// append flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
@ -157,31 +218,32 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err := set.Parse(ctx.Args().Tail())
err = set.Parse(ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, ctx.globalSet)
context := NewContext(a, set, ctx)
if nerr != nil {
fmt.Println(nerr)
fmt.Fprintln(a.Writer, nerr)
fmt.Fprintln(a.Writer)
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
} else {
ShowCommandHelp(ctx, context.Args().First())
}
fmt.Println("")
return nerr
}
if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
ShowSubcommandHelp(context)
return err
}
if checkCompletions(context) {
return nil
}
if err != nil {
fmt.Fprintln(a.Writer, "Incorrect Usage.")
fmt.Fprintln(a.Writer)
ShowSubcommandHelp(context)
return err
}
if len(a.Commands) > 0 {
if checkSubcommandHelp(context) {
return nil
@ -192,6 +254,19 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
}
}
if a.After != nil {
defer func() {
afterErr := a.After(context)
if afterErr != nil {
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
@ -209,11 +284,7 @@ func (a *App) RunAsSubcommand(ctx *Context) error {
}
// Run default Action
if len(a.Commands) > 0 {
a.Action(context)
} else {
a.Action(ctx)
}
a.Action(context)
return nil
}
@ -244,3 +315,19 @@ func (a *App) appendFlag(flag Flag) {
a.Flags = append(a.Flags, flag)
}
}
// Author represents someone who has contributed to a cli project.
type Author struct {
Name string // The Authors name
Email string // The Authors email
}
// String makes Author comply to the Stringer interface, to allow an easy print in the templating process
func (a Author) String() string {
e := ""
if a.Email != "" {
e = "<" + a.Email + "> "
}
return fmt.Sprintf("%v %v", a.Name, e)
}

View File

@ -1,423 +0,0 @@
package cli_test
import (
"fmt"
"os"
"testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
)
func ExampleApp() {
// set args for examples sake
os.Args = []string{"greet", "--name", "Jeremy"}
app := cli.NewApp()
app.Name = "greet"
app.Flags = []cli.Flag{
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Action = func(c *cli.Context) {
fmt.Printf("Hello %v\n", c.String("name"))
}
app.Run(os.Args)
// Output:
// Hello Jeremy
}
func ExampleAppSubcommand() {
// set args for examples sake
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"}
app := cli.NewApp()
app.Name = "say"
app.Commands = []cli.Command{
{
Name: "hello",
ShortName: "hi",
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []cli.Command{
{
Name: "english",
ShortName: "en",
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []cli.Flag{
cli.StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *cli.Context) {
fmt.Println("Hello,", c.String("name"))
},
},
},
},
}
app.Run(os.Args)
// Output:
// Hello, Jeremy
}
func ExampleAppHelp() {
// set args for examples sake
os.Args = []string{"greet", "h", "describeit"}
app := cli.NewApp()
app.Name = "greet"
app.Flags = []cli.Flag{
cli.StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Commands = []cli.Command{
{
Name: "describeit",
ShortName: "d",
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *cli.Context) {
fmt.Printf("i like to describe things")
},
},
}
app.Run(os.Args)
// Output:
// NAME:
// describeit - use it to see a description
//
// USAGE:
// command describeit [arguments...]
//
// DESCRIPTION:
// This is how we describe describeit the function
}
func ExampleAppBashComplete() {
// set args for examples sake
os.Args = []string{"greet", "--generate-bash-completion"}
app := cli.NewApp()
app.Name = "greet"
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "describeit",
ShortName: "d",
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *cli.Context) {
fmt.Printf("i like to describe things")
},
}, {
Name: "next",
Usage: "next example",
Description: "more stuff to see when generating bash completion",
Action: func(c *cli.Context) {
fmt.Printf("the next example")
},
},
}
app.Run(os.Args)
// Output:
// describeit
// d
// next
// help
// h
}
func TestApp_Run(t *testing.T) {
s := ""
app := cli.NewApp()
app.Action = func(c *cli.Context) {
s = s + c.Args().First()
}
err := app.Run([]string{"command", "foo"})
expect(t, err, nil)
err = app.Run([]string{"command", "bar"})
expect(t, err, nil)
expect(t, s, "foobar")
}
var commandAppTests = []struct {
name string
expected bool
}{
{"foobar", true},
{"batbaz", true},
{"b", true},
{"f", true},
{"bat", false},
{"nothing", false},
}
func TestApp_Command(t *testing.T) {
app := cli.NewApp()
fooCommand := cli.Command{Name: "foobar", ShortName: "f"}
batCommand := cli.Command{Name: "batbaz", ShortName: "b"}
app.Commands = []cli.Command{
fooCommand,
batCommand,
}
for _, test := range commandAppTests {
expect(t, app.Command(test.name) != nil, test.expected)
}
}
func TestApp_CommandWithArgBeforeFlags(t *testing.T) {
var parsedOption, firstArg string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.StringFlag{Name: "option", Value: "", Usage: "some option"},
},
Action: func(c *cli.Context) {
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"})
expect(t, parsedOption, "my-option")
expect(t, firstArg, "my-arg")
}
func TestApp_Float64Flag(t *testing.T) {
var meters float64
app := cli.NewApp()
app.Flags = []cli.Flag{
cli.Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"},
}
app.Action = func(c *cli.Context) {
meters = c.Float64("height")
}
app.Run([]string{"", "--height", "1.93"})
expect(t, meters, 1.93)
}
func TestApp_ParseSliceFlags(t *testing.T) {
var parsedOption, firstArg string
var parsedIntSlice []int
var parsedStringSlice []string
app := cli.NewApp()
command := cli.Command{
Name: "cmd",
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "p", Value: &cli.IntSlice{}, Usage: "set one or more ip addr"},
cli.StringSliceFlag{Name: "ip", Value: &cli.StringSlice{}, Usage: "set one or more ports to open"},
},
Action: func(c *cli.Context) {
parsedIntSlice = c.IntSlice("p")
parsedStringSlice = c.StringSlice("ip")
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []cli.Command{command}
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"})
IntsEquals := func(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
StrsEquals := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
var expectedIntSlice = []int{22, 80}
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"}
if !IntsEquals(parsedIntSlice, expectedIntSlice) {
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice)
}
if !StrsEquals(parsedStringSlice, expectedStringSlice) {
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice)
}
}
func TestApp_BeforeFunc(t *testing.T) {
beforeRun, subcommandRun := false, false
beforeError := fmt.Errorf("fail")
var err error
app := cli.NewApp()
app.Before = func(c *cli.Context) error {
beforeRun = true
s := c.String("opt")
if s == "fail" {
return beforeError
}
return nil
}
app.Commands = []cli.Command{
cli.Command{
Name: "sub",
Action: func(c *cli.Context) {
subcommandRun = true
},
},
}
app.Flags = []cli.Flag{
cli.StringFlag{Name: "opt"},
}
// run with the Before() func succeeding
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == false {
t.Errorf("Subcommand not executed when expected")
}
// reset
beforeRun, subcommandRun = false, false
// run with the Before() func failing
err = app.Run([]string{"command", "--opt", "fail", "sub"})
// should be the same error produced by the Before func
if err != beforeError {
t.Errorf("Run error expected, but not received")
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == true {
t.Errorf("Subcommand executed when NOT expected")
}
}
func TestAppHelpPrinter(t *testing.T) {
oldPrinter := cli.HelpPrinter
defer func() {
cli.HelpPrinter = oldPrinter
}()
var wasCalled = false
cli.HelpPrinter = func(template string, data interface{}) {
wasCalled = true
}
app := cli.NewApp()
app.Run([]string{"-h"})
if wasCalled == false {
t.Errorf("Help printer expected to be called, but was not")
}
}
func TestAppVersionPrinter(t *testing.T) {
oldPrinter := cli.VersionPrinter
defer func() {
cli.VersionPrinter = oldPrinter
}()
var wasCalled = false
cli.VersionPrinter = func(c *cli.Context) {
wasCalled = true
}
app := cli.NewApp()
ctx := cli.NewContext(app, nil, nil)
cli.ShowVersion(ctx)
if wasCalled == false {
t.Errorf("Version printer expected to be called, but was not")
}
}
func TestAppCommandNotFound(t *testing.T) {
beforeRun, subcommandRun := false, false
app := cli.NewApp()
app.CommandNotFound = func(c *cli.Context, command string) {
beforeRun = true
}
app.Commands = []cli.Command{
cli.Command{
Name: "bar",
Action: func(c *cli.Context) {
subcommandRun = true
},
},
}
app.Run([]string{"command", "foo"})
expect(t, beforeRun, true)
expect(t, subcommandRun, false)
}
func TestGlobalFlagsInSubcommands(t *testing.T) {
subcommandRun := false
app := cli.NewApp()
app.Flags = []cli.Flag{
cli.BoolFlag{Name: "debug, d", Usage: "Enable debugging"},
}
app.Commands = []cli.Command{
cli.Command{
Name: "foo",
Subcommands: []cli.Command{
{
Name: "bar",
Action: func(c *cli.Context) {
if c.GlobalBool("debug") {
subcommandRun = true
}
},
},
},
},
}
app.Run([]string{"command", "-d", "foo", "bar"})
expect(t, subcommandRun, true)
}

View File

@ -1,13 +1,14 @@
#! /bin/bash
: ${PROG:=$(basename ${BASH_SOURCE})}
_cli_bash_autocomplete() {
local cur prev opts base
local cur opts base
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion )
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
return 0
}
complete -F _cli_bash_autocomplete $PROG
complete -F _cli_bash_autocomplete $PROG

View File

@ -17,3 +17,24 @@
// app.Run(os.Args)
// }
package cli
import (
"strings"
)
type MultiError struct {
Errors []error
}
func NewMultiError(err ...error) MultiError {
return MultiError{Errors: err}
}
func (m MultiError) Error() string {
errs := make([]string, len(m.Errors))
for i, err := range m.Errors {
errs[i] = err.Error()
}
return strings.Join(errs, "\n")
}

View File

@ -1,100 +0,0 @@
package cli_test
import (
"os"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
)
func Example() {
app := cli.NewApp()
app.Name = "todo"
app.Usage = "task list on the command line"
app.Commands = []cli.Command{
{
Name: "add",
ShortName: "a",
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
ShortName: "c",
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
},
}
app.Run(os.Args)
}
func ExampleSubcommand() {
app := cli.NewApp()
app.Name = "say"
app.Commands = []cli.Command{
{
Name: "hello",
ShortName: "hi",
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []cli.Command{
{
Name: "english",
ShortName: "en",
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []cli.Flag{
cli.StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Hello, ", c.String("name"))
},
}, {
Name: "spanish",
ShortName: "sp",
Usage: "sends a greeting in spanish",
Flags: []cli.Flag{
cli.StringFlag{
Name: "surname",
Value: "Jones",
Usage: "Surname of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Hola, ", c.String("surname"))
},
}, {
Name: "french",
ShortName: "fr",
Usage: "sends a greeting in french",
Flags: []cli.Flag{
cli.StringFlag{
Name: "nickname",
Value: "Stevie",
Usage: "Nickname of the person to greet",
},
},
Action: func(c *cli.Context) {
println("Bonjour, ", c.String("nickname"))
},
},
},
}, {
Name: "bye",
Usage: "says goodbye",
Action: func(c *cli.Context) {
println("bye")
},
},
}
app.Run(os.Args)
}

View File

@ -10,17 +10,24 @@ import (
type Command struct {
// The name of the command
Name string
// short name of the command. Typically one character
// short name of the command. Typically one character (deprecated, use `Aliases`)
ShortName string
// A list of aliases for the command
Aliases []string
// A short description of the usage of this command
Usage string
// A longer explanation of how the command works
Description string
// A short description of the arguments of this command
ArgsUsage string
// The function to call when checking for bash command completions
BashComplete func(context *Context)
// An action to execute before any sub-subcommands are run, but after the context is ready
// If a non-nil error is returned, no sub-subcommands are run
Before func(context *Context) error
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After func(context *Context) error
// The function to call when this command is invoked
Action func(context *Context)
// List of child commands
@ -31,16 +38,28 @@ type Command struct {
SkipFlagParsing bool
// Boolean to hide built-in help command
HideHelp bool
// Full name of command for help, defaults to full command name, including parent commands.
HelpName string
commandNamePath []string
}
// Returns the full name of the command.
// For subcommands this ensures that parent commands are part of the command path
func (c Command) FullName() string {
if c.commandNamePath == nil {
return c.Name
}
return strings.Join(c.commandNamePath, " ")
}
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
func (c Command) Run(ctx *Context) error {
if len(c.Subcommands) > 0 || c.Before != nil {
if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil {
return c.startApp(ctx)
}
if !c.HideHelp {
if !c.HideHelp && (HelpFlag != BoolFlag{}) {
// append help to flags
c.Flags = append(
c.Flags,
@ -55,40 +74,57 @@ func (c Command) Run(ctx *Context) error {
set := flagSet(c.Name, c.Flags)
set.SetOutput(ioutil.Discard)
firstFlagIndex := -1
for index, arg := range ctx.Args() {
if strings.HasPrefix(arg, "-") {
firstFlagIndex = index
break
var err error
if !c.SkipFlagParsing {
firstFlagIndex := -1
terminatorIndex := -1
for index, arg := range ctx.Args() {
if arg == "--" {
terminatorIndex = index
break
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
firstFlagIndex = index
}
}
if firstFlagIndex > -1 {
args := ctx.Args()
regularArgs := make([]string, len(args[1:firstFlagIndex]))
copy(regularArgs, args[1:firstFlagIndex])
var flagArgs []string
if terminatorIndex > -1 {
flagArgs = args[firstFlagIndex:terminatorIndex]
regularArgs = append(regularArgs, args[terminatorIndex:]...)
} else {
flagArgs = args[firstFlagIndex:]
}
err = set.Parse(append(flagArgs, regularArgs...))
} else {
err = set.Parse(ctx.Args().Tail())
}
} else {
if c.SkipFlagParsing {
err = set.Parse(append([]string{"--"}, ctx.Args().Tail()...))
}
}
var err error
if firstFlagIndex > -1 && !c.SkipFlagParsing {
args := ctx.Args()
regularArgs := args[1:firstFlagIndex]
flagArgs := args[firstFlagIndex:]
err = set.Parse(append(flagArgs, regularArgs...))
} else {
err = set.Parse(ctx.Args().Tail())
}
if err != nil {
fmt.Printf("Incorrect Usage.\n\n")
fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.")
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
fmt.Println("")
return err
}
nerr := normalizeFlags(c.Flags, set)
if nerr != nil {
fmt.Println(nerr)
fmt.Println("")
fmt.Fprintln(ctx.App.Writer, nerr)
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
fmt.Println("")
return nerr
}
context := NewContext(ctx.App, set, ctx.globalSet)
context := NewContext(ctx.App, set, ctx)
if checkCommandCompletions(context, c.Name) {
return nil
@ -102,9 +138,24 @@ func (c Command) Run(ctx *Context) error {
return nil
}
func (c Command) Names() []string {
names := []string{c.Name}
if c.ShortName != "" {
names = append(names, c.ShortName)
}
return append(names, c.Aliases...)
}
// Returns true if Command.Name or Command.ShortName matches given name
func (c Command) HasName(name string) bool {
return c.Name == name || c.ShortName == name
for _, n := range c.Names() {
if n == name {
return true
}
}
return false
}
func (c Command) startApp(ctx *Context) error {
@ -112,6 +163,12 @@ func (c Command) startApp(ctx *Context) error {
// set the name and usage
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
if c.HelpName == "" {
app.HelpName = c.HelpName
} else {
app.HelpName = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
}
if c.Description != "" {
app.Usage = c.Description
} else {
@ -126,6 +183,13 @@ func (c Command) startApp(ctx *Context) error {
app.Flags = c.Flags
app.HideHelp = c.HideHelp
app.Version = ctx.App.Version
app.HideVersion = ctx.App.HideVersion
app.Compiled = ctx.App.Compiled
app.Author = ctx.App.Author
app.Email = ctx.App.Email
app.Writer = ctx.App.Writer
// bash completion
app.EnableBashCompletion = ctx.App.EnableBashCompletion
if c.BashComplete != nil {
@ -134,11 +198,19 @@ func (c Command) startApp(ctx *Context) error {
// set the actions
app.Before = c.Before
app.After = c.After
if c.Action != nil {
app.Action = c.Action
} else {
app.Action = helpSubcommand.Action
}
var newCmds []Command
for _, cc := range app.Commands {
cc.commandNamePath = []string{c.Name, cc.Name}
newCmds = append(newCmds, cc)
}
app.Commands = newCmds
return app.RunAsSubcommand(ctx)
}

View File

@ -1,49 +0,0 @@
package cli_test
import (
"flag"
"testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
)
func TestCommandDoNotIgnoreFlags(t *testing.T) {
app := cli.NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah", "-break"}
set.Parse(test)
c := cli.NewContext(app, set, set)
command := cli.Command{
Name: "test-cmd",
ShortName: "tc",
Usage: "this is for testing",
Description: "testing",
Action: func(_ *cli.Context) {},
}
err := command.Run(c)
expect(t, err.Error(), "flag provided but not defined: -break")
}
func TestCommandIgnoreFlags(t *testing.T) {
app := cli.NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah"}
set.Parse(test)
c := cli.NewContext(app, set, set)
command := cli.Command{
Name: "test-cmd",
ShortName: "tc",
Usage: "this is for testing",
Description: "testing",
Action: func(_ *cli.Context) {},
SkipFlagParsing: true,
}
err := command.Run(c)
expect(t, err, nil)
}

View File

@ -13,16 +13,17 @@ import (
// can be used to retrieve context-specific Args and
// parsed command-line options.
type Context struct {
App *App
Command Command
flagSet *flag.FlagSet
globalSet *flag.FlagSet
setFlags map[string]bool
App *App
Command Command
flagSet *flag.FlagSet
setFlags map[string]bool
globalSetFlags map[string]bool
parentContext *Context
}
// Creates a new context. For use in when invoking an App or Command action.
func NewContext(app *App, set *flag.FlagSet, globalSet *flag.FlagSet) *Context {
return &Context{App: app, flagSet: set, globalSet: globalSet}
func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context {
return &Context{App: app, flagSet: set, parentContext: parentCtx}
}
// Looks up the value of a local int flag, returns 0 if no int flag exists
@ -72,40 +73,66 @@ func (c *Context) Generic(name string) interface{} {
// Looks up the value of a global int flag, returns 0 if no int flag exists
func (c *Context) GlobalInt(name string) int {
return lookupInt(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupInt(name, fs)
}
return 0
}
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) GlobalDuration(name string) time.Duration {
return lookupDuration(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupDuration(name, fs)
}
return 0
}
// Looks up the value of a global bool flag, returns false if no bool flag exists
func (c *Context) GlobalBool(name string) bool {
return lookupBool(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupBool(name, fs)
}
return false
}
// Looks up the value of a global string flag, returns "" if no string flag exists
func (c *Context) GlobalString(name string) string {
return lookupString(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupString(name, fs)
}
return ""
}
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
func (c *Context) GlobalStringSlice(name string) []string {
return lookupStringSlice(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupStringSlice(name, fs)
}
return nil
}
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
func (c *Context) GlobalIntSlice(name string) []int {
return lookupIntSlice(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupIntSlice(name, fs)
}
return nil
}
// Looks up the value of a global generic flag, returns nil if no generic flag exists
func (c *Context) GlobalGeneric(name string) interface{} {
return lookupGeneric(name, c.globalSet)
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupGeneric(name, fs)
}
return nil
}
// Determines if the flag was actually set exists
// Returns the number of flags set
func (c *Context) NumFlags() int {
return c.flagSet.NFlag()
}
// Determines if the flag was actually set
func (c *Context) IsSet(name string) bool {
if c.setFlags == nil {
c.setFlags = make(map[string]bool)
@ -116,6 +143,23 @@ func (c *Context) IsSet(name string) bool {
return c.setFlags[name] == true
}
// Determines if the global flag was actually set
func (c *Context) GlobalIsSet(name string) bool {
if c.globalSetFlags == nil {
c.globalSetFlags = make(map[string]bool)
ctx := c
if ctx.parentContext != nil {
ctx = ctx.parentContext
}
for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext {
ctx.flagSet.Visit(func(f *flag.Flag) {
c.globalSetFlags[f.Name] = true
})
}
}
return c.globalSetFlags[name]
}
// Returns a slice of flag names used in this context.
func (c *Context) FlagNames() (names []string) {
for _, flag := range c.Command.Flags {
@ -128,6 +172,23 @@ func (c *Context) FlagNames() (names []string) {
return
}
// Returns a slice of global flag names used by the app.
func (c *Context) GlobalFlagNames() (names []string) {
for _, flag := range c.App.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" || name == "version" {
continue
}
names = append(names, name)
}
return
}
// Returns the parent context, if any
func (c *Context) Parent() *Context {
return c.parentContext
}
type Args []string
// Returns the command line arguments associated with the context.
@ -172,6 +233,18 @@ func (a Args) Swap(from, to int) error {
return nil
}
func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet {
if ctx.parentContext != nil {
ctx = ctx.parentContext
}
for ; ctx != nil; ctx = ctx.parentContext {
if f := ctx.flagSet.Lookup(name); f != nil {
return ctx.flagSet
}
}
return nil
}
func lookupInt(name string, set *flag.FlagSet) int {
f := set.Lookup(name)
if f != nil {

View File

@ -1,77 +0,0 @@
package cli_test
import (
"flag"
"testing"
"time"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
)
func TestNewContext(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Int("myflag", 42, "doc")
command := cli.Command{Name: "mycommand"}
c := cli.NewContext(nil, set, globalSet)
c.Command = command
expect(t, c.Int("myflag"), 12)
expect(t, c.GlobalInt("myflag"), 42)
expect(t, c.Command.Name, "mycommand")
}
func TestContext_Int(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Int("myflag"), 12)
}
func TestContext_Duration(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Duration("myflag", time.Duration(12*time.Second), "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
}
func TestContext_String(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.String("myflag", "hello world", "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.String("myflag"), "hello world")
}
func TestContext_Bool(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.Bool("myflag"), false)
}
func TestContext_BoolT(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", true, "doc")
c := cli.NewContext(nil, set, set)
expect(t, c.BoolT("myflag"), true)
}
func TestContext_Args(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := cli.NewContext(nil, set, set)
set.Parse([]string{"--myflag", "bat", "baz"})
expect(t, len(c.Args()), 2)
expect(t, c.Bool("myflag"), true)
}
func TestContext_IsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
c := cli.NewContext(nil, set, set)
set.Parse([]string{"--myflag", "bat", "baz"})
expect(t, c.IsSet("myflag"), true)
expect(t, c.IsSet("otherflag"), false)
expect(t, c.IsSet("bogusflag"), false)
}

View File

@ -21,6 +21,8 @@ var VersionFlag = BoolFlag{
}
// This flag prints the help for all commands and subcommands
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
// unless HideHelp is set to true)
var HelpFlag = BoolFlag{
Name: "help, h",
Usage: "show help",
@ -67,15 +69,24 @@ type GenericFlag struct {
EnvVar string
}
// String returns the string representation of the generic flag to display the
// help text to the user (uses the String() method of the generic flag to show
// the value)
func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s %v\t`%v` %s", prefixFor(f.Name), f.Name, f.Value, "-"+f.Name+" option -"+f.Name+" option", f.Usage))
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage))
}
// Apply takes the flagset and calls Set on the generic flag with the value
// provided by the user for parsing by the flag
func (f GenericFlag) Apply(set *flag.FlagSet) {
val := f.Value
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
val.Set(envVal)
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
val.Set(envVal)
break
}
}
}
@ -88,21 +99,27 @@ func (f GenericFlag) getName() string {
return f.Name
}
// StringSlice is an opaque type for []string to satisfy flag.Value
type StringSlice []string
// Set appends the string value to the list of values
func (f *StringSlice) Set(value string) error {
*f = append(*f, value)
return nil
}
// String returns a readable representation of this value (for usage defaults)
func (f *StringSlice) String() string {
return fmt.Sprintf("%s", *f)
}
// Value returns the slice of strings set by this flag
func (f *StringSlice) Value() []string {
return *f
}
// StringSlice is a string flag that can be specified multiple times on the
// command-line
type StringSliceFlag struct {
Name string
Value *StringSlice
@ -110,24 +127,34 @@ type StringSliceFlag struct {
EnvVar string
}
// String returns the usage
func (f StringSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f StringSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
newVal := &StringSlice{}
for _, s := range strings.Split(envVal, ",") {
newVal.Set(s)
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
newVal := &StringSlice{}
for _, s := range strings.Split(envVal, ",") {
s = strings.TrimSpace(s)
newVal.Set(s)
}
f.Value = newVal
break
}
f.Value = newVal
}
}
eachName(f.Name, func(name string) {
if f.Value == nil {
f.Value = &StringSlice{}
}
set.Var(f.Value, name, f.Usage)
})
}
@ -136,10 +163,11 @@ func (f StringSliceFlag) getName() string {
return f.Name
}
// StringSlice is an opaque type for []int to satisfy flag.Value
type IntSlice []int
// Set parses the value into an integer and appends it to the list of values
func (f *IntSlice) Set(value string) error {
tmp, err := strconv.Atoi(value)
if err != nil {
return err
@ -149,14 +177,18 @@ func (f *IntSlice) Set(value string) error {
return nil
}
// String returns a readable representation of this value (for usage defaults)
func (f *IntSlice) String() string {
return fmt.Sprintf("%d", *f)
}
// Value returns the slice of ints set by this flag
func (f *IntSlice) Value() []int {
return *f
}
// IntSliceFlag is an int flag that can be specified multiple times on the
// command-line
type IntSliceFlag struct {
Name string
Value *IntSlice
@ -164,27 +196,37 @@ type IntSliceFlag struct {
EnvVar string
}
// String returns the usage
func (f IntSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f IntSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
newVal := &IntSlice{}
for _, s := range strings.Split(envVal, ",") {
err := newVal.Set(s)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
newVal := &IntSlice{}
for _, s := range strings.Split(envVal, ",") {
s = strings.TrimSpace(s)
err := newVal.Set(s)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
}
f.Value = newVal
break
}
f.Value = newVal
}
}
eachName(f.Name, func(name string) {
if f.Value == nil {
f.Value = &IntSlice{}
}
set.Var(f.Value, name, f.Usage)
})
}
@ -193,28 +235,40 @@ func (f IntSliceFlag) getName() string {
return f.Name
}
// BoolFlag is a switch that defaults to false
type BoolFlag struct {
Name string
Usage string
EnvVar string
Name string
Usage string
EnvVar string
Destination *bool
}
// String returns a readable representation of this value (for usage defaults)
func (f BoolFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f BoolFlag) Apply(set *flag.FlagSet) {
val := false
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
}
break
}
}
}
eachName(f.Name, func(name string) {
if f.Destination != nil {
set.BoolVar(f.Destination, name, val, f.Usage)
return
}
set.Bool(name, val, f.Usage)
})
}
@ -223,28 +277,41 @@ func (f BoolFlag) getName() string {
return f.Name
}
// BoolTFlag this represents a boolean flag that is true by default, but can
// still be set to false by --some-flag=false
type BoolTFlag struct {
Name string
Usage string
EnvVar string
Name string
Usage string
EnvVar string
Destination *bool
}
// String returns a readable representation of this value (for usage defaults)
func (f BoolTFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f BoolTFlag) Apply(set *flag.FlagSet) {
val := true
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
break
}
}
}
}
eachName(f.Name, func(name string) {
if f.Destination != nil {
set.BoolVar(f.Destination, name, val, f.Usage)
return
}
set.Bool(name, val, f.Usage)
})
}
@ -253,19 +320,22 @@ func (f BoolTFlag) getName() string {
return f.Name
}
// StringFlag represents a flag that takes as string value
type StringFlag struct {
Name string
Value string
Usage string
EnvVar string
Name string
Value string
Usage string
EnvVar string
Destination *string
}
// String returns the usage
func (f StringFlag) String() string {
var fmtString string
fmtString = "%s %v\t%v"
if len(f.Value) > 0 {
fmtString = "%s '%v'\t%v"
fmtString = "%s \"%v\"\t%v"
} else {
fmtString = "%s %v\t%v"
}
@ -273,14 +343,23 @@ func (f StringFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f StringFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
f.Value = envVal
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
f.Value = envVal
break
}
}
}
eachName(f.Name, func(name string) {
if f.Destination != nil {
set.StringVar(f.Destination, name, f.Value, f.Usage)
return
}
set.String(name, f.Value, f.Usage)
})
}
@ -289,28 +368,41 @@ func (f StringFlag) getName() string {
return f.Name
}
// IntFlag is a flag that takes an integer
// Errors if the value provided cannot be parsed
type IntFlag struct {
Name string
Value int
Usage string
EnvVar string
Name string
Value int
Usage string
EnvVar string
Destination *int
}
// String returns the usage
func (f IntFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f IntFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValInt, err := strconv.ParseUint(envVal, 10, 64)
if err == nil {
f.Value = int(envValInt)
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValInt, err := strconv.ParseInt(envVal, 0, 64)
if err == nil {
f.Value = int(envValInt)
break
}
}
}
}
eachName(f.Name, func(name string) {
if f.Destination != nil {
set.IntVar(f.Destination, name, f.Value, f.Usage)
return
}
set.Int(name, f.Value, f.Usage)
})
}
@ -319,28 +411,41 @@ func (f IntFlag) getName() string {
return f.Name
}
// DurationFlag is a flag that takes a duration specified in Go's duration
// format: https://golang.org/pkg/time/#ParseDuration
type DurationFlag struct {
Name string
Value time.Duration
Usage string
EnvVar string
Name string
Value time.Duration
Usage string
EnvVar string
Destination *time.Duration
}
// String returns a readable representation of this value (for usage defaults)
func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f DurationFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValDuration, err := time.ParseDuration(envVal)
if err == nil {
f.Value = envValDuration
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValDuration, err := time.ParseDuration(envVal)
if err == nil {
f.Value = envValDuration
break
}
}
}
}
eachName(f.Name, func(name string) {
if f.Destination != nil {
set.DurationVar(f.Destination, name, f.Value, f.Usage)
return
}
set.Duration(name, f.Value, f.Usage)
})
}
@ -349,28 +454,40 @@ func (f DurationFlag) getName() string {
return f.Name
}
// Float64Flag is a flag that takes an float value
// Errors if the value provided cannot be parsed
type Float64Flag struct {
Name string
Value float64
Usage string
EnvVar string
Name string
Value float64
Usage string
EnvVar string
Destination *float64
}
// String returns the usage
func (f Float64Flag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s '%v'\t%v", prefixedNames(f.Name), f.Value, f.Usage))
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f Float64Flag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
if envVal := os.Getenv(f.EnvVar); envVal != "" {
envValFloat, err := strconv.ParseFloat(envVal, 10)
if err == nil {
f.Value = float64(envValFloat)
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValFloat, err := strconv.ParseFloat(envVal, 10)
if err == nil {
f.Value = float64(envValFloat)
}
}
}
}
eachName(f.Name, func(name string) {
if f.Destination != nil {
set.Float64Var(f.Destination, name, f.Value, f.Usage)
return
}
set.Float64(name, f.Value, f.Usage)
})
}
@ -404,7 +521,7 @@ func prefixedNames(fullName string) (prefixed string) {
func withEnvHint(envVar, str string) string {
envText := ""
if envVar != "" {
envText = fmt.Sprintf(" [$%s]", envVar)
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $"))
}
return str + envText
}

View File

@ -1,587 +0,0 @@
package cli_test
import (
"fmt"
"os"
"reflect"
"strings"
"testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli"
)
var boolFlagTests = []struct {
name string
expected string
}{
{"help", "--help\t"},
{"h", "-h\t"},
}
func TestBoolFlagHelpOutput(t *testing.T) {
for _, test := range boolFlagTests {
flag := cli.BoolFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
var stringFlagTests = []struct {
name string
value string
expected string
}{
{"help", "", "--help \t"},
{"h", "", "-h \t"},
{"h", "", "-h \t"},
{"test", "Something", "--test 'Something'\t"},
}
func TestStringFlagHelpOutput(t *testing.T) {
for _, test := range stringFlagTests {
flag := cli.StringFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_FOO", "derp")
for _, test := range stringFlagTests {
flag := cli.StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_FOO]") {
t.Errorf("%s does not end with [$APP_FOO]", output)
}
}
}
var stringSliceFlagTests = []struct {
name string
value *cli.StringSlice
expected string
}{
{"help", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "--help '--help option --help option'\t"},
{"h", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "-h '-h option -h option'\t"},
{"h", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("")
return s
}(), "-h '-h option -h option'\t"},
{"test", func() *cli.StringSlice {
s := &cli.StringSlice{}
s.Set("Something")
return s
}(), "--test '--test option --test option'\t"},
}
func TestStringSliceFlagHelpOutput(t *testing.T) {
for _, test := range stringSliceFlagTests {
flag := cli.StringSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_QWWX", "11,4")
for _, test := range stringSliceFlagTests {
flag := cli.StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_QWWX]") {
t.Errorf("%q does not end with [$APP_QWWX]", output)
}
}
}
var intFlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestIntFlagHelpOutput(t *testing.T) {
for _, test := range intFlagTests {
flag := cli.IntFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAR", "2")
for _, test := range intFlagTests {
flag := cli.IntFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var durationFlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestDurationFlagHelpOutput(t *testing.T) {
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAR", "2h3m6s")
for _, test := range durationFlagTests {
flag := cli.DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var intSliceFlagTests = []struct {
name string
value *cli.IntSlice
expected string
}{
{"help", &cli.IntSlice{}, "--help '--help option --help option'\t"},
{"h", &cli.IntSlice{}, "-h '-h option -h option'\t"},
{"h", &cli.IntSlice{}, "-h '-h option -h option'\t"},
{"test", func() *cli.IntSlice {
i := &cli.IntSlice{}
i.Set("9")
return i
}(), "--test '--test option --test option'\t"},
}
func TestIntSliceFlagHelpOutput(t *testing.T) {
for _, test := range intSliceFlagTests {
flag := cli.IntSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_SMURF", "42,3")
for _, test := range intSliceFlagTests {
flag := cli.IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_SMURF]") {
t.Errorf("%q does not end with [$APP_SMURF]", output)
}
}
}
var float64FlagTests = []struct {
name string
expected string
}{
{"help", "--help '0'\t"},
{"h", "-h '0'\t"},
}
func TestFloat64FlagHelpOutput(t *testing.T) {
for _, test := range float64FlagTests {
flag := cli.Float64Flag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_BAZ", "99.4")
for _, test := range float64FlagTests {
flag := cli.Float64Flag{Name: test.name, EnvVar: "APP_BAZ"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAZ]") {
t.Errorf("%s does not end with [$APP_BAZ]", output)
}
}
}
var genericFlagTests = []struct {
name string
value cli.Generic
expected string
}{
{"help", &Parser{}, "--help <nil>\t`-help option -help option` "},
{"h", &Parser{}, "-h <nil>\t`-h option -h option` "},
{"test", &Parser{}, "--test <nil>\t`-test option -test option` "},
}
func TestGenericFlagHelpOutput(t *testing.T) {
for _, test := range genericFlagTests {
flag := cli.GenericFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) {
os.Setenv("APP_ZAP", "3")
for _, test := range genericFlagTests {
flag := cli.GenericFlag{Name: test.name, EnvVar: "APP_ZAP"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_ZAP]") {
t.Errorf("%s does not end with [$APP_ZAP]", output)
}
}
}
func TestParseMultiString(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.String("serve") != "10" {
t.Errorf("main name not set")
}
if ctx.String("s") != "10" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10"})
}
func TestParseMultiStringFromEnv(t *testing.T) {
os.Setenv("APP_COUNT", "20")
(&cli.App{
Flags: []cli.Flag{
cli.StringFlag{Name: "count, c", EnvVar: "APP_COUNT"},
},
Action: func(ctx *cli.Context) {
if ctx.String("count") != "20" {
t.Errorf("main name not set")
}
if ctx.String("c") != "20" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringSlice(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "serve, s", Value: &cli.StringSlice{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiStringSliceFromEnv(t *testing.T) {
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.StringSliceFlag{Name: "intervals, i", Value: &cli.StringSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiInt(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("serve") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("s") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10"})
}
func TestParseMultiIntFromEnv(t *testing.T) {
os.Setenv("APP_TIMEOUT_SECONDS", "10")
a := cli.App{
Flags: []cli.Flag{
cli.IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Int("timeout") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("t") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiIntSlice(t *testing.T) {
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "serve, s", Value: &cli.IntSlice{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiIntSliceFromEnv(t *testing.T) {
os.Setenv("APP_INTERVALS", "20,30,40")
(&cli.App{
Flags: []cli.Flag{
cli.IntSliceFlag{Name: "intervals, i", Value: &cli.IntSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiFloat64(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("serve") != 10.2 {
t.Errorf("main name not set")
}
if ctx.Float64("s") != 10.2 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10.2"})
}
func TestParseMultiFloat64FromEnv(t *testing.T) {
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
a := cli.App{
Flags: []cli.Flag{
cli.Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *cli.Context) {
if ctx.Float64("timeout") != 15.5 {
t.Errorf("main name not set")
}
if ctx.Float64("t") != 15.5 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBool(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("serve") != true {
t.Errorf("main name not set")
}
if ctx.Bool("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolFromEnv(t *testing.T) {
os.Setenv("APP_DEBUG", "1")
a := cli.App{
Flags: []cli.Flag{
cli.BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.Bool("debug") != true {
t.Errorf("main name not set from env")
}
if ctx.Bool("d") != true {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolT(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "serve, s"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("serve") != true {
t.Errorf("main name not set")
}
if ctx.BoolT("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolTFromEnv(t *testing.T) {
os.Setenv("APP_DEBUG", "0")
a := cli.App{
Flags: []cli.Flag{
cli.BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *cli.Context) {
if ctx.BoolT("debug") != false {
t.Errorf("main name not set from env")
}
if ctx.BoolT("d") != false {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
type Parser [2]string
func (p *Parser) Set(value string) error {
parts := strings.Split(value, ",")
if len(parts) != 2 {
return fmt.Errorf("invalid format")
}
(*p)[0] = parts[0]
(*p)[1] = parts[1]
return nil
}
func (p *Parser) String() string {
return fmt.Sprintf("%s,%s", p[0], p[1])
}
func TestParseGeneric(t *testing.T) {
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "serve, s", Value: &Parser{}},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10,20"})
}
func TestParseGenericFromEnv(t *testing.T) {
os.Setenv("APP_SERVE", "20,30")
a := cli.App{
Flags: []cli.Flag{
cli.GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"},
},
Action: func(ctx *cli.Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}

View File

@ -2,7 +2,8 @@ package cli
import (
"fmt"
"os"
"io"
"strings"
"text/tabwriter"
"text/template"
)
@ -14,31 +15,33 @@ var AppHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} {{if .Flags}}[global options] {{end}}command{{if .Flags}} [command options]{{end}} [arguments...]
{{.HelpName}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}
{{if .Version}}
VERSION:
{{.Version}}{{if or .Author .Email}}
AUTHOR:{{if .Author}}
{{.Author}}{{if .Email}} - <{{.Email}}>{{end}}{{else}}
{{.Email}}{{end}}{{end}}
{{.Version}}
{{end}}{{if len .Authors}}
AUTHOR(S):
{{range .Authors}}{{ . }}{{end}}
{{end}}{{if .Commands}}
COMMANDS:
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
{{end}}{{end}}{{if .Flags}}
GLOBAL OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}
{{end}}{{end}}{{if .Copyright }}
COPYRIGHT:
{{.Copyright}}
{{end}}
`
// The text template for the command help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var CommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
{{.HelpName}} - {{.Usage}}
USAGE:
command {{.Name}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}}
{{.HelpName}}{{if .Flags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{if .Description}}
DESCRIPTION:
{{.Description}}{{end}}{{if .Flags}}
@ -52,13 +55,13 @@ OPTIONS:
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var SubcommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
{{.HelpName}} - {{.Usage}}
USAGE:
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...]
{{.HelpName}} command{{if .Flags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}
COMMANDS:
{{range .Commands}}{{.Name}}{{with .ShortName}}, {{.}}{{end}}{{ "\t" }}{{.Usage}}
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
@ -67,8 +70,9 @@ OPTIONS:
var helpCommand = Command{
Name: "help",
ShortName: "h",
Aliases: []string{"h"},
Usage: "Shows a list of commands or help for one command",
ArgsUsage: "[command]",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
@ -81,8 +85,9 @@ var helpCommand = Command{
var helpSubcommand = Command{
Name: "help",
ShortName: "h",
Aliases: []string{"h"},
Usage: "Shows a list of commands or help for one command",
ArgsUsage: "[command]",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
@ -93,45 +98,52 @@ var helpSubcommand = Command{
},
}
// Prints help for the App
var HelpPrinter = printHelp
// Prints help for the App or Command
type helpPrinter func(w io.Writer, templ string, data interface{})
var HelpPrinter helpPrinter = printHelp
// Prints version for the App
var VersionPrinter = printVersion
func ShowAppHelp(c *Context) {
HelpPrinter(AppHelpTemplate, c.App)
HelpPrinter(c.App.Writer, AppHelpTemplate, c.App)
}
// Prints the list of subcommands as the default app completion method
func DefaultAppComplete(c *Context) {
for _, command := range c.App.Commands {
fmt.Println(command.Name)
if command.ShortName != "" {
fmt.Println(command.ShortName)
for _, name := range command.Names() {
fmt.Fprintln(c.App.Writer, name)
}
}
}
// Prints help for the given command
func ShowCommandHelp(c *Context, command string) {
for _, c := range c.App.Commands {
func ShowCommandHelp(ctx *Context, command string) {
// show the subcommand help for a command with subcommands
if command == "" {
HelpPrinter(ctx.App.Writer, SubcommandHelpTemplate, ctx.App)
return
}
for _, c := range ctx.App.Commands {
if c.HasName(command) {
HelpPrinter(CommandHelpTemplate, c)
HelpPrinter(ctx.App.Writer, CommandHelpTemplate, c)
return
}
}
if c.App.CommandNotFound != nil {
c.App.CommandNotFound(c, command)
if ctx.App.CommandNotFound != nil {
ctx.App.CommandNotFound(ctx, command)
} else {
fmt.Printf("No help topic for '%v'\n", command)
fmt.Fprintf(ctx.App.Writer, "No help topic for '%v'\n", command)
}
}
// Prints help for the given subcommand
func ShowSubcommandHelp(c *Context) {
HelpPrinter(SubcommandHelpTemplate, c.App)
ShowCommandHelp(c, c.Command.Name)
}
// Prints the version number of the App
@ -140,7 +152,7 @@ func ShowVersion(c *Context) {
}
func printVersion(c *Context) {
fmt.Printf("%v version %v\n", c.App.Name, c.App.Version)
fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version)
}
// Prints the lists of commands within a given context
@ -159,9 +171,13 @@ func ShowCommandCompletions(ctx *Context, command string) {
}
}
func printHelp(templ string, data interface{}) {
w := tabwriter.NewWriter(os.Stdout, 0, 8, 1, '\t', 0)
t := template.Must(template.New("help").Parse(templ))
func printHelp(out io.Writer, templ string, data interface{}) {
funcMap := template.FuncMap{
"join": strings.Join,
}
w := tabwriter.NewWriter(out, 0, 8, 1, '\t', 0)
t := template.Must(template.New("help").Funcs(funcMap).Parse(templ))
err := t.Execute(w, data)
if err != nil {
panic(err)
@ -170,21 +186,27 @@ func printHelp(templ string, data interface{}) {
}
func checkVersion(c *Context) bool {
if c.GlobalBool("version") {
ShowVersion(c)
return true
found := false
if VersionFlag.Name != "" {
eachName(VersionFlag.Name, func(name string) {
if c.GlobalBool(name) || c.Bool(name) {
found = true
}
})
}
return false
return found
}
func checkHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") {
ShowAppHelp(c)
return true
found := false
if HelpFlag.Name != "" {
eachName(HelpFlag.Name, func(name string) {
if c.GlobalBool(name) || c.Bool(name) {
found = true
}
})
}
return false
return found
}
func checkCommandHelp(c *Context, name string) bool {
@ -206,7 +228,7 @@ func checkSubcommandHelp(c *Context) bool {
}
func checkCompletions(c *Context) bool {
if c.GlobalBool(BashCompletionFlag.Name) && c.App.EnableBashCompletion {
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion {
ShowCompletions(c)
return true
}

View File

@ -1,19 +0,0 @@
package cli_test
import (
"reflect"
"testing"
)
/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}
func refute(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}

View File

@ -0,0 +1,7 @@
Copyright (C) 2014 Thomas Rooney
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,64 @@
# Gexpect
Gexpect is a pure golang expect-like module.
It makes it simpler to create and control other terminal applications.
child, err := gexpect.Spawn("python")
if err != nil {
panic(err)
}
child.Expect(">>>")
child.SendLine("print 'Hello World'")
child.Interact()
child.Close()
## Examples
`Spawn` handles the argument parsing from a string
child.Spawn("/bin/sh -c 'echo \"my complicated command\" | tee log | cat > log2'")
child.ReadLine() // ReadLine() (string, error)
child.ReadUntil(' ') // ReadUntil(delim byte) ([]byte, error)
`ReadLine`, `ReadUntil` and `SendLine` send strings from/to `stdout/stdin` respectively
child := gexpect.Spawn("cat")
child.SendLine("echoing process_stdin") // SendLine(command string) (error)
msg, _ := child.ReadLine() // msg = echoing process_stdin
`Wait` and `Close` allow for graceful and ungraceful termination.
child.Wait() // Waits until the child terminates naturally.
child.Close() // Sends a kill command
`AsyncInteractChannels` spawns two go routines to pipe into and from `stdout`/`stdin`, allowing for some usecases to be a little simpler.
child := gexpect.spawn("sh")
sender, reciever := child.AsyncInteractChannels()
sender <- "echo Hello World\n" // Send to stdin
line, open := <- reciever // Recieve a line from stdout/stderr
// When the subprocess stops (e.g. with child.Close()) , receiver is closed
if open {
fmt.Printf("Received %s", line)]
}
`ExpectRegex` uses golang's internal regex engine to wait until a match from the process with the given regular expression (or an error on process termination with no match).
child := gexpect.Spawn("echo accb")
match, _ := child.ExpectRegex("a..b")
// (match=true)
`ExpectRegexFind` allows for groups to be extracted from process stdout. The first element is an array of containing the total matched text, followed by each subexpression group match.
child := gexpect.Spawn("echo 123 456 789")
result, _ := child.ExpectRegexFind("\d+ (\d+) (\d+)")
// result = []string{"123 456 789", "456", "789"}
See `gexpect_test.go` and the `examples` folder for full syntax
## Credits
github.com/kballard/go-shellquote
github.com/kr/pty
KMP Algorithm: "http://blog.databigbang.com/searching-for-substrings-in-streams-a-slight-modification-of-the-knuth-morris-pratt-algorithm-in-haxe/"

View File

@ -0,0 +1,27 @@
package main
import gexpect "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
import "log"
func main() {
log.Printf("Testing Ftp... ")
child, err := gexpect.Spawn("ftp ftp.openbsd.org")
if err != nil {
panic(err)
}
child.Expect("Name")
child.SendLine("anonymous")
child.Expect("Password")
child.SendLine("pexpect@sourceforge.net")
child.Expect("ftp> ")
child.SendLine("cd /pub/OpenBSD/3.7/packages/i386")
child.Expect("ftp> ")
child.SendLine("bin")
child.Expect("ftp> ")
child.SendLine("prompt")
child.Expect("ftp> ")
child.SendLine("pwd")
child.Expect("ftp> ")
log.Printf("Success\n")
}

View File

@ -0,0 +1,15 @@
package main
import gexpect "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
import "log"
func main() {
log.Printf("Testing Ping interact... \n")
child, err := gexpect.Spawn("ping -c8 127.0.0.1")
if err != nil {
panic(err)
}
child.Interact()
log.Printf("Success\n")
}

View File

@ -0,0 +1,22 @@
package main
import "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
import "fmt"
func main() {
fmt.Printf("Starting python.. \n")
child, err := gexpect.Spawn("python")
if err != nil {
panic(err)
}
fmt.Printf("Expecting >>>.. \n")
child.Expect(">>>")
fmt.Printf("print 'Hello World'..\n")
child.SendLine("print 'Hello World'")
child.Expect(">>>")
fmt.Printf("Interacting.. \n")
child.Interact()
fmt.Printf("Done \n")
child.Close()
}

View File

@ -0,0 +1,53 @@
package main
import "github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/gexpect"
import "fmt"
import "strings"
func main() {
waitChan := make(chan string)
fmt.Printf("Starting screen.. \n")
child, err := gexpect.Spawn("screen")
if err != nil {
panic(err)
}
sender, reciever := child.AsyncInteractChannels()
go func() {
waitString := ""
count := 0
for {
select {
case waitString = <-waitChan:
count++
case msg, open := <-reciever:
if !open {
return
}
fmt.Printf("Recieved: %s\n", msg)
if strings.Contains(msg, waitString) {
if count >= 1 {
waitChan <- msg
count -= 1
}
}
}
}
}()
wait := func(str string) {
waitChan <- str
<-waitChan
}
fmt.Printf("Waiting until started.. \n")
wait(" ")
fmt.Printf("Sending Enter.. \n")
sender <- "\n"
wait("$")
fmt.Printf("Sending echo.. \n")
sender <- "echo Hello World\n"
wait("Hello World")
fmt.Printf("Received echo. \n")
}

View File

@ -0,0 +1,430 @@
package gexpect
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"time"
"unicode/utf8"
shell "github.com/coreos/etcd/Godeps/_workspace/src/github.com/kballard/go-shellquote"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/kr/pty"
)
type ExpectSubprocess struct {
Cmd *exec.Cmd
buf *buffer
outputBuffer []byte
}
type buffer struct {
f *os.File
b bytes.Buffer
collect bool
collection bytes.Buffer
}
func (buf *buffer) StartCollecting() {
buf.collect = true
}
func (buf *buffer) StopCollecting() (result string) {
result = string(buf.collection.Bytes())
buf.collect = false
buf.collection.Reset()
return result
}
func (buf *buffer) Read(chunk []byte) (int, error) {
nread := 0
if buf.b.Len() > 0 {
n, err := buf.b.Read(chunk)
if err != nil {
return n, err
}
if n == len(chunk) {
return n, nil
}
nread = n
}
fn, err := buf.f.Read(chunk[nread:])
return fn + nread, err
}
func (buf *buffer) ReadRune() (r rune, size int, err error) {
l := buf.b.Len()
chunk := make([]byte, utf8.UTFMax)
if l > 0 {
n, err := buf.b.Read(chunk)
if err != nil {
return 0, 0, err
}
if utf8.FullRune(chunk) {
r, rL := utf8.DecodeRune(chunk)
if n > rL {
buf.PutBack(chunk[rL:n])
}
if buf.collect {
buf.collection.WriteRune(r)
}
return r, rL, nil
}
}
// else add bytes from the file, then try that
for l < utf8.UTFMax {
fn, err := buf.f.Read(chunk[l : l+1])
if err != nil {
return 0, 0, err
}
l = l + fn
if utf8.FullRune(chunk) {
r, rL := utf8.DecodeRune(chunk)
if buf.collect {
buf.collection.WriteRune(r)
}
return r, rL, nil
}
}
return 0, 0, errors.New("File is not a valid UTF=8 encoding")
}
func (buf *buffer) PutBack(chunk []byte) {
if len(chunk) == 0 {
return
}
if buf.b.Len() == 0 {
buf.b.Write(chunk)
return
}
d := make([]byte, 0, len(chunk)+buf.b.Len())
d = append(d, chunk...)
d = append(d, buf.b.Bytes()...)
buf.b.Reset()
buf.b.Write(d)
}
func SpawnAtDirectory(command string, directory string) (*ExpectSubprocess, error) {
expect, err := _spawn(command)
if err != nil {
return nil, err
}
expect.Cmd.Dir = directory
return _start(expect)
}
func Command(command string) (*ExpectSubprocess, error) {
expect, err := _spawn(command)
if err != nil {
return nil, err
}
return expect, nil
}
func (expect *ExpectSubprocess) Start() error {
_, err := _start(expect)
return err
}
func Spawn(command string) (*ExpectSubprocess, error) {
expect, err := _spawn(command)
if err != nil {
return nil, err
}
return _start(expect)
}
func (expect *ExpectSubprocess) Close() error {
return expect.Cmd.Process.Kill()
}
func (expect *ExpectSubprocess) AsyncInteractChannels() (send chan string, receive chan string) {
receive = make(chan string)
send = make(chan string)
go func() {
for {
str, err := expect.ReadLine()
if err != nil {
close(receive)
return
}
receive <- str
}
}()
go func() {
for {
select {
case sendCommand, exists := <-send:
{
if !exists {
return
}
err := expect.Send(sendCommand)
if err != nil {
receive <- "gexpect Error: " + err.Error()
return
}
}
}
}
}()
return
}
func (expect *ExpectSubprocess) ExpectRegex(regex string) (bool, error) {
return regexp.MatchReader(regex, expect.buf)
}
func (expect *ExpectSubprocess) expectRegexFind(regex string, output bool) ([]string, string, error) {
re, err := regexp.Compile(regex)
if err != nil {
return nil, "", err
}
expect.buf.StartCollecting()
pairs := re.FindReaderSubmatchIndex(expect.buf)
stringIndexedInto := expect.buf.StopCollecting()
l := len(pairs)
numPairs := l / 2
result := make([]string, numPairs)
for i := 0; i < numPairs; i += 1 {
result[i] = stringIndexedInto[pairs[i*2]:pairs[i*2+1]]
}
// convert indexes to strings
if len(result) == 0 {
err = fmt.Errorf("ExpectRegex didn't find regex '%v'.", regex)
}
return result, stringIndexedInto, err
}
func (expect *ExpectSubprocess) expectTimeoutRegexFind(regex string, timeout time.Duration) (result []string, out string, err error) {
t := make(chan bool)
go func() {
result, out, err = expect.ExpectRegexFindWithOutput(regex)
t <- false
}()
go func() {
time.Sleep(timeout)
err = fmt.Errorf("ExpectRegex timed out after %v finding '%v'.\nOutput:\n%s", timeout, regex, expect.Collect())
t <- true
}()
<-t
return result, out, err
}
func (expect *ExpectSubprocess) ExpectRegexFind(regex string) ([]string, error) {
result, _, err := expect.expectRegexFind(regex, false)
return result, err
}
func (expect *ExpectSubprocess) ExpectTimeoutRegexFind(regex string, timeout time.Duration) ([]string, error) {
result, _, err := expect.expectTimeoutRegexFind(regex, timeout)
return result, err
}
func (expect *ExpectSubprocess) ExpectRegexFindWithOutput(regex string) ([]string, string, error) {
return expect.expectRegexFind(regex, true)
}
func (expect *ExpectSubprocess) ExpectTimeoutRegexFindWithOutput(regex string, timeout time.Duration) ([]string, string, error) {
return expect.expectTimeoutRegexFind(regex, timeout)
}
func buildKMPTable(searchString string) []int {
pos := 2
cnd := 0
length := len(searchString)
var table []int
if length < 2 {
length = 2
}
table = make([]int, length)
table[0] = -1
table[1] = 0
for pos < len(searchString) {
if searchString[pos-1] == searchString[cnd] {
cnd += 1
table[pos] = cnd
pos += 1
} else if cnd > 0 {
cnd = table[cnd]
} else {
table[pos] = 0
pos += 1
}
}
return table
}
func (expect *ExpectSubprocess) ExpectTimeout(searchString string, timeout time.Duration) (e error) {
result := make(chan error)
go func() {
result <- expect.Expect(searchString)
}()
select {
case e = <-result:
case <-time.After(timeout):
e = fmt.Errorf("Expect timed out after %v waiting for '%v'.\nOutput:\n%s", timeout, searchString, expect.Collect())
}
return e
}
func (expect *ExpectSubprocess) Expect(searchString string) (e error) {
chunk := make([]byte, len(searchString)*2)
target := len(searchString)
if expect.outputBuffer != nil {
expect.outputBuffer = expect.outputBuffer[0:]
}
m := 0
i := 0
// Build KMP Table
table := buildKMPTable(searchString)
for {
n, err := expect.buf.Read(chunk)
if err != nil {
return err
}
if expect.outputBuffer != nil {
expect.outputBuffer = append(expect.outputBuffer, chunk[:n]...)
}
offset := m + i
for m+i-offset < n {
if searchString[i] == chunk[m+i-offset] {
i += 1
if i == target {
unreadIndex := m + i - offset
if len(chunk) > unreadIndex {
expect.buf.PutBack(chunk[unreadIndex:])
}
return nil
}
} else {
m += i - table[i]
if table[i] > -1 {
i = table[i]
} else {
i = 0
}
}
}
}
}
func (expect *ExpectSubprocess) Send(command string) error {
_, err := io.WriteString(expect.buf.f, command)
return err
}
func (expect *ExpectSubprocess) Capture() {
if expect.outputBuffer == nil {
expect.outputBuffer = make([]byte, 0)
}
}
func (expect *ExpectSubprocess) Collect() []byte {
collectOutput := make([]byte, len(expect.outputBuffer))
copy(collectOutput, expect.outputBuffer)
expect.outputBuffer = nil
return collectOutput
}
func (expect *ExpectSubprocess) SendLine(command string) error {
_, err := io.WriteString(expect.buf.f, command+"\r\n")
return err
}
func (expect *ExpectSubprocess) Interact() {
defer expect.Cmd.Wait()
io.Copy(os.Stdout, &expect.buf.b)
go io.Copy(os.Stdout, expect.buf.f)
go io.Copy(expect.buf.f, os.Stdin)
}
func (expect *ExpectSubprocess) ReadUntil(delim byte) ([]byte, error) {
join := make([]byte, 1, 512)
chunk := make([]byte, 255)
for {
n, err := expect.buf.Read(chunk)
if err != nil {
return join, err
}
for i := 0; i < n; i++ {
if chunk[i] == delim {
if len(chunk) > i+1 {
expect.buf.PutBack(chunk[i+1:])
}
return join, nil
} else {
join = append(join, chunk[i])
}
}
}
}
func (expect *ExpectSubprocess) Wait() error {
return expect.Cmd.Wait()
}
func (expect *ExpectSubprocess) ReadLine() (string, error) {
str, err := expect.ReadUntil('\n')
if err != nil {
return "", err
}
return string(str), nil
}
func _start(expect *ExpectSubprocess) (*ExpectSubprocess, error) {
f, err := pty.Start(expect.Cmd)
if err != nil {
return nil, err
}
expect.buf.f = f
return expect, nil
}
func _spawn(command string) (*ExpectSubprocess, error) {
wrapper := new(ExpectSubprocess)
wrapper.outputBuffer = nil
splitArgs, err := shell.Split(command)
if err != nil {
return nil, err
}
numArguments := len(splitArgs) - 1
if numArguments < 0 {
return nil, errors.New("gexpect: No command given to spawn")
}
path, err := exec.LookPath(splitArgs[0])
if err != nil {
return nil, err
}
if numArguments >= 1 {
wrapper.Cmd = exec.Command(path, splitArgs[1:]...)
} else {
wrapper.Cmd = exec.Command(path)
}
wrapper.buf = new(buffer)
return wrapper, nil
}

View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Brian Goff
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,19 @@
package md2man
import (
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/russross/blackfriday"
)
func Render(doc []byte) []byte {
renderer := RoffRenderer(0)
extensions := 0
extensions |= blackfriday.EXTENSION_NO_INTRA_EMPHASIS
extensions |= blackfriday.EXTENSION_TABLES
extensions |= blackfriday.EXTENSION_FENCED_CODE
extensions |= blackfriday.EXTENSION_AUTOLINK
extensions |= blackfriday.EXTENSION_SPACE_HEADERS
extensions |= blackfriday.EXTENSION_FOOTNOTES
extensions |= blackfriday.EXTENSION_TITLEBLOCK
return blackfriday.Markdown(doc, renderer, extensions)
}

View File

@ -0,0 +1,269 @@
package md2man
import (
"bytes"
"fmt"
"html"
"strings"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/russross/blackfriday"
)
type roffRenderer struct{}
func RoffRenderer(flags int) blackfriday.Renderer {
return &roffRenderer{}
}
func (r *roffRenderer) GetFlags() int {
return 0
}
func (r *roffRenderer) TitleBlock(out *bytes.Buffer, text []byte) {
out.WriteString(".TH ")
splitText := bytes.Split(text, []byte("\n"))
for i, line := range splitText {
line = bytes.TrimPrefix(line, []byte("% "))
if i == 0 {
line = bytes.Replace(line, []byte("("), []byte("\" \""), 1)
line = bytes.Replace(line, []byte(")"), []byte("\" \""), 1)
}
line = append([]byte("\""), line...)
line = append(line, []byte("\" ")...)
out.Write(line)
}
out.WriteString(" \"\"\n")
}
func (r *roffRenderer) BlockCode(out *bytes.Buffer, text []byte, lang string) {
out.WriteString("\n.PP\n.RS\n\n.nf\n")
escapeSpecialChars(out, text)
out.WriteString("\n.fi\n.RE\n")
}
func (r *roffRenderer) BlockQuote(out *bytes.Buffer, text []byte) {
out.WriteString("\n.PP\n.RS\n")
out.Write(text)
out.WriteString("\n.RE\n")
}
func (r *roffRenderer) BlockHtml(out *bytes.Buffer, text []byte) {
out.Write(text)
}
func (r *roffRenderer) Header(out *bytes.Buffer, text func() bool, level int, id string) {
marker := out.Len()
switch {
case marker == 0:
// This is the doc header
out.WriteString(".TH ")
case level == 1:
out.WriteString("\n\n.SH ")
case level == 2:
out.WriteString("\n.SH ")
default:
out.WriteString("\n.SS ")
}
if !text() {
out.Truncate(marker)
return
}
}
func (r *roffRenderer) HRule(out *bytes.Buffer) {
out.WriteString("\n.ti 0\n\\l'\\n(.lu'\n")
}
func (r *roffRenderer) List(out *bytes.Buffer, text func() bool, flags int) {
marker := out.Len()
out.WriteString(".IP ")
if flags&blackfriday.LIST_TYPE_ORDERED != 0 {
out.WriteString("\\(bu 2")
} else {
out.WriteString("\\n+[step" + string(flags) + "]")
}
out.WriteString("\n")
if !text() {
out.Truncate(marker)
return
}
}
func (r *roffRenderer) ListItem(out *bytes.Buffer, text []byte, flags int) {
out.WriteString("\n\\item ")
out.Write(text)
}
func (r *roffRenderer) Paragraph(out *bytes.Buffer, text func() bool) {
marker := out.Len()
out.WriteString("\n.PP\n")
if !text() {
out.Truncate(marker)
return
}
if marker != 0 {
out.WriteString("\n")
}
}
// TODO: This might now work
func (r *roffRenderer) Table(out *bytes.Buffer, header []byte, body []byte, columnData []int) {
out.WriteString(".TS\nallbox;\n")
out.Write(header)
out.Write(body)
out.WriteString("\n.TE\n")
}
func (r *roffRenderer) TableRow(out *bytes.Buffer, text []byte) {
if out.Len() > 0 {
out.WriteString("\n")
}
out.Write(text)
out.WriteString("\n")
}
func (r *roffRenderer) TableHeaderCell(out *bytes.Buffer, text []byte, align int) {
if out.Len() > 0 {
out.WriteString(" ")
}
out.Write(text)
out.WriteString(" ")
}
// TODO: This is probably broken
func (r *roffRenderer) TableCell(out *bytes.Buffer, text []byte, align int) {
if out.Len() > 0 {
out.WriteString("\t")
}
out.Write(text)
out.WriteString("\t")
}
func (r *roffRenderer) Footnotes(out *bytes.Buffer, text func() bool) {
}
func (r *roffRenderer) FootnoteItem(out *bytes.Buffer, name, text []byte, flags int) {
}
func (r *roffRenderer) AutoLink(out *bytes.Buffer, link []byte, kind int) {
out.WriteString("\n\\[la]")
out.Write(link)
out.WriteString("\\[ra]")
}
func (r *roffRenderer) CodeSpan(out *bytes.Buffer, text []byte) {
out.WriteString("\\fB\\fC")
escapeSpecialChars(out, text)
out.WriteString("\\fR")
}
func (r *roffRenderer) DoubleEmphasis(out *bytes.Buffer, text []byte) {
out.WriteString("\\fB")
out.Write(text)
out.WriteString("\\fP")
}
func (r *roffRenderer) Emphasis(out *bytes.Buffer, text []byte) {
out.WriteString("\\fI")
out.Write(text)
out.WriteString("\\fP")
}
func (r *roffRenderer) Image(out *bytes.Buffer, link []byte, title []byte, alt []byte) {
}
func (r *roffRenderer) LineBreak(out *bytes.Buffer) {
out.WriteString("\n.br\n")
}
func (r *roffRenderer) Link(out *bytes.Buffer, link []byte, title []byte, content []byte) {
r.AutoLink(out, link, 0)
}
func (r *roffRenderer) RawHtmlTag(out *bytes.Buffer, tag []byte) {
out.Write(tag)
}
func (r *roffRenderer) TripleEmphasis(out *bytes.Buffer, text []byte) {
out.WriteString("\\s+2")
out.Write(text)
out.WriteString("\\s-2")
}
func (r *roffRenderer) StrikeThrough(out *bytes.Buffer, text []byte) {
}
func (r *roffRenderer) FootnoteRef(out *bytes.Buffer, ref []byte, id int) {
}
func (r *roffRenderer) Entity(out *bytes.Buffer, entity []byte) {
out.WriteString(html.UnescapeString(string(entity)))
}
func processFooterText(text []byte) []byte {
text = bytes.TrimPrefix(text, []byte("% "))
newText := []byte{}
textArr := strings.Split(string(text), ") ")
for i, w := range textArr {
if i == 0 {
w = strings.Replace(w, "(", "\" \"", 1)
w = fmt.Sprintf("\"%s\"", w)
} else {
w = fmt.Sprintf(" \"%s\"", w)
}
newText = append(newText, []byte(w)...)
}
newText = append(newText, []byte(" \"\"")...)
return newText
}
func (r *roffRenderer) NormalText(out *bytes.Buffer, text []byte) {
escapeSpecialChars(out, text)
}
func (r *roffRenderer) DocumentHeader(out *bytes.Buffer) {
}
func (r *roffRenderer) DocumentFooter(out *bytes.Buffer) {
}
func needsBackslash(c byte) bool {
for _, r := range []byte("-_&\\~") {
if c == r {
return true
}
}
return false
}
func escapeSpecialChars(out *bytes.Buffer, text []byte) {
for i := 0; i < len(text); i++ {
// directly copy normal characters
org := i
for i < len(text) && !needsBackslash(text[i]) {
i++
}
if i > org {
out.Write(text[org:i])
}
// escape a character
if i >= len(text) {
break
}
out.WriteByte('\\')
out.WriteByte(text[i])
}
}

36
Godeps/_workspace/src/github.com/gogo/protobuf/LICENSE generated vendored Normal file
View File

@ -0,0 +1,36 @@
Extensions for Protocol Buffers to create more go like structures.
Copyright (c) 2013, Vastech SA (PTY) LTD. All rights reserved.
http://github.com/gogo/protobuf/gogoproto
Go support for Protocol Buffers - Google's data interchange format
Copyright 2010 The Go Authors. All rights reserved.
https://github.com/golang/protobuf
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -39,5 +39,5 @@ test: install generate-test-pbs
generate-test-pbs:
make install
make -C testdata
protoc-min-version --version="3.0.0" --proto_path=.:../../../../ proto3_proto/proto3.proto
protoc-min-version --version="3.0.0" --proto_path=.:../../../../ --gogo_out=. proto3_proto/proto3.proto
make

View File

@ -401,17 +401,18 @@ type fakeMarshaler struct {
err error
}
func (f fakeMarshaler) Marshal() ([]byte, error) {
return f.b, f.err
func (f *fakeMarshaler) Marshal() ([]byte, error) { return f.b, f.err }
func (f *fakeMarshaler) String() string { return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err) }
func (f *fakeMarshaler) ProtoMessage() {}
func (f *fakeMarshaler) Reset() {}
type msgWithFakeMarshaler struct {
M *fakeMarshaler `protobuf:"bytes,1,opt,name=fake"`
}
func (f fakeMarshaler) String() string {
return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err)
}
func (f fakeMarshaler) ProtoMessage() {}
func (f fakeMarshaler) Reset() {}
func (m *msgWithFakeMarshaler) String() string { return CompactTextString(m) }
func (m *msgWithFakeMarshaler) ProtoMessage() {}
func (m *msgWithFakeMarshaler) Reset() {}
// Simple tests for proto messages that implement the Marshaler interface.
func TestMarshalerEncoding(t *testing.T) {
@ -423,7 +424,7 @@ func TestMarshalerEncoding(t *testing.T) {
}{
{
name: "Marshaler that fails",
m: fakeMarshaler{
m: &fakeMarshaler{
err: errors.New("some marshal err"),
b: []byte{5, 6, 7},
},
@ -431,9 +432,25 @@ func TestMarshalerEncoding(t *testing.T) {
want: nil,
wantErr: errors.New("some marshal err"),
},
{
name: "Marshaler that fails with RequiredNotSetError",
m: &msgWithFakeMarshaler{
M: &fakeMarshaler{
err: &RequiredNotSetError{},
b: []byte{5, 6, 7},
},
},
// Since there's an error that can be continued after,
// the buffer should be written.
want: []byte{
10, 3, // for &msgWithFakeMarshaler
5, 6, 7, // for &fakeMarshaler
},
wantErr: &RequiredNotSetError{},
},
{
name: "Marshaler that succeeds",
m: fakeMarshaler{
m: &fakeMarshaler{
b: []byte{0, 1, 2, 3, 4, 127, 255},
},
want: []byte{0, 1, 2, 3, 4, 127, 255},
@ -443,6 +460,10 @@ func TestMarshalerEncoding(t *testing.T) {
for _, test := range tests {
b := NewBuffer(nil)
err := b.Marshal(test.m)
if _, ok := err.(*RequiredNotSetError); ok {
// We're not in package proto, so we can only assert the type in this case.
err = &RequiredNotSetError{}
}
if !reflect.DeepEqual(test.wantErr, err) {
t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr)
}
@ -1281,7 +1302,7 @@ func TestEnum(t *testing.T) {
// We don't care what the value actually is, just as long as it doesn't crash.
func TestPrintingNilEnumFields(t *testing.T) {
pb := new(GoEnum)
fmt.Sprintf("%+v", pb)
_ = fmt.Sprintf("%+v", pb)
}
// Verify that absent required fields cause Marshal/Unmarshal to return errors.
@ -1925,6 +1946,83 @@ func TestMapFieldRoundTrips(t *testing.T) {
}
}
func TestMapFieldWithNil(t *testing.T) {
m := &MessageWithMap{
MsgMapping: map[int64]*FloatingPoint{
1: nil,
},
}
b, err := Marshal(m)
if err == nil {
t.Fatalf("Marshal of bad map should have failed, got these bytes: %v", b)
}
}
func TestOneof(t *testing.T) {
m := &Communique{}
b, err := Marshal(m)
if err != nil {
t.Fatalf("Marshal of empty message with oneof: %v", err)
}
if len(b) != 0 {
t.Errorf("Marshal of empty message yielded too many bytes: %v", b)
}
m = &Communique{
Union: &Communique_Name{"Barry"},
}
// Round-trip.
b, err = Marshal(m)
if err != nil {
t.Fatalf("Marshal of message with oneof: %v", err)
}
if len(b) != 7 { // name tag/wire (1) + name len (1) + name (5)
t.Errorf("Incorrect marshal of message with oneof: %v", b)
}
m.Reset()
if err := Unmarshal(b, m); err != nil {
t.Fatalf("Unmarshal of message with oneof: %v", err)
}
if x, ok := m.Union.(*Communique_Name); !ok || x.Name != "Barry" {
t.Errorf("After round trip, Union = %+v", m.Union)
}
if name := m.GetName(); name != "Barry" {
t.Errorf("After round trip, GetName = %q, want %q", name, "Barry")
}
// Let's try with a message in the oneof.
m.Union = &Communique_Msg{&Strings{StringField: String("deep deep string")}}
b, err = Marshal(m)
if err != nil {
t.Fatalf("Marshal of message with oneof set to message: %v", err)
}
if len(b) != 20 { // msg tag/wire (1) + msg len (1) + msg (1 + 1 + 16)
t.Errorf("Incorrect marshal of message with oneof set to message: %v", b)
}
m.Reset()
if err := Unmarshal(b, m); err != nil {
t.Fatalf("Unmarshal of message with oneof set to message: %v", err)
}
ss, ok := m.Union.(*Communique_Msg)
if !ok || ss.Msg.GetStringField() != "deep deep string" {
t.Errorf("After round trip with oneof set to message, Union = %+v", m.Union)
}
}
func TestInefficientPackedBool(t *testing.T) {
// https://github.com/golang/protobuf/issues/76
inp := []byte{
0x12, 0x02, // 0x12 = 2<<3|2; 2 bytes
// Usually a bool should take a single byte,
// but it is permitted to be any varint.
0xb9, 0x30,
}
if err := Unmarshal(inp, new(MoreRepeated)); err != nil {
t.Error(err)
}
}
// Benchmarks
func testMsg() *GoTest {

View File

@ -30,7 +30,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer deep copy and merge.
// TODO: MessageSet and RawMessage.
// TODO: RawMessage.
package proto
@ -75,12 +75,13 @@ func Merge(dst, src Message) {
}
func mergeStruct(out, in reflect.Value) {
sprop := GetProperties(in.Type())
for i := 0; i < in.NumField(); i++ {
f := in.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
mergeAny(out.Field(i), in.Field(i))
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
}
if emIn, ok := in.Addr().Interface().(extensionsMap); ok {
@ -103,7 +104,10 @@ func mergeStruct(out, in reflect.Value) {
}
}
func mergeAny(out, in reflect.Value) {
// mergeAny performs a merge between two values of the same type.
// viaPtr indicates whether the values were indirected through a pointer (implying proto2).
// prop is set if this is a struct field (it may be nil).
func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
if in.Type() == protoMessageType {
if !in.IsNil() {
if out.IsNil() {
@ -117,7 +121,21 @@ func mergeAny(out, in reflect.Value) {
switch in.Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
if !viaPtr && isProto3Zero(in) {
return
}
out.Set(in)
case reflect.Interface:
// Probably a oneof field; copy non-nil values.
if in.IsNil() {
return
}
// Allocate destination if it is not set, or set to a different type.
// Otherwise we will merge as normal.
if out.IsNil() || out.Elem().Type() != in.Elem().Type() {
out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
}
mergeAny(out.Elem(), in.Elem(), false, nil)
case reflect.Map:
if in.Len() == 0 {
return
@ -132,7 +150,7 @@ func mergeAny(out, in reflect.Value) {
switch elemKind {
case reflect.Ptr:
val = reflect.New(in.Type().Elem().Elem())
mergeAny(val, in.MapIndex(key))
mergeAny(val, in.MapIndex(key), false, nil)
case reflect.Slice:
val = in.MapIndex(key)
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
@ -148,13 +166,21 @@ func mergeAny(out, in reflect.Value) {
if out.IsNil() {
out.Set(reflect.New(in.Elem().Type()))
}
mergeAny(out.Elem(), in.Elem())
mergeAny(out.Elem(), in.Elem(), true, nil)
case reflect.Slice:
if in.IsNil() {
return
}
if in.Type().Elem().Kind() == reflect.Uint8 {
// []byte is a scalar bytes field, not a repeated field.
// Edge case: if this is in a proto3 message, a zero length
// bytes field is considered the zero value, and should not
// be merged.
if prop != nil && prop.proto3 && in.Len() == 0 {
return
}
// Make a deep copy.
// Append to []byte{} instead of []byte(nil) so that we never end up
// with a nil result.
@ -172,7 +198,7 @@ func mergeAny(out, in reflect.Value) {
default:
for i := 0; i < n; i++ {
x := reflect.Indirect(reflect.New(in.Type().Elem()))
mergeAny(x, in.Index(i))
mergeAny(x, in.Index(i), false, nil)
out.Set(reflect.Append(out, x))
}
}
@ -189,7 +215,7 @@ func mergeExtension(out, in map[int32]Extension) {
eOut := Extension{desc: eIn.desc}
if eIn.value != nil {
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
mergeAny(v, reflect.ValueOf(eIn.value))
mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
eOut.value = v.Interface()
}
if eIn.enc != nil {

View File

@ -35,6 +35,8 @@ import (
"testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
proto3pb "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto/proto3_proto"
pb "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto/testdata"
)
@ -213,6 +215,45 @@ var mergeTests = []struct {
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
},
// proto3 shouldn't merge zero values,
// in the same way that proto2 shouldn't merge nils.
{
src: &proto3pb.Message{
Name: "Aaron",
Data: []byte(""), // zero value, but not nil
},
dst: &proto3pb.Message{
HeightInCm: 176,
Data: []byte("texas!"),
},
want: &proto3pb.Message{
Name: "Aaron",
HeightInCm: 176,
Data: []byte("texas!"),
},
},
// Oneof fields should merge by assignment.
{
src: &pb.Communique{
Union: &pb.Communique_Number{Number: 41},
},
dst: &pb.Communique{
Union: &pb.Communique_Name{Name: "Bobby Tables"},
},
want: &pb.Communique{
Union: &pb.Communique_Number{Number: 41},
},
},
// Oneof nil is the same as not set.
{
src: &pb.Communique{},
dst: &pb.Communique{
Union: &pb.Communique_Name{Name: "Bobby Tables"},
},
want: &pb.Communique{
Union: &pb.Communique_Name{Name: "Bobby Tables"},
},
},
}
func TestMerge(t *testing.T) {

View File

@ -46,6 +46,10 @@ import (
// errOverflow is returned when an integer is too large to be represented.
var errOverflow = errors.New("proto: integer overflow")
// ErrInternalBadWireType is returned by generated code when an incorrect
// wire type is encountered. It does not get returned to user code.
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
@ -314,6 +318,24 @@ func UnmarshalMerge(buf []byte, pb Message) error {
return NewBuffer(buf).Unmarshal(pb)
}
// DecodeMessage reads a count-delimited message from the Buffer.
func (p *Buffer) DecodeMessage(pb Message) error {
enc, err := p.DecodeRawBytes(false)
if err != nil {
return err
}
return NewBuffer(enc).Unmarshal(pb)
}
// DecodeGroup reads a tag-delimited group from the Buffer.
func (p *Buffer) DecodeGroup(pb Message) error {
typ, base, err := getbase(pb)
if err != nil {
return err
}
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
}
// Unmarshal parses the protocol buffer representation in the
// Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be
@ -370,11 +392,11 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
if prop.extendable {
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
if ee, ok := e.(extensionsMap); ok {
if ee, eok := e.(extensionsMap); eok {
ext := ee.ExtensionMap()[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
ee.ExtensionMap()[int32(tag)] = ext
} else if ee, ok := e.(extensionsBytes); ok {
} else if ee, eok := e.(extensionsBytes); eok {
ext := ee.GetExtensions()
*ext = append(*ext, o.buf[oi:o.index]...)
}
@ -382,6 +404,20 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
continue
}
}
// Maybe it's a oneof?
if prop.oneofUnmarshaler != nil {
m := structPointer_Interface(base, st).(Message)
// First return value indicates whether tag is a oneof field.
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
if err == ErrInternalBadWireType {
// Map the error to something more descriptive.
// Do the formatting here to save generated code space.
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
}
if ok {
continue
}
}
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
@ -566,9 +602,13 @@ func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error
return err
}
nb := int(nn) // number of bytes of encoded bools
fin := o.index + nb
if fin < o.index {
return errOverflow
}
y := *v
for i := 0; i < nb; i++ {
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
@ -680,7 +720,7 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
oi := o.index // index at the end of this map entry
o.index -= len(raw) // move buffer back to start of map entry
mptr := structPointer_Map(base, p.field, p.mtype) // *map[K]V
mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V
if mptr.Elem().IsNil() {
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
}
@ -732,8 +772,14 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
return fmt.Errorf("proto: bad map data tag %d", raw[0])
}
}
keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() || !valelem.IsValid() {
// We did not decode the key or the value in the map entry.
// Either way, it's an invalid map entry.
return fmt.Errorf("proto: bad map data: missing key/val")
}
v.SetMapIndex(keyptr.Elem(), valptr.Elem())
v.SetMapIndex(keyelem, valelem)
return nil
}

View File

@ -228,6 +228,20 @@ func Marshal(pb Message) ([]byte, error) {
return p.buf, err
}
// EncodeMessage writes the protocol buffer to the Buffer,
// prefixed by a varint-encoded length.
func (p *Buffer) EncodeMessage(pb Message) error {
t, base, err := getbase(pb)
if structPointer_IsNil(base) {
return ErrNil
}
if err == nil {
var state errorState
err = p.enc_len_struct(GetProperties(t.Elem()), base, &state)
}
return err
}
// Marshal takes the protocol buffer
// and encodes it into the wire format, writing the result to the
// Buffer.
@ -318,7 +332,7 @@ func size_bool(p *Properties, base structPointer) int {
func size_proto3_bool(p *Properties, base structPointer) int {
v := *structPointer_BoolVal(base, p.field)
if !v {
if !v && !p.oneof {
return 0
}
return len(p.tagcode) + 1 // each bool takes exactly one byte
@ -361,7 +375,7 @@ func size_int32(p *Properties, base structPointer) (n int) {
func size_proto3_int32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range
if x == 0 {
if x == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -407,7 +421,7 @@ func size_uint32(p *Properties, base structPointer) (n int) {
func size_proto3_uint32(p *Properties, base structPointer) (n int) {
v := structPointer_Word32Val(base, p.field)
x := word32Val_Get(v)
if x == 0 {
if x == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -452,7 +466,7 @@ func size_int64(p *Properties, base structPointer) (n int) {
func size_proto3_int64(p *Properties, base structPointer) (n int) {
v := structPointer_Word64Val(base, p.field)
x := word64Val_Get(v)
if x == 0 {
if x == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -495,7 +509,7 @@ func size_string(p *Properties, base structPointer) (n int) {
func size_proto3_string(p *Properties, base structPointer) (n int) {
v := *structPointer_StringVal(base, p.field)
if v == "" {
if v == "" && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -529,7 +543,7 @@ func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error {
}
o.buf = append(o.buf, p.tagcode...)
o.EncodeRawBytes(data)
return nil
return state.err
}
o.buf = append(o.buf, p.tagcode...)
@ -667,7 +681,7 @@ func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error
func size_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if s == nil {
if s == nil && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -677,7 +691,7 @@ func size_slice_byte(p *Properties, base structPointer) (n int) {
func size_proto3_slice_byte(p *Properties, base structPointer) (n int) {
s := *structPointer_Bytes(base, p.field)
if len(s) == 0 {
if len(s) == 0 && !p.oneof {
return 0
}
n += len(p.tagcode)
@ -1084,7 +1098,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
repeated MapFieldEntry map_field = N;
*/
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V
if v.Len() == 0 {
return nil
}
@ -1101,11 +1115,15 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
return nil
}
keys := v.MapKeys()
sort.Sort(mapKeys(keys))
for _, key := range keys {
// Don't sort map keys. It is not required by the spec, and C++ doesn't do it.
for _, key := range v.MapKeys() {
val := v.MapIndex(key)
// The only illegal map entry values are nil message pointers.
if val.Kind() == reflect.Ptr && val.IsNil() {
return errors.New("proto: map has nil element")
}
keycopy.Set(key)
valcopy.Set(val)
@ -1118,7 +1136,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
}
func size_new_map(p *Properties, base structPointer) int {
v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V
keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
@ -1196,6 +1214,14 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
}
}
// Do oneof fields.
if prop.oneofMarshaler != nil {
m := structPointer_Interface(base, prop.stype).(Message)
if err := prop.oneofMarshaler(m, o); err != nil {
return err
}
}
// Add unrecognized fields at the end.
if prop.unrecField.IsValid() {
v := *structPointer_Bytes(base, prop.unrecField)
@ -1221,6 +1247,27 @@ func size_struct(prop *StructProperties, base structPointer) (n int) {
n += len(v)
}
// Factor in any oneof fields.
// TODO: This could be faster and use less reflection.
if prop.oneofMarshaler != nil {
sv := reflect.ValueOf(structPointer_Interface(base, prop.stype)).Elem()
for i := 0; i < prop.stype.NumField(); i++ {
fv := sv.Field(i)
if fv.Kind() != reflect.Interface || fv.IsNil() {
continue
}
if prop.stype.Field(i).Tag.Get("protobuf_oneof") == "" {
continue
}
spv := fv.Elem() // interface -> *T
sv := spv.Elem() // *T -> T
sf := sv.Type().Field(0) // StructField inside T
var prop Properties
prop.Init(sf.Type, "whatever", sf.Tag.Get("protobuf"), &sf)
n += prop.size(&prop, toStructPointer(spv))
}
}
return
}

View File

@ -30,7 +30,6 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer comparison.
// TODO: MessageSet.
package proto
@ -154,6 +153,17 @@ func equalAny(v1, v2 reflect.Value) bool {
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Interface:
// Probably a oneof field; compare the inner values.
n1, n2 := v1.IsNil(), v2.IsNil()
if n1 || n2 {
return n1 == n2
}
e1, e2 := v1.Elem(), v2.Elem()
if e1.Type() != e2.Type() {
return false
}
return equalAny(e1, e2)
case reflect.Map:
if v1.Len() != v2.Len() {
return false

View File

@ -180,6 +180,24 @@ var EqualTests = []struct {
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
false,
},
{
"oneof same",
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
true,
},
{
"oneof one nil",
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
&pb.Communique{},
false,
},
{
"oneof different",
&pb.Communique{Union: &pb.Communique_Number{Number: 41}},
&pb.Communique{Union: &pb.Communique_Name{Name: "Bobby Tables"}},
false,
},
}
func TestEqual(t *testing.T) {

View File

@ -311,7 +311,9 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
emap := epb.ExtensionMap()
e, ok := emap[extension.Field]
if !ok {
return nil, ErrMissingExtension
// defaultExtensionValue returns the default value or
// ErrMissingExtension if there is no default.
return defaultExtensionValue(extension)
}
if e.value != nil {
// Already decoded. Check the descriptor, though.
@ -356,10 +358,46 @@ func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, er
}
o += n + l
}
return defaultExtensionValue(extension)
}
panic("unreachable")
}
// defaultExtensionValue returns the default value for extension.
// If no default for an extension is defined ErrMissingExtension is returned.
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
t := reflect.TypeOf(extension.ExtensionType)
props := extensionProperties(extension)
sf, _, err := fieldDefault(t, props)
if err != nil {
return nil, err
}
if sf == nil || sf.value == nil {
// There is no default value.
return nil, ErrMissingExtension
}
if t.Kind() != reflect.Ptr {
// We do not need to return a Ptr, we can directly return sf.value.
return sf.value, nil
}
// We need to return an interface{} that is a pointer to sf.value.
value := reflect.New(t).Elem()
value.Set(reflect.New(value.Type().Elem()))
if sf.kind == reflect.Int32 {
// We may have an int32 or an enum, but the underlying data is int32.
// Since we can't set an int32 into a non int32 reflect.value directly
// set it as a int32.
value.Elem().SetInt(int64(sf.value.(int32)))
} else {
value.Elem().Set(reflect.ValueOf(sf.value))
}
return value.Interface(), nil
}
// decodeExtension decodes an extension encoded in b.
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
o := NewBuffer(b)

View File

@ -32,6 +32,8 @@
package proto_test
import (
"fmt"
"reflect"
"testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
@ -93,6 +95,143 @@ func TestGetExtensionStability(t *testing.T) {
}
}
func TestGetExtensionDefaults(t *testing.T) {
var setFloat64 float64 = 1
var setFloat32 float32 = 2
var setInt32 int32 = 3
var setInt64 int64 = 4
var setUint32 uint32 = 5
var setUint64 uint64 = 6
var setBool = true
var setBool2 = false
var setString = "Goodnight string"
var setBytes = []byte("Goodnight bytes")
var setEnum = pb.DefaultsMessage_TWO
type testcase struct {
ext *proto.ExtensionDesc // Extension we are testing.
want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
def interface{} // Expected value of extension after ClearExtension().
}
tests := []testcase{
{pb.E_NoDefaultDouble, setFloat64, nil},
{pb.E_NoDefaultFloat, setFloat32, nil},
{pb.E_NoDefaultInt32, setInt32, nil},
{pb.E_NoDefaultInt64, setInt64, nil},
{pb.E_NoDefaultUint32, setUint32, nil},
{pb.E_NoDefaultUint64, setUint64, nil},
{pb.E_NoDefaultSint32, setInt32, nil},
{pb.E_NoDefaultSint64, setInt64, nil},
{pb.E_NoDefaultFixed32, setUint32, nil},
{pb.E_NoDefaultFixed64, setUint64, nil},
{pb.E_NoDefaultSfixed32, setInt32, nil},
{pb.E_NoDefaultSfixed64, setInt64, nil},
{pb.E_NoDefaultBool, setBool, nil},
{pb.E_NoDefaultBool, setBool2, nil},
{pb.E_NoDefaultString, setString, nil},
{pb.E_NoDefaultBytes, setBytes, nil},
{pb.E_NoDefaultEnum, setEnum, nil},
{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
{pb.E_DefaultFloat, setFloat32, float32(3.14)},
{pb.E_DefaultInt32, setInt32, int32(42)},
{pb.E_DefaultInt64, setInt64, int64(43)},
{pb.E_DefaultUint32, setUint32, uint32(44)},
{pb.E_DefaultUint64, setUint64, uint64(45)},
{pb.E_DefaultSint32, setInt32, int32(46)},
{pb.E_DefaultSint64, setInt64, int64(47)},
{pb.E_DefaultFixed32, setUint32, uint32(48)},
{pb.E_DefaultFixed64, setUint64, uint64(49)},
{pb.E_DefaultSfixed32, setInt32, int32(50)},
{pb.E_DefaultSfixed64, setInt64, int64(51)},
{pb.E_DefaultBool, setBool, true},
{pb.E_DefaultBool, setBool2, true},
{pb.E_DefaultString, setString, "Hello, string"},
{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
}
checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
val, err := proto.GetExtension(msg, test.ext)
if err != nil {
if valWant != nil {
return fmt.Errorf("GetExtension(): %s", err)
}
if want := proto.ErrMissingExtension; err != want {
return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
}
return nil
}
// All proto2 extension values are either a pointer to a value or a slice of values.
ty := reflect.TypeOf(val)
tyWant := reflect.TypeOf(test.ext.ExtensionType)
if got, want := ty, tyWant; got != want {
return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
}
tye := ty.Elem()
tyeWant := tyWant.Elem()
if got, want := tye, tyeWant; got != want {
return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
}
// Check the name of the type of the value.
// If it is an enum it will be type int32 with the name of the enum.
if got, want := tye.Name(), tye.Name(); got != want {
return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
}
// Check that value is what we expect.
// If we have a pointer in val, get the value it points to.
valExp := val
if ty.Kind() == reflect.Ptr {
valExp = reflect.ValueOf(val).Elem().Interface()
}
if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
}
return nil
}
setTo := func(test testcase) interface{} {
setTo := reflect.ValueOf(test.want)
if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
setTo = reflect.New(typ).Elem()
setTo.Set(reflect.New(setTo.Type().Elem()))
setTo.Elem().Set(reflect.ValueOf(test.want))
}
return setTo.Interface()
}
for _, test := range tests {
msg := &pb.DefaultsMessage{}
name := test.ext.Name
// Check the initial value.
if err := checkVal(test, msg, test.def); err != nil {
t.Errorf("%s: %v", name, err)
}
// Set the per-type value and check value.
name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
t.Errorf("%s: SetExtension(): %v", name, err)
continue
}
if err := checkVal(test, msg, test.want); err != nil {
t.Errorf("%s: %v", name, err)
continue
}
// Set and check the value.
name += " (cleared)"
proto.ClearExtension(msg, test.ext)
if err := checkVal(test, msg, test.def); err != nil {
t.Errorf("%s: %v", name, err)
}
}
}
func TestExtensionsRoundTrip(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{

View File

@ -30,179 +30,230 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/*
Package proto converts data structures to and from the wire format of
protocol buffers. It works in concert with the Go source code generated
for .proto files by the protocol compiler.
Package proto converts data structures to and from the wire format of
protocol buffers. It works in concert with the Go source code generated
for .proto files by the protocol compiler.
A summary of the properties of the protocol buffer interface
for a protocol buffer variable v:
A summary of the properties of the protocol buffer interface
for a protocol buffer variable v:
- Names are turned from camel_case to CamelCase for export.
- There are no methods on v to set fields; just treat
them as structure fields.
- There are getters that return a field's value if set,
and return the field's default value if unset.
The getters work even if the receiver is a nil message.
- The zero value for a struct is its correct initialization state.
All desired fields must be set before marshaling.
- A Reset() method will restore a protobuf struct to its zero state.
- Non-repeated fields are pointers to the values; nil means unset.
That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices.
- Helper functions are available to aid the setting of fields.
msg.Foo = proto.String("hello") // set field
- Constants are defined to hold the default values of all fields that
have them. They have the form Default_StructName_FieldName.
Because the getter methods handle defaulted values,
direct use of these constants should be rare.
- Enums are given type names and maps from names to values.
Enum values are prefixed by the enclosing message's name, or by the
enum's type name if it is a top-level enum. Enum types have a String
method, and a Enum method to assist in message construction.
- Nested messages, groups and enums have type names prefixed with the name of
the surrounding message type.
- Extensions are given descriptor names that start with E_,
followed by an underscore-delimited list of the nested messages
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
- Marshal and Unmarshal are functions to encode and decode the wire format.
- Names are turned from camel_case to CamelCase for export.
- There are no methods on v to set fields; just treat
them as structure fields.
- There are getters that return a field's value if set,
and return the field's default value if unset.
The getters work even if the receiver is a nil message.
- The zero value for a struct is its correct initialization state.
All desired fields must be set before marshaling.
- A Reset() method will restore a protobuf struct to its zero state.
- Non-repeated fields are pointers to the values; nil means unset.
That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices.
- Helper functions are available to aid the setting of fields.
msg.Foo = proto.String("hello") // set field
- Constants are defined to hold the default values of all fields that
have them. They have the form Default_StructName_FieldName.
Because the getter methods handle defaulted values,
direct use of these constants should be rare.
- Enums are given type names and maps from names to values.
Enum values are prefixed by the enclosing message's name, or by the
enum's type name if it is a top-level enum. Enum types have a String
method, and a Enum method to assist in message construction.
- Nested messages, groups and enums have type names prefixed with the name of
the surrounding message type.
- Extensions are given descriptor names that start with E_,
followed by an underscore-delimited list of the nested messages
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
- Oneof field sets are given a single field in their message,
with distinguished wrapper types for each possible field value.
- Marshal and Unmarshal are functions to encode and decode the wire format.
The simplest way to describe this is to see an example.
Given file test.proto, containing
The simplest way to describe this is to see an example.
Given file test.proto, containing
package example;
package example;
enum FOO { X = 17; }
enum FOO { X = 17; }
message Test {
required string label = 1;
optional int32 type = 2 [default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4 {
required string RequiredField = 5;
}
message Test {
required string label = 1;
optional int32 type = 2 [default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4 {
required string RequiredField = 5;
}
oneof union {
int32 number = 6;
string name = 7;
}
}
The resulting file, test.pb.go, is:
package example
import proto "github.com/gogo/protobuf/proto"
import math "math"
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
}
var FOO_value = map[string]int32{
"X": 17,
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
if err != nil {
return err
}
*x = FOO(value)
return nil
}
The resulting file, test.pb.go, is:
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
// Types that are valid to be assigned to Union:
// *Test_Number
// *Test_Name
Union isTest_Union `protobuf_oneof:"union"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
package example
type isTest_Union interface {
isTest_Union()
}
import proto "github.com/gogo/protobuf/proto"
import math "math"
type Test_Number struct {
Number int32 `protobuf:"varint,6,opt,name=number"`
}
type Test_Name struct {
Name string `protobuf:"bytes,7,opt,name=name"`
}
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
func (*Test_Number) isTest_Union() {}
func (*Test_Name) isTest_Union() {}
func (m *Test) GetUnion() isTest_Union {
if m != nil {
return m.Union
}
var FOO_value = map[string]int32{
"X": 17,
return nil
}
const Default_Test_Type int32 = 77
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
return Default_Test_Type
}
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
if err != nil {
return err
}
*x = FOO(value)
return nil
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
}
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
XXX_unrecognized []byte `json:"-"`
func (m *Test) GetNumber() int32 {
if x, ok := m.GetUnion().(*Test_Number); ok {
return x.Number
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
const Default_Test_Type int32 = 77
return 0
}
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
func (m *Test) GetName() string {
if x, ok := m.GetUnion().(*Test_Name); ok {
return x.Name
}
return ""
}
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
return Default_Test_Type
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}
To create and play with a Test object:
package main
import (
"log"
"github.com/gogo/protobuf/proto"
pb "./example.pb"
)
func main() {
test := &pb.Test{
Label: proto.String("hello"),
Type: proto.Int32(17),
Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"),
},
Union: &pb.Test_Name{"fred"},
}
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
return nil
data, err := proto.Marshal(test)
if err != nil {
log.Fatal("marshaling error: ", err)
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
newTest := &pb.Test{}
err = proto.Unmarshal(data, newTest)
if err != nil {
log.Fatal("unmarshaling error: ", err)
}
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
// Now test and newTest contain the same data.
if test.GetLabel() != newTest.GetLabel() {
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}
To create and play with a Test object:
package main
import (
"log"
"github.com/gogo/protobuf/proto"
pb "./example.pb"
)
func main() {
test := &pb.Test{
Label: proto.String("hello"),
Type: proto.Int32(17),
Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"),
},
}
data, err := proto.Marshal(test)
if err != nil {
log.Fatal("marshaling error: ", err)
}
newTest := &pb.Test{}
err = proto.Unmarshal(data, newTest)
if err != nil {
log.Fatal("unmarshaling error: ", err)
}
// Now test and newTest contain the same data.
if test.GetLabel() != newTest.GetLabel() {
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
}
// etc.
// Use a type switch to determine which oneof was set.
switch u := test.Union.(type) {
case *pb.Test_Number: // u.Number contains the number.
case *pb.Test_Name: // u.Name contains the string.
}
// etc.
}
*/
package proto
@ -211,6 +262,7 @@ import (
"fmt"
"log"
"reflect"
"sort"
"strconv"
"sync"
)
@ -385,13 +437,13 @@ func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32,
// DebugPrint dumps the encoded data in b in a debugging format with a header
// including the string s. Used in testing but made available for general debugging.
func (o *Buffer) DebugPrint(s string, b []byte) {
func (p *Buffer) DebugPrint(s string, b []byte) {
var u uint64
obuf := o.buf
index := o.index
o.buf = b
o.index = 0
obuf := p.buf
index := p.index
p.buf = b
p.index = 0
depth := 0
fmt.Printf("\n--- %s ---\n", s)
@ -402,12 +454,12 @@ out:
fmt.Print(" ")
}
index := o.index
if index == len(o.buf) {
index := p.index
if index == len(p.buf) {
break
}
op, err := o.DecodeVarint()
op, err := p.DecodeVarint()
if err != nil {
fmt.Printf("%3d: fetching op err %v\n", index, err)
break out
@ -424,7 +476,7 @@ out:
case WireBytes:
var r []byte
r, err = o.DecodeRawBytes(false)
r, err = p.DecodeRawBytes(false)
if err != nil {
break out
}
@ -445,7 +497,7 @@ out:
fmt.Printf("\n")
case WireFixed32:
u, err = o.DecodeFixed32()
u, err = p.DecodeFixed32()
if err != nil {
fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err)
break out
@ -453,16 +505,15 @@ out:
fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u)
case WireFixed64:
u, err = o.DecodeFixed64()
u, err = p.DecodeFixed64()
if err != nil {
fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
break
case WireVarint:
u, err = o.DecodeVarint()
u, err = p.DecodeVarint()
if err != nil {
fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err)
break out
@ -470,30 +521,22 @@ out:
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
case WireStartGroup:
if err != nil {
fmt.Printf("%3d: t=%3d start err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d start\n", index, tag)
depth++
case WireEndGroup:
depth--
if err != nil {
fmt.Printf("%3d: t=%3d end err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d end\n", index, tag)
}
}
if depth != 0 {
fmt.Printf("%3d: start-end not balanced %d\n", o.index, depth)
fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth)
}
fmt.Printf("\n")
o.buf = obuf
o.index = index
p.buf = obuf
p.index = index
}
// SetDefaults sets unset protocol buffer fields to their default values.
@ -668,123 +711,173 @@ func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
}
ft := t.Field(fi).Type
var canHaveDefault, nestedMessage bool
switch ft.Kind() {
case reflect.Ptr:
if ft.Elem().Kind() == reflect.Struct {
nestedMessage = true
} else {
canHaveDefault = true // proto2 scalar field
}
case reflect.Slice:
switch ft.Elem().Kind() {
case reflect.Ptr:
nestedMessage = true // repeated message
case reflect.Uint8:
canHaveDefault = true // bytes field
}
case reflect.Map:
if ft.Elem().Kind() == reflect.Ptr {
nestedMessage = true // map with message values
}
sf, nested, err := fieldDefault(ft, prop)
switch {
case err != nil:
log.Print(err)
case nested:
dm.nested = append(dm.nested, fi)
case sf != nil:
sf.index = fi
dm.scalars = append(dm.scalars, *sf)
}
if !canHaveDefault {
if nestedMessage {
dm.nested = append(dm.nested, fi)
}
continue
}
sf := scalarField{
index: fi,
kind: ft.Elem().Kind(),
}
// scalar fields without defaults
if !prop.HasDefault {
dm.scalars = append(dm.scalars, sf)
continue
}
// a scalar field: either *T or []byte
switch ft.Elem().Kind() {
case reflect.Bool:
x, err := strconv.ParseBool(prop.Default)
if err != nil {
log.Printf("proto: bad default bool %q: %v", prop.Default, err)
continue
}
sf.value = x
case reflect.Float32:
x, err := strconv.ParseFloat(prop.Default, 32)
if err != nil {
log.Printf("proto: bad default float32 %q: %v", prop.Default, err)
continue
}
sf.value = float32(x)
case reflect.Float64:
x, err := strconv.ParseFloat(prop.Default, 64)
if err != nil {
log.Printf("proto: bad default float64 %q: %v", prop.Default, err)
continue
}
sf.value = x
case reflect.Int32:
x, err := strconv.ParseInt(prop.Default, 10, 32)
if err != nil {
log.Printf("proto: bad default int32 %q: %v", prop.Default, err)
continue
}
sf.value = int32(x)
case reflect.Int64:
x, err := strconv.ParseInt(prop.Default, 10, 64)
if err != nil {
log.Printf("proto: bad default int64 %q: %v", prop.Default, err)
continue
}
sf.value = x
case reflect.String:
sf.value = prop.Default
case reflect.Uint8:
// []byte (not *uint8)
sf.value = []byte(prop.Default)
case reflect.Uint32:
x, err := strconv.ParseUint(prop.Default, 10, 32)
if err != nil {
log.Printf("proto: bad default uint32 %q: %v", prop.Default, err)
continue
}
sf.value = uint32(x)
case reflect.Uint64:
x, err := strconv.ParseUint(prop.Default, 10, 64)
if err != nil {
log.Printf("proto: bad default uint64 %q: %v", prop.Default, err)
continue
}
sf.value = x
default:
log.Printf("proto: unhandled def kind %v", ft.Elem().Kind())
continue
}
dm.scalars = append(dm.scalars, sf)
}
return dm
}
// fieldDefault returns the scalarField for field type ft.
// sf will be nil if the field can not have a default.
// nestedMessage will be true if this is a nested message.
// Note that sf.index is not set on return.
func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) {
var canHaveDefault bool
switch ft.Kind() {
case reflect.Ptr:
if ft.Elem().Kind() == reflect.Struct {
nestedMessage = true
} else {
canHaveDefault = true // proto2 scalar field
}
case reflect.Slice:
switch ft.Elem().Kind() {
case reflect.Ptr:
nestedMessage = true // repeated message
case reflect.Uint8:
canHaveDefault = true // bytes field
}
case reflect.Map:
if ft.Elem().Kind() == reflect.Ptr {
nestedMessage = true // map with message values
}
}
if !canHaveDefault {
if nestedMessage {
return nil, true, nil
}
return nil, false, nil
}
// We now know that ft is a pointer or slice.
sf = &scalarField{kind: ft.Elem().Kind()}
// scalar fields without defaults
if !prop.HasDefault {
return sf, false, nil
}
// a scalar field: either *T or []byte
switch ft.Elem().Kind() {
case reflect.Bool:
x, err := strconv.ParseBool(prop.Default)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err)
}
sf.value = x
case reflect.Float32:
x, err := strconv.ParseFloat(prop.Default, 32)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err)
}
sf.value = float32(x)
case reflect.Float64:
x, err := strconv.ParseFloat(prop.Default, 64)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err)
}
sf.value = x
case reflect.Int32:
x, err := strconv.ParseInt(prop.Default, 10, 32)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err)
}
sf.value = int32(x)
case reflect.Int64:
x, err := strconv.ParseInt(prop.Default, 10, 64)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err)
}
sf.value = x
case reflect.String:
sf.value = prop.Default
case reflect.Uint8:
// []byte (not *uint8)
sf.value = []byte(prop.Default)
case reflect.Uint32:
x, err := strconv.ParseUint(prop.Default, 10, 32)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err)
}
sf.value = uint32(x)
case reflect.Uint64:
x, err := strconv.ParseUint(prop.Default, 10, 64)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err)
}
sf.value = x
default:
return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind())
}
return sf, false, nil
}
// Map fields may have key types of non-float scalars, strings and enums.
// The easiest way to sort them in some deterministic order is to use fmt.
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
type mapKeys []reflect.Value
func mapKeys(vs []reflect.Value) sort.Interface {
s := mapKeySorter{
vs: vs,
// default Less function: textual comparison
less: func(a, b reflect.Value) bool {
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
},
}
func (s mapKeys) Len() int { return len(s) }
func (s mapKeys) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s mapKeys) Less(i, j int) bool {
return fmt.Sprint(s[i].Interface()) < fmt.Sprint(s[j].Interface())
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
// numeric keys are sorted numerically.
if len(vs) == 0 {
return s
}
switch vs[0].Kind() {
case reflect.Int32, reflect.Int64:
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
case reflect.Uint32, reflect.Uint64:
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
}
return s
}
type mapKeySorter struct {
vs []reflect.Value
less func(a, b reflect.Value) bool
}
func (s mapKeySorter) Len() int { return len(s.vs) }
func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
func (s mapKeySorter) Less(i, j int) bool {
return s.less(s.vs[i], s.vs[j])
}
// isProto3Zero reports whether v is a zero proto3 value.
func isProto3Zero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.String:
return v.String() == ""
}
return false
}

View File

@ -44,11 +44,11 @@ import (
"sort"
)
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
// errNoMessageTypeID occurs when a protocol buffer does not have a message type ID.
// A message type ID is required for storing a protocol buffer in a message set.
var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
var errNoMessageTypeID = errors.New("proto does not have a message type ID")
// The first two types (_MessageSet_Item and MessageSet)
// The first two types (_MessageSet_Item and messageSet)
// model what the protocol compiler produces for the following protocol message:
// message MessageSet {
// repeated group Item = 1 {
@ -58,27 +58,20 @@ var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
// }
// That is the MessageSet wire format. We can't use a proto to generate these
// because that would introduce a circular dependency between it and this package.
//
// When a proto1 proto has a field that looks like:
// optional message<MessageSet> info = 3;
// the protocol compiler produces a field in the generated struct that looks like:
// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"`
// The package is automatically inserted so there is no need for that proto file to
// import this package.
type _MessageSet_Item struct {
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
Message []byte `protobuf:"bytes,3,req,name=message"`
}
type MessageSet struct {
type messageSet struct {
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
XXX_unrecognized []byte
// TODO: caching?
}
// Make sure MessageSet is a Message.
var _ Message = (*MessageSet)(nil)
// Make sure messageSet is a Message.
var _ Message = (*messageSet)(nil)
// messageTypeIder is an interface satisfied by a protocol buffer type
// that may be stored in a MessageSet.
@ -86,7 +79,7 @@ type messageTypeIder interface {
MessageTypeId() int32
}
func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
func (ms *messageSet) find(pb Message) *_MessageSet_Item {
mti, ok := pb.(messageTypeIder)
if !ok {
return nil
@ -100,24 +93,24 @@ func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
return nil
}
func (ms *MessageSet) Has(pb Message) bool {
func (ms *messageSet) Has(pb Message) bool {
if ms.find(pb) != nil {
return true
}
return false
}
func (ms *MessageSet) Unmarshal(pb Message) error {
func (ms *messageSet) Unmarshal(pb Message) error {
if item := ms.find(pb); item != nil {
return Unmarshal(item.Message, pb)
}
if _, ok := pb.(messageTypeIder); !ok {
return ErrNoMessageTypeId
return errNoMessageTypeID
}
return nil // TODO: return error instead?
}
func (ms *MessageSet) Marshal(pb Message) error {
func (ms *messageSet) Marshal(pb Message) error {
msg, err := Marshal(pb)
if err != nil {
return err
@ -130,7 +123,7 @@ func (ms *MessageSet) Marshal(pb Message) error {
mti, ok := pb.(messageTypeIder)
if !ok {
return ErrNoMessageTypeId
return errNoMessageTypeID
}
mtid := mti.MessageTypeId()
@ -141,9 +134,9 @@ func (ms *MessageSet) Marshal(pb Message) error {
return nil
}
func (ms *MessageSet) Reset() { *ms = MessageSet{} }
func (ms *MessageSet) String() string { return CompactTextString(ms) }
func (*MessageSet) ProtoMessage() {}
func (ms *messageSet) Reset() { *ms = messageSet{} }
func (ms *messageSet) String() string { return CompactTextString(ms) }
func (*messageSet) ProtoMessage() {}
// Support for the message_set_wire_format message option.
@ -169,7 +162,7 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
}
sort.Ints(ids)
ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
ms := &messageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
for _, id := range ids {
e := m[int32(id)]
// Remove the wire type and field number varint, as well as the length varint.
@ -186,7 +179,7 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
ms := new(MessageSet)
ms := new(messageSet)
if err := Unmarshal(buf, ms); err != nil {
return err
}

View File

@ -38,7 +38,7 @@ import (
func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
// Check that a repeated message set entry will be concatenated.
in := &MessageSet{
in := &messageSet{
Item: []*_MessageSet_Item{
{TypeId: Int32(12345), Message: []byte("hoo")},
{TypeId: Int32(12345), Message: []byte("hah")},

View File

@ -144,8 +144,8 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension)
}
// Map returns the reflect.Value for the address of a map field in the struct.
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
// NewAt returns the reflect.Value for a pointer to a field in the struct.
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
return structPointer_field(p, f).Addr()
}

View File

@ -130,8 +130,8 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
// Map returns the reflect.Value for the address of a map field in the struct.
func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
// NewAt returns the reflect.Value for a pointer to a field in the struct.
func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value {
return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
}

View File

@ -42,6 +42,7 @@ package proto
import (
"fmt"
"log"
"os"
"reflect"
"sort"
@ -89,6 +90,12 @@ type decoder func(p *Buffer, prop *Properties, base structPointer) error
// A valueDecoder decodes a single integer in a particular encoding.
type valueDecoder func(o *Buffer) (x uint64, err error)
// A oneofMarshaler does the marshaling for all oneof fields in a message.
type oneofMarshaler func(Message, *Buffer) error
// A oneofUnmarshaler does the unmarshaling for a oneof field in a message.
type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error)
// tagMap is an optimization over map[int]int for typical protocol buffer
// use-cases. Encoded protocol buffers are often in tag order with small tag
// numbers.
@ -137,6 +144,21 @@ type StructProperties struct {
order []int // list of struct field numbers in tag order
unrecField field // field id of the XXX_unrecognized []byte field
extendable bool // is this an extendable proto
oneofMarshaler oneofMarshaler
oneofUnmarshaler oneofUnmarshaler
stype reflect.Type
// OneofTypes contains information about the oneof fields in this message.
// It is keyed by the original name of a field.
OneofTypes map[string]*OneofProperties
}
// OneofProperties represents information about a specific field in a oneof.
type OneofProperties struct {
Type reflect.Type // pointer to generated struct type for this oneof field
Field int // struct field number of the containing oneof in the message
Prop *Properties
}
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
@ -161,6 +183,7 @@ type Properties struct {
Packed bool // relevant for repeated primitives only
Enum string // set for enum types only
proto3 bool // whether this is known to be a proto3 field; set for []byte only
oneof bool // whether this is a oneof field
Default string // default value
HasDefault bool // whether an explicit default was provided
@ -216,6 +239,9 @@ func (p *Properties) String() string {
if p.proto3 {
s += ",proto3"
}
if p.oneof {
s += ",oneof"
}
if len(p.Enum) > 0 {
s += ",enum=" + p.Enum
}
@ -292,6 +318,8 @@ func (p *Properties) Parse(s string) {
p.Enum = f[5:]
case f == "proto3":
p.proto3 = true
case f == "oneof":
p.oneof = true
case strings.HasPrefix(f, "def="):
p.HasDefault = true
p.Default = f[4:] // rest of string
@ -713,6 +741,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
prop.Prop = make([]*Properties, t.NumField())
prop.order = make([]int, t.NumField())
isOneofMessage := false
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
p := new(Properties)
@ -733,6 +762,10 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
}
oneof := f.Tag.Get("protobuf_oneof") != "" // special case
if oneof {
isOneofMessage = true
}
prop.Prop[i] = p
prop.order[i] = i
if debug {
@ -742,7 +775,7 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
}
print("\n")
}
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") {
if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof {
fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]")
}
}
@ -750,6 +783,41 @@ func getPropertiesLocked(t reflect.Type) *StructProperties {
// Re-order prop.order.
sort.Sort(prop)
type oneofMessage interface {
XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), []interface{})
}
if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); isOneofMessage && ok {
var oots []interface{}
prop.oneofMarshaler, prop.oneofUnmarshaler, oots = om.XXX_OneofFuncs()
prop.stype = t
// Interpret oneof metadata.
prop.OneofTypes = make(map[string]*OneofProperties)
for _, oot := range oots {
oop := &OneofProperties{
Type: reflect.ValueOf(oot).Type(), // *T
Prop: new(Properties),
}
sft := oop.Type.Elem().Field(0)
oop.Prop.Name = sft.Name
oop.Prop.Parse(sft.Tag.Get("protobuf"))
// There will be exactly one interface field that
// this new value is assignable to.
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type.Kind() != reflect.Interface {
continue
}
if !oop.Type.AssignableTo(f.Type) {
continue
}
oop.Field = i
break
}
prop.OneofTypes[oop.Prop.OrigName] = oop
}
}
// build required counts
// build tags
reqCount := 0
@ -813,3 +881,35 @@ func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[
}
enumStringMaps[typeName] = unusedNameMap
}
// EnumValueMap returns the mapping from names to integers of the
// enum type enumType, or a nil if not found.
func EnumValueMap(enumType string) map[string]int32 {
return enumValueMaps[enumType]
}
// A registry of all linked message types.
// The string is a fully-qualified proto name ("pkg.Message").
var (
protoTypes = make(map[string]reflect.Type)
revProtoTypes = make(map[reflect.Type]string)
)
// RegisterType is called from generated code and maps from the fully qualified
// proto name to the type (pointer to struct) of the protocol buffer.
func RegisterType(x Message, name string) {
if _, ok := protoTypes[name]; ok {
// TODO: Some day, make this a panic.
log.Printf("proto: duplicate proto type registered: %s", name)
return
}
t := reflect.TypeOf(x)
protoTypes[name] = t
revProtoTypes[t] = name
}
// MessageName returns the fully-qualified proto name for the given message type.
func MessageName(x Message) string { return revProtoTypes[reflect.TypeOf(x)] }
// MessageType returns the message type (pointer to struct) for a named message.
func MessageType(name string) reflect.Type { return protoTypes[name] }

View File

@ -16,10 +16,14 @@ It has these top-level messages:
package proto3_proto
import proto "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
import fmt "fmt"
import math "math"
import testdata "github.com/coreos/etcd/Godeps/_workspace/src/github.com/gogo/protobuf/proto/testdata"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
type Message_Humour int32
@ -118,5 +122,8 @@ func (m *MessageWithMap) GetByteMapping() map[bool][]byte {
}
func init() {
proto.RegisterType((*Message)(nil), "proto3_proto.Message")
proto.RegisterType((*Nested)(nil), "proto3_proto.Nested")
proto.RegisterType((*MessageWithMap)(nil), "proto3_proto.MessageWithMap")
proto.RegisterEnum("proto3_proto.Message_Humour", Message_Humour_name, Message_Humour_value)
}

View File

@ -124,6 +124,11 @@ var SizeTests = []struct {
{"map field with big entry", &pb.MessageWithMap{NameMapping: map[int32]string{8: strings.Repeat("x", 125)}}},
{"map field with big key and val", &pb.MessageWithMap{StrToStr: map[string]string{strings.Repeat("x", 70): strings.Repeat("y", 70)}}},
{"map field with big numeric key", &pb.MessageWithMap{NameMapping: map[int32]string{0xf00d: "om nom nom"}}},
{"oneof not set", &pb.Communique{}},
{"oneof zero int32", &pb.Communique{Union: &pb.Communique_Number{Number: 0}}},
{"oneof int32", &pb.Communique{Union: &pb.Communique_Number{Number: 3}}},
{"oneof string", &pb.Communique{Union: &pb.Communique_Name{Name: "Rhythmic Fman"}}},
}
func TestSize(t *testing.T) {

View File

@ -32,6 +32,6 @@
all: regenerate
regenerate:
go install github.com/gogo/protobuf/protoc-gen-gogo/version/protoc-min-version
go install github.com/gogo/protobuf/protoc-min-version
protoc-min-version --version="3.0.0" --gogo_out=. test.proto

Some files were not shown because too many files have changed in this diff Show More