go-ethereum中p2p源码学习

前一篇笔记学习了以太坊p2p网络的发现和维护机制,这篇笔记就来了解一下p2p服务

server.go

这是P2P服务的主逻辑代码所在处

Start

服务的启动代码如下:

func (srv *Server) Start() (err error) {
	srv.lock.Lock()
	defer srv.lock.Unlock()
	if srv.running { //判断是否启动,避免重复启动
		return errors.New("server already running")
	}
	srv.running = true 
	srv.log = srv.Config.Logger
	if srv.log == nil {
		srv.log = log.New()
	}
	if srv.NoDial && srv.ListenAddr == "" { //判断节点是否主动连接其他节点或者监听提起节点
		srv.log.Warn("P2P server will be useless, neither dialing nor listening")
	}

	// static fields
	if srv.PrivateKey == nil {
		return errors.New("Server.PrivateKey must be set to a non-nil key")
	}
	if srv.newTransport == nil {
		srv.newTransport = newRLPX
	}
	if srv.Dialer == nil {
		srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
	}
	srv.quit = make(chan struct{})
	srv.addpeer = make(chan *conn)
	srv.delpeer = make(chan peerDrop)
	srv.posthandshake = make(chan *conn)
	srv.addstatic = make(chan *enode.Node)
	srv.removestatic = make(chan *enode.Node)
	srv.addtrusted = make(chan *enode.Node)
	srv.removetrusted = make(chan *enode.Node)
	srv.peerOp = make(chan peerOpFunc)
	srv.peerOpDone = make(chan struct{})

	if err := srv.setupLocalNode(); err != nil {
		return err
	}
	if srv.ListenAddr != "" {
		if err := srv.setupListening(); err != nil {
			return err
		}
	}
	if err := srv.setupDiscovery(); err != nil {
		return err
	}

	dynPeers := srv.maxDialedConns()
	dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
	srv.loopWG.Add(1)
	go srv.run(dialer)
	return nil
}

首先判断是否启动,避免多起启动实例,然后将服务已运行的标志位置true。然后初始化log实例,检查是否要主动连接其他节点,检查私钥是否为空。之后配置rlpx与dial类,稍后再介绍这两个东西。再往下初始化了一系列的channel留作同步用。接着调用了setupLocalNode实现如下

setupLocalNode

func (srv *Server) setupLocalNode() error {
	pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
	for _, p := range srv.Protocols {
		srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
	}
	sort.Sort(capsByNameAndVersion(srv.ourHandshake.Caps))

	db, err := enode.OpenDB(srv.Config.NodeDatabase)
	if err != nil {
		return err
	}
	srv.nodedb = db
	srv.localnode = enode.NewLocalNode(db, srv.PrivateKey)
	srv.localnode.SetFallbackIP(net.IP{127, 0, 0, 1})
	srv.localnode.Set(capsByNameAndVersion(srv.ourHandshake.Caps))
	for _, p := range srv.Protocols {
		for _, e := range p.Attributes {
			srv.localnode.Set(e)
		}
	}
	switch srv.NAT.(type) {
	case nil:
	case nat.ExtIP:
		ip, _ := srv.NAT.ExternalIP()
		srv.localnode.SetStaticIP(ip)
	default:
		srv.loopWG.Add(1)
		go func() {
			defer srv.loopWG.Done()
			if ip, err := srv.NAT.ExternalIP(); err == nil {
				srv.localnode.SetStaticIP(ip)
			}
		}()
	}
	return nil
}

首先FromECDSAPub是将公钥以字节数组的形式表示(根据椭圆加密算法,我们知道公钥实际上是一对点坐标,这里我们将这对点用字节数组表示出来),之后构造了握手协议的实例,主要记录了版本号(5);服务名以及ID(就是公钥)。之后遍历了服务的所有协议,将每个协议的cap添加到握手协议的caps中(cap实际上记录了该协议的版本号及名字)。下面先创建了数据库,并配置给服务。接着调用NewLocalNode创建本地节点,实际上就是实例化了一个LocalNode对象,存储了公钥私钥数据库对象等信息。接下来设置了fallbackIP,Set方法实际上是将对象存储在localnode的entries中,capsByNameAndVersion实现了Entry接口。然后遍历服务中所有协议的所有Attributes(实际上也是一个个Entry数组)存储起来。

setupListening

配置完localnode后接着调用了setupListening

func (srv *Server) setupListening() error {
	listener, err := net.Listen("tcp", srv.ListenAddr)
	if err != nil {
		return err
	}
	laddr := listener.Addr().(*net.TCPAddr)
	srv.ListenAddr = laddr.String()
	srv.listener = listener
	srv.localnode.Set(enr.TCP(laddr.Port))

	srv.loopWG.Add(1)
	go srv.listenLoop()

	if !laddr.IP.IsLoopback() && srv.NAT != nil {
		srv.loopWG.Add(1)
		go func() {
			nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
			srv.loopWG.Done()
		}()
	}
	return nil
}

这里主要是监听某个tcp端口地址,启动了listenLoop

func (srv *Server) listenLoop() {
	defer srv.loopWG.Done()
	srv.log.Debug("TCP listener up", "addr", srv.listener.Addr())

	tokens := defaultMaxPendingPeers
	if srv.MaxPendingPeers > 0 {
		tokens = srv.MaxPendingPeers
	}
	slots := make(chan struct{}, tokens)
	for i := 0; i < tokens; i++ {
		slots <- struct{}{}
	}

	for {
		<-slots

		var (
			fd  net.Conn
			err error
		)
		for {
			fd, err = srv.listener.Accept()
			if netutil.IsTemporaryError(err) {
				srv.log.Debug("Temporary read error", "err", err)
				continue
			} else if err != nil {
				srv.log.Debug("Read error", "err", err)
				return
			}
			break
		}

		if srv.NetRestrict != nil {
			if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) {
				srv.log.Debug("Rejected conn (not whitelisted in NetRestrict)", "addr", fd.RemoteAddr())
				fd.Close()
				slots <- struct{}{}
				continue
			}
		}

		var ip net.IP
		if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok {
			ip = tcp.IP
		}
		fd = newMeteredConn(fd, true, ip)
		srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
		go func() {
			srv.SetupConn(fd, inboundConn, nil)
			slots <- struct{}{}
		}()
	}
}

首先规定了最大的等待连接数量tokens,然后创建了一个容量与之一样大的channel,在于一个无限循环中中,利用channel机制,启动了tokens个无限循环。每个循环会接收一个连接请求,实际上虽然是无限循环,在获得一个请求后循环便结束了(之所以要用无限循环是要跳过其中的临时性错误)。然后检查白名单,不在名单内的IP都拒绝服务。对于可以服务的连接,单独启动一个goroutine去处理,然后主循环继续,如果连接数未达到最大,则继续等待连接到来。处理连接的方法是SetupConn:

func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
	c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
	err := srv.setupConn(c, flags, dialDest)
	if err != nil {
		c.close(err)
		srv.log.Trace("Setting up connection failed", "addr", fd.RemoteAddr(), "err", err)
	}
	return err
}

SetupConn是执行一个握手协议,并尝试把连接创建成一个peer对象。可以看到只是先创建了conn对象,然后调用了setupConn:

func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) error {
	srv.lock.Lock()
	running := srv.running
	srv.lock.Unlock()
	if !running {
		return errServerStopped
	}
	
	var dialPubkey *ecdsa.PublicKey
	if dialDest != nil {
		dialPubkey = new(ecdsa.PublicKey)
		if err := dialDest.Load((*enode.Secp256k1)(dialPubkey)); err != nil {
			return errors.New("dial destination doesn't have a secp256k1 public key")
		}
	}
	
	remotePubkey, err := c.doEncHandshake(srv.PrivateKey, dialPubkey)
	if err != nil {
		srv.log.Trace("Failed RLPx handshake", "addr", c.fd.RemoteAddr(), "conn", c.flags, "err", err)
		return err
	}
	if dialDest != nil {
		// For dialed connections, check that the remote public key matches.
		if dialPubkey.X.Cmp(remotePubkey.X) != 0 || dialPubkey.Y.Cmp(remotePubkey.Y) != 0 {
			return DiscUnexpectedIdentity
		}
		c.node = dialDest
	} else {
		c.node = nodeFromConn(remotePubkey, c.fd)
	}
	if conn, ok := c.fd.(*meteredConn); ok {
		conn.handshakeDone(c.node.ID())
	}
	clog := srv.log.New("id", c.node.ID(), "addr", c.fd.RemoteAddr(), "conn", c.flags)
	err = srv.checkpoint(c, srv.posthandshake)
	if err != nil {
		clog.Trace("Rejected peer before protocol handshake", "err", err)
		return err
	}
	
	phs, err := c.doProtoHandshake(srv.ourHandshake)
	if err != nil {
		clog.Trace("Failed proto handshake", "err", err)
		return err
	}
	if id := c.node.ID(); !bytes.Equal(crypto.Keccak256(phs.ID), id[:]) {
		clog.Trace("Wrong devp2p handshake identity", "phsid", hex.EncodeToString(phs.ID))
		return DiscUnexpectedIdentity
	}
	c.caps, c.name = phs.Caps, phs.Name
	err = srv.checkpoint(c, srv.addpeer)
	if err != nil {
		clog.Trace("Rejected peer", "err", err)
		return err
	}
	
	clog.Trace("connection set up", "inbound", dialDest == nil)
	return nil
}

首先确保服务正在运行,然后判断远端节点是否为nil(为nil其实是被动连接,不为nil其实是在dial中主动连接),来计算其公钥。当收到一个连接时,这里为nil不执行。然后调用doEncHandshake开始加密握手,这里实际上是调用的rlpx中的方法,稍后再讲。这里最后得到远端的公钥。对于收到一个连接,远端node开始为空,这里调用nodeFromConn创建一个,主要是记录公钥、ip地址及端口号。接着执行checkpoint方法:

func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
	select {
	case stage <- c:
	case <-srv.quit:
		return errServerStopped
	}
	select {
	case err := <-c.cont:
		return err
	case <-srv.quit:
		return errServerStopped
	}
}

实际上就是给posthandshake赋值,然后触发后续逻辑,我们稍后再讲。紧接着又执行了协议握手,调用了doProtoHandshake方法,也是rlpx中方法,传入的参数是ourHandshake,也就是在配置localnode是初始化的,记录了版本号、服务名和自己公钥以及服务中所有协议的摘要。这个方法返回远端的协议信息,之后在conn对象中记录下来远端的服务名和服务中所有协议摘要。同样也通过checkpoint传递给addpeer用来触发后续逻辑。完成后就和远端建立的连接。

setupDiscovery

刚才我们在分析配置网络监听的代码时顺便看了收到连接时的逻辑。再回到Start中,下面又调用了setupDiscovery;

func (srv *Server) setupDiscovery() error {
	if srv.NoDiscovery && !srv.DiscoveryV5 {
		return nil
	}

	addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
	if err != nil {
		return err
	}
	conn, err := net.ListenUDP("udp", addr)
	if err != nil {
		return err
	}
	realaddr := conn.LocalAddr().(*net.UDPAddr)
	srv.log.Debug("UDP listener up", "addr", realaddr)
	if srv.NAT != nil {
		if !realaddr.IP.IsLoopback() {
			go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
		}
	}
	srv.localnode.SetFallbackUDP(realaddr.Port)

	
	var unhandled chan discover.ReadPacket
	var sconn *sharedUDPConn
	if !srv.NoDiscovery {
		if srv.DiscoveryV5 {
			unhandled = make(chan discover.ReadPacket, 100)
			sconn = &sharedUDPConn{conn, unhandled}
		}
		cfg := discover.Config{
			PrivateKey:  srv.PrivateKey,
			NetRestrict: srv.NetRestrict,
			Bootnodes:   srv.BootstrapNodes,
			Unhandled:   unhandled,
		}
		ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
		if err != nil {
			return err
		}
		srv.ntab = ntab
	}
	
	if srv.DiscoveryV5 {
		var ntab *discv5.Network
		var err error
		if sconn != nil {
			ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, "", srv.NetRestrict)
		} else {
			ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, "", srv.NetRestrict)
		}
		if err != nil {
			return err
		}
		if err := ntab.SetFallbackNodes(srv.BootstrapNodesV5); err != nil {
			return err
		}
		srv.DiscV5 = ntab
	}
	return nil
}

这里主要就是监听udp连接,配置私钥、白名单、bootstrap节点等,然后调用discover的ListenUDP开始节点发现,后续逻辑详见这里

run

稍微总结一下,start的逻辑主要是配置:配置localnode,处理tcp连接用于节点通信,处理udp连接用于节点发现。配置完毕后,调用newDialState创建了一个dialstate对象,然后运行run方法:

func (srv *Server) run(dialstate dialer) {
	srv.log.Info("Started P2P networking", "self", srv.localnode.Node())
	defer srv.loopWG.Done()
	defer srv.nodedb.Close()

	var (
		peers        = make(map[enode.ID]*Peer)
		inboundCount = 0
		trusted      = make(map[enode.ID]bool, len(srv.TrustedNodes))
		taskdone     = make(chan task, maxActiveDialTasks)
		runningTasks []task
		queuedTasks  []task // tasks that can't run yet
	)
	
	for _, n := range srv.TrustedNodes {
		trusted[n.ID()] = true
	}

	
	delTask := func(t task) {
		for i := range runningTasks {
			if runningTasks[i] == t {
				runningTasks = append(runningTasks[:i], runningTasks[i+1:]...)
				break
			}
		}
	}
	
	startTasks := func(ts []task) (rest []task) {
		i := 0
		for ; len(runningTasks) < maxActiveDialTasks && i < len(ts); i++ {
			t := ts[i]
			srv.log.Trace("New dial task", "task", t)
			go func() { t.Do(srv); taskdone <- t }()
			runningTasks = append(runningTasks, t)
		}
		return ts[i:]
	}
	scheduleTasks := func() {
		queuedTasks = append(queuedTasks[:0], startTasks(queuedTasks)...)
		if len(runningTasks) < maxActiveDialTasks {
			nt := dialstate.newTasks(len(runningTasks)+len(queuedTasks), peers, time.Now())
			queuedTasks = append(queuedTasks, startTasks(nt)...)
		}
	}

running:
	for {
		scheduleTasks()
		select{
		    ...
		}
	}

}

run方法主要是处理一些连接的逻辑,首先定义了两个队列runningTasks与queuedTasks,保存正在运行和等待运行的任务。然后定义了三个处理任务的方法。delTask就是删除任务。startTasks就是将任务添加到runningTasks并执行do方法,对于暂时无法运行的任务则返回。scheduleTasks是用来启动任务,他会先尝试启动等待中的任务,然后用newTasks新建一个任务,添加到queuedTasks中。

在接下来的一个无限循环开始,首先启动scheduleTasks,触发一些连接的建立和任务的启动,之后进入select开始阻塞,等待特定的逻辑被触发。

run:running

回顾刚才的代码流程,在执行完加密握手后,调用了checkpoint,将一个conn对象传给了srv.posthandshake,出发了run中如下逻辑:

case c := <-srv.posthandshake:
			if trusted[c.node.ID()] {
				c.flags |= trustedConn
			}
			select {
			case c.cont <- srv.encHandshakeChecks(peers, inboundCount, c):
			case <-srv.quit:
				break running
			}

首先检测是否是可信节点,是的话修改flags的值。然后执行encHandshakeChecks方法:

func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
	switch {
	case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
		return DiscTooManyPeers
	case !c.is(trustedConn) && c.is(inboundConn) && inboundCount >= srv.maxInboundConns():
		return DiscTooManyPeers
	case peers[c.node.ID()] != nil:
		return DiscAlreadyConnected
	case c.node.ID() == srv.localnode.ID():
		return DiscSelf
	default:
		return nil
	}
}

前两个case主要是通过检查conn的flags判断是节点的性质以及连接数,后两个检查是否是自己或者已连接。检查结果赋值给c.cont,这时回到checkpoint中,如果刚才检查无误setupConn流程继续,否则则抛出异常拒绝服务。

再往下,在setupConn中执行完协议握手后,同样调用了checkpoint方法,这次是给srv.addpeer赋值来触发逻辑:

case c := <-srv.addpeer:
	err := srv.protoHandshakeChecks(peers, inboundCount, c)
	if err == nil {
		p := newPeer(c, srv.Protocols)
		if srv.EnableMsgEvents {
			p.events = &srv.peerFeed
		}
		name := truncateName(c.name)
		srv.log.Debug("Adding p2p peer", "name", name, "addr", c.fd.RemoteAddr(), "peers", len(peers)+1)
		go srv.runPeer(p)
		peers[c.node.ID()] = p
		if p.Inbound() {
			inboundCount++
		}
	}
	select {
	case c.cont <- err:
	case <-srv.quit:
		break running
	}

同样先检查节点,这次调用的是protoHandshakeChecks:

func (srv *Server) protoHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int, c *conn) error {
	if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
		return DiscUselessPeer
	}
	return srv.encHandshakeChecks(peers, inboundCount, c)
}

由于握手后,我们知道了对方能提供的协议详情,这里进行了匹配检查,如果双方没有能匹配到的协议,则返回DiscUselessPeer,之后和刚才加密握手一样调用encHandshakeChecks进行节点检查。最后如果都没有问题,调用newPeer创建一个新的Peer。这里表示握手通过,连接正式建立。之后进行了后续操作,如将刚才生成的peer记录下来,另外如果是一个接入的连接,则inboundCount自增。同时调用了runPeer方法

func (srv *Server) runPeer(p *Peer) {
	if srv.newPeerHook != nil {
		srv.newPeerHook(p)
	}

	srv.peerFeed.Send(&PeerEvent{
		Type: PeerEventTypeAdd,
		Peer: p.ID(),
	})

	remoteRequested, err := p.run()

	srv.peerFeed.Send(&PeerEvent{
		Type:  PeerEventTypeDrop,
		Peer:  p.ID(),
		Error: err.Error(),
	})

	srv.delpeer <- peerDrop{p, err, remoteRequested}
}

这里主要是先广播了peer的建立,然后调用peer的run方法,这里稍后再讲,知道peer连接断开,然后给delpeer赋值触发相应逻辑。由于runPeer是运行在一个单独goroutine中,所以不会阻塞server的run运行,我们回到run中,srv.addpeer对应的逻辑最后有个阻塞,会给c.cont赋值,这时回到setupConn,如果赋值为空表示没有错误,连接正常,则继续setupConn的逻辑。这样一次握手完成。再看peer断开时,srv.delpeer被赋值,触发如下逻辑:

case pd := <-srv.delpeer:
	d := common.PrettyDuration(mclock.Now() - pd.created)
	pd.log.Debug("Removing p2p peer", "duration", d, "peers", len(peers)-1, "req", pd.requested, "err", pd.err)
	delete(peers, pd.ID())
	if pd.Inbound() {
		inboundCount--
	}
}

主要的逻辑就是从peers删除对应的peer,然后如果是接入型peer,inboundCount再自减1。

到这里server的主逻辑分析完毕,除了这些,服务还提供了一些方法供外部使用,首先看AddPeer:

func (srv *Server) AddPeer(node *enode.Node) {
	select {
	case srv.addstatic <- node:
	case <-srv.quit:
	}
}

也是利用channel模式触发run中的逻辑

case n := <-srv.addstatic:
	srv.log.Trace("Adding static node", "node", n)
	dialstate.addStatic(n)

很简单就是将要添加的节点dialstate的static这个map中,key就是节点的ID,值就是一个dialTask。

除此之外还有RemovePeer、AddTrustedPeer、RemoveTrustedPeer等方法,都是利用向一个channel中赋值,来触发run中的逻辑这里不再详细叙述。

rlpx.go

这是一个较独立的模块,所以拿出来分析了,详见此文:go-ethereum中p2p-rlpx源码学习

dial.go

dial在p2p中也负责链接的建立,在p2pserver中第一次出现是在start方法内构造了一个TCPDialer对象赋值给Dialer:

	if srv.Dialer == nil {
		srv.Dialer = TCPDialer{&net.Dialer{Timeout: defaultDialTimeout}}
	}
	
type TCPDialer struct {
	*net.Dialer
}	

func (t TCPDialer) Dial(dest *enode.Node) (net.Conn, error) {
	addr := &net.TCPAddr{IP: dest.IP(), Port: dest.TCP()}
	return t.Dialer.Dial("tcp", addr.String())
}

TCPDialer实际上对Dialer进行了封装,然后提供了Dial用于和指定Node建立tcp链接。除此之外还有一个重要的结构体dialstate,它在p2pserver中的start方法最后进行了实例化,并作为参数传递给了run方法,初始化方法如下:

func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
	s := &dialstate{
		maxDynDials: maxdyn,
		ntab:        ntab,
		self:        self,
		netrestrict: netrestrict,
		static:      make(map[enode.ID]*dialTask),
		dialing:     make(map[enode.ID]connFlag),
		bootnodes:   make([]*enode.Node, len(bootnodes)),
		randomNodes: make([]*enode.Node, maxdyn/2),
		hist:        new(dialHistory),
	}
	copy(s.bootnodes, bootnodes)
	for _, n := range static {
		s.addStatic(n)
	}
	return s
}

maxDynDials就是server的maxDialedConns,计算如下:

func (srv *Server) maxDialedConns() int {
	if srv.NoDiscovery || srv.NoDial {
		return 0
	}
	r := srv.DialRatio
	if r == 0 {
		r = defaultDialRatio
	}
	return srv.MaxPeers / r
}

他在节点从不进行发现或进行连接时等于0,否则根据MaxPeers和DialRatio计算。ntab就是discover的table,代表节点发现协议。其他的赋值了信赖节点和boot节点。实例化之后,在run方法中调用了它的下面几个方法:

newTasks

这是在定义scheduleTasks时调用的

func (s *dialstate) newTasks(nRunning int, peers map[enode.ID]*Peer, now time.Time) []task {
	if s.start.IsZero() {
		s.start = now
	}

	var newtasks []task
	addDial := func(flag connFlag, n *enode.Node) bool {
		if err := s.checkDial(n, peers); err != nil {
			log.Trace("Skipping dial candidate", "id", n.ID(), "addr", &net.TCPAddr{IP: n.IP(), Port: n.TCP()}, "err", err)
			return false
		}
		s.dialing[n.ID()] = flag
		newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
		return true
	}

	needDynDials := s.maxDynDials
	for _, p := range peers {
		if p.rw.is(dynDialedConn) {
			needDynDials--
		}
	}
	for _, flag := range s.dialing {
		if flag&dynDialedConn != 0 {
			needDynDials--
		}
	}

	s.hist.expire(now)

	for id, t := range s.static {
		err := s.checkDial(t.dest, peers)
		switch err {
		case errNotWhitelisted, errSelf:
			log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}, "err", err)
			delete(s.static, t.dest.ID())
		case nil:
			s.dialing[id] = t.flags
			newtasks = append(newtasks, t)
		}
	}

	if len(peers) == 0 && len(s.bootnodes) > 0 && needDynDials > 0 && now.Sub(s.start) > fallbackInterval {
		bootnode := s.bootnodes[0]
		s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...)
		s.bootnodes = append(s.bootnodes, bootnode)

		if addDial(dynDialedConn, bootnode) {
			needDynDials--
		}
	}

	randomCandidates := needDynDials / 2
	if randomCandidates > 0 {
		n := s.ntab.ReadRandomNodes(s.randomNodes)
		for i := 0; i < randomCandidates && i < n; i++ {
			if addDial(dynDialedConn, s.randomNodes[i]) {
				needDynDials--
			}
		}
	}

	i := 0
	for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
		if addDial(dynDialedConn, s.lookupBuf[i]) {
			needDynDials--
		}
	}
	s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
	if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
		s.lookupRunning = true
		newtasks = append(newtasks, &discoverTask{})
	}

	if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
		t := &waitExpireTask{s.hist.min().exp.Sub(now)}
		newtasks = append(newtasks, t)
	}
	return newtasks
}

他所返回的是一组task对象,task有一个Do方法。开头第一个检查时间的是为了初始化开始时间。然后定义了addDial方法,主要将节点包装成dialTask添加到newtasks中。不过首先对节点进行了检查,主要是检查是否已连接或者正在连接或者非信任节点等。之后计算了需要建立动态连接的数量。

然后遍历所有静态节点,检查每个节点,没有问题的将其暂存到newtasks中。接下来如果已连接数为0,并且已经过了fallbackInterval时间,则寻找一个bootnode进行连接。接下来在randomNodes中添加需要建立动态连接的数量的一半的节点。后面如果还未达到需要建立动态连接的数量要求,则从lookupBuf中挑选。如果数量还不够创建discoverTask添加进去,用于节点发现。最后当什么都不做创建waitExpireTask返回。

这个方法就是添加一系列要执行的任务。

Do

之后回到p2pserver的run方法中,在startTasks方法内,会启动这一系列任务,通过Do方法,不同的任务有同的Do方法:

dialTask

func (t *dialTask) Do(srv *Server) {
	if t.dest.Incomplete() {
		if !t.resolve(srv) {
			return
		}
	}
	err := t.dial(srv, t.dest)
	if err != nil {
		log.Trace("Dial error", "task", t, "err", err)
		// Try resolving the ID of static nodes if dialing failed.
		if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
			if t.resolve(srv) {
				t.dial(srv, t.dest)
			}
		}
	}
}

Incomplete方法是检查节点是否有IP地址,如果没有则调用resolve方法:

func (t *dialTask) resolve(srv *Server) bool {
	if srv.ntab == nil {
		log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
		return false
	}
	if t.resolveDelay == 0 {
		t.resolveDelay = initialResolveDelay
	}
	if time.Since(t.lastResolved) < t.resolveDelay {
		return false
	}
	resolved := srv.ntab.Resolve(t.dest)
	t.lastResolved = time.Now()
	if resolved == nil {
		t.resolveDelay *= 2
		if t.resolveDelay > maxResolveDelay {
			t.resolveDelay = maxResolveDelay
		}
		log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
		return false
	}

	t.resolveDelay = initialResolveDelay
	t.dest = resolved
	log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()})
	return true
}

这个方法是进程查询的,规定最小查询间隔是60秒。调用的是table中的Resolve,具体代码就不贴了,实际上就是先查找k桶,找到目标节点返回或摘到离他较近的节点调用lookup进行查找,所以就是节点查找的过程。如果查不到节点,则将最小查询间隔翻倍。如果查到的话,则更新节点。回到Do中,如果IP地址被补全,则调用dial进行连接,

func (t *dialTask) dial(srv *Server, dest *enode.Node) error {
	fd, err := srv.Dialer.Dial(dest)
	if err != nil {
		return &dialError{err}
	}
	mfd := newMeteredConn(fd, false, dest.IP())
	return srv.SetupConn(mfd, t.flags, dest)
}

Dialer实际上就是前文的TCPDialer,建立tcp连接后,调用了SetupConn,之后的逻辑前文以及分析过了。

discoverTask

func (t *discoverTask) Do(srv *Server) {
	next := srv.lastLookup.Add(lookupInterval)
	if now := time.Now(); now.Before(next) {
		time.Sleep(next.Sub(now))
	}
	srv.lastLookup = time.Now()
	t.results = srv.ntab.LookupRandom()
}

首先也是判断是否在最小间隔内,若是则等待,否则执行disocer的LookupRandom方法进行随机查找

waitExpireTask

func (t waitExpireTask) Do(*Server) {
	time.Sleep(t.Duration)
}

就是一个定时方法

addStatic removeStatic

这两个方法是提供给p2pserver的run方法中使用的,就是添加或移除节点。

dial总结

结合p2pserver分析,可知首先封装了一个TCP连接对象给server。然后初始化dialstate进行连接的建立。首先在run的大循环中第一次调用scheduleTasks,此时queuedTasks和queuedTasks都为空,此时添加的任务就是那些静态节点或桶中节点包装的dialTask,以及还有可能的discoverTask和最后的waitExpireTask。然后启动这些任务。对于dialTask,主动和相关节点建立联系;对于discoverTask,执行节点发现逻辑;对于waitExpireTask进行定时方法。

peer.go

newPeer

前面介绍了链路建立的准备工作,到这里peer代表一个已经建立好的连接。在p2p服务中,最早出现peer的地方是在节点主动发出或收到一个连接请求,并进行协议握手后,双方检查协议匹配性,检查通过后利用newPeer方法建立了一个peer对象,表示一条稳定的连接:

func newPeer(conn *conn, protocols []Protocol) *Peer {
	protomap := matchProtocols(protocols, conn.caps, conn)
	p := &Peer{
		rw:       conn,
		running:  protomap,
		created:  mclock.Now(),
		disc:     make(chan DiscReason),
		protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop
		closed:   make(chan struct{}),
		log:      log.New("id", conn.node.ID(), "conn", conn.flags),
	}
	return p
}

matchProtocols表示双方都能提供的协议,参数中protocols表示自己能提供的协议,conn.caps表示对方能提供的协议,来看具体实现:

func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
	sort.Sort(capsByNameAndVersion(caps))
	offset := baseProtocolLength
	result := make(map[string]*protoRW)

outer:
	for _, cap := range caps {
		for _, proto := range protocols {
			if proto.Name == cap.Name && proto.Version == cap.Version {
				// If an old protocol version matched, revert it
				if old := result[cap.Name]; old != nil {
					offset -= old.Length
				}
				// Assign the new match
				result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
				offset += proto.Length

				continue outer
			}
		}
	}
	return result
}

基本就是遍历自己和对方的协议集合,将所有名字和版本号一样的协议封装成protoRW对象并按名字存入result中,最后返回二者都能提供的协议集合。

run

回到newPeer中,构建了一个Peer对象并返回。紧接着启动了一个goroutine去调用runPeer,在runPeer中执行了peer的run方法:

func (p *Peer) run() (remoteRequested bool, err error) {
	var (
		writeStart = make(chan struct{}, 1)
		writeErr   = make(chan error, 1)
		readErr    = make(chan error, 1)
		reason     DiscReason // sent to the peer
	)
	p.wg.Add(2)
	go p.readLoop(readErr)
	go p.pingLoop()

	writeStart <- struct{}{}
	p.startProtocols(writeStart, writeErr)

loop:
	for {
		select {
		case err = <-writeErr:
			if err != nil {
				reason = DiscNetworkError
				break loop
			}
			writeStart <- struct{}{}
		case err = <-readErr:
			if r, ok := err.(DiscReason); ok {
				remoteRequested = true
				reason = r
			} else {
				reason = DiscNetworkError
			}
			break loop
		case err = <-p.protoErr:
			reason = discReasonForError(err)
			break loop
		case err = <-p.disc:
			reason = discReasonForError(err)
			break loop
		}
	}

	close(p.closed)
	p.rw.close(reason)
	p.wg.Wait()
	return remoteRequested, err
}

readLoop

在run方法中,首先在一个独立goroutine中启动了readLoop方法:

func (p *Peer) readLoop(errc chan<- error) {
	defer p.wg.Done()
	for {
		msg, err := p.rw.ReadMsg()
		if err != nil {
			errc <- err
			return
		}
		msg.ReceivedAt = time.Now()
		if err = p.handle(msg); err != nil {
			errc <- err
			return
		}
	}
}

这就是一个阻塞型的从流中读取信息的方法。rw是conn类型,conn是在开启一个tcp连接后,在SetupConn中将net.Conn包装后的对象。它的ReadMsg实际上就是rlpx的ReadMsg方法,前文已经分析过,返回的是一个Msg对象。之后在readLoop中给mag打上接受的时间戳,然后调用handle处理这个msg

func (p *Peer) handle(msg Msg) error {
	switch {
	case msg.Code == pingMsg:
		msg.Discard()
		go SendItems(p.rw, pongMsg)
	case msg.Code == discMsg:
		var reason [1]DiscReason
		rlp.Decode(msg.Payload, &reason)
		return reason[0]
	case msg.Code < baseProtocolLength:
		return msg.Discard()
	default:
		proto, err := p.getProto(msg.Code)
		if err != nil {
			return fmt.Errorf("msg code out of range: %v", msg.Code)
		}
		select {
		case proto.in <- msg:
			return nil
		case <-p.closed:
			return io.EOF
		}
	}
	return nil
}

这里主要是根据msg.Code决定执行的逻辑。主要是判断是否是ping消息或的断开连接的消息。都不是的话判断code是否满足规范,满足的话根据code取具体的协议getProto:

func (p *Peer) getProto(code uint64) (*protoRW, error) {
	for _, proto := range p.running {
		if code >= proto.offset && code < proto.offset+proto.Length {
			return proto, nil
		}
	}
	return nil, newPeerError(errInvalidMsgCode, "%d", code)
}

这里根据code查找具体协议,匹配的方式是根据协议的offset和length,看是否在[offset,offset+length)区间内,然后返回具体协议。找到后将msg赋值给proto.in触发相应逻辑。到这里handle执行完毕,readLoop的一个循环结束,开始下一个循环等待数据到来,这样一次完整的数据读取完成。还有一点需要注意的是handle多次出现了Discard方法:

func (msg Msg) Discard() error {
	_, err := io.Copy(ioutil.Discard, msg.Payload)
	return err
}

var Discard io.Writer = devNull(0)

由于msg的payload携带的是一个Reader对象,当消息抛弃是也要把read的内容读完,所以这里使用Discard这个虚拟的写对象,这个不执行什么实际操作,但是不会报错,可以安全的丢弃数据。

pingLoop

再回到run方法中,和readLoop同时启动的还有pingLoop

func (p *Peer) pingLoop() {
	ping := time.NewTimer(pingInterval)
	defer p.wg.Done()
	defer ping.Stop()
	for {
		select {
		case <-ping.C:
			if err := SendItems(p.rw, pingMsg); err != nil {
				p.protoErr <- err
				return
			}
			ping.Reset(pingInterval)
		case <-p.closed:
			return
		}
	}
}

这相当于是一个心跳包,即使双方没有数据传输,也要每隔一段时间内发送一个ping包看对方是否在线,这里时间间隔是15秒。15秒后执行SendItems方法

func SendItems(w MsgWriter, msgcode uint64, elems ...interface{}) error {
	return Send(w, msgcode, elems)
}

func Send(w MsgWriter, msgcode uint64, data interface{}) error {
	size, r, err := rlp.EncodeToReader(data)
	if err != nil {
		return err
	}
	return w.WriteMsg(Msg{Code: msgcode, Size: uint32(size), Payload: r})
}

这里的ping包是一个空包,只有一个pingMsg,最后还用rlpx的WriteMsg发送一个封装好的msg对象。发送后对方接受的逻辑就在上文分析的handle方法,

case msg.Code == pingMsg:
	msg.Discard()
	go SendItems(p.rw, pongMsg)

可见显示抛弃消息,然后为了避免阻塞发送了pong包,和发送ping包一样。当我们接受到pong包时,由于pongMsg小于baseProtocolLength,所以直接被抛弃。

startProtocols

再回到run中,由于readLoop和pingLoop都是异步进行了,我们看主线程的逻辑,首先调用了startProtocols

func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
	p.wg.Add(len(p.running))
	for _, proto := range p.running {
		proto := proto
		proto.closed = p.closed
		proto.wstart = writeStart
		proto.werr = writeErr
		var rw MsgReadWriter = proto
		if p.events != nil {
			rw = newMsgEventer(rw, p.events, p.ID(), proto.Name)
		}
		p.log.Trace(fmt.Sprintf("Starting protocol %s/%d", proto.Name, proto.Version))
		go func() {
			err := proto.Run(p, rw)
			if err == nil {
				p.log.Trace(fmt.Sprintf("Protocol %s/%d returned", proto.Name, proto.Version))
				err = errProtocolReturned
			} else if err != io.EOF {
				p.log.Trace(fmt.Sprintf("Protocol %s/%d failed", proto.Name, proto.Version), "err", err)
			}
			p.protoErr <- err
			p.wg.Done()
		}()
	}
}

running就是在newPeer中挑选的匹配上的协议,这里遍历这些协议进行启动操作。由于这里遍历的是Protocol的封装类protoRW,所以先对其几个channel赋值,然后启动一个匿名方法,执行协议的Run方法。

run:loop

在启动完所有协议后,开启了一个loop,也是一个无限循环,这个主要是处理各种错误的,如读写错误等,对于所有读写错误处理都是直接结束循环,后续触发peer的结束。

题图来自unsplash:https://unsplash.com/photos/gooBgyq17i0