go-ethereum中p2p-discover源码学习

区块链系统就是一个基于P2P网络的系统,这里先来学习一下以太坊的P2P网络实现。P2P网络运作的第一个要求就是节点之间能互相发现,这里以太坊用到了一个名为Kademlia的协议算法,这里来梳理一下ethereum P2P的discover实现

Kademlia协议

这个协议是2002年美国纽约大学Petar Maymounkov和David Mazières发表的一篇论文中所介绍的,该论文的核心部分翻译见这里,下面我们就来详细介绍一下这个协议

概述

Kademlia规定了网络的结构,也规定了通过节点查询进行信息交换的方式。Kademlia网络节点之间使用UDP进行通讯。参与通讯的所有节点形成一张虚拟网(或者叫做覆盖网)。这些节点通过一组数字(或称为节点ID)来进行身份标识。节点ID不仅可以用来做身份标识,还可以用来进行值定位(标识这哪些节点存储哪些资源)。

当进行节点搜索时,Kademlia算法需要知道与这些值相关的键,然后分步在网络中开始搜索,每一步都会找到一些节点,这些节点的ID与键更为接近,如果有节点直接返回搜索的值或者再也无法找到与键更为接近的节点ID的时候搜索便会停止。这种搜索值的方法是非常高效的:与其他的分布式散列表的实现类似,在一个包含n个节点的系统的值的搜索中,Kademlia仅访问O(log(n))个节点。

节点

Kademlia定义每个节点都以一个随机的ID,文中定义有160位,不必确保两个节点ID有什么联系,唯一需要做的足够随机。

距离度量

Kademlia定义两个节点的距离为两个节点ID异或的结果。如节点A为010101,节点B为110001,二者异或的100100,转换为十进制就是36,则二者之间的距离就是36。关于选择异或作为距离的度量,作者表示有一些几个特点:

  1. 一个节点到自己的距离为0,即d(x,x) = 0
  2. 从A到B与从B到A的距离相等,即d(x,y) = d(y,x)
  3. 满足三角不等式:d(x,z) = d(x,y) XOR d(y,z),而a+b>a XOR b

有了这三点,就足以证明异或算法也可以作为距离度量的标准,更重要的是异或计算非常高效。实际上用异或计算距离,更重要的是给出了一个节点分类的标准,类似于显示生活根据距离分类一样,便于通过一个ID搜索一个节点,也就是节点发现

K桶

文章提出了一个K桶的概念,实际上就是一个列表,如果ID有160位,那么一个节点就有160个所谓K桶,第i个桶保存着距离自己[2^i,2^i+1)范围内的节点,当寻找一个节点时,就从相应范围内的列表去搜索,k是一个列表的最大长度,如20。可知i越大,区间范围越大,里面的节点数可能越多。离自己越近的k桶内所记录的节点数虽然越少,但命中率越高。所以进行节点搜索是去寻找目标节点距离附近的节点,在进一步迭代,就能很快找到所需节点。

K桶使用一种类似于最近最久未使用的淘汰算法,当新节点被探知时,如果所在k桶未满,则直接在队尾插入,如果已满,则ping队头节点,如果节点不在线,则移除,把新节点插到队尾,如果队头节点在线,则从队头移到队尾,新节点被抛弃

RPC消息

Kademlia是利用一系列RPC消息来维护网络的

PING

主要作用是探测一个节点是否仍在线

STORE

通知一个节点存储一个键值对

FIND_NODE

节点定位。计算距离,寻找对应区间的k桶,选择一些节点发送消息,返回这些节点所知的距离目标节点更近的节点,然后对这些节点再次发送消息,多次迭代,最后定位到节点。

FIND_VALUE

定位资源。类似于FIND_NODE,返回的是节点的信息,如IP地址,udp端口即节点ID。

加入网络

要加入一个P2P首先必须要与一个网络内的节点建立通信,之后新节点进行自我定位,通过这种方法其他节点可以更新自己的k桶,新节点也可以获得网络信息。

其他

还有一些详细的协议规范请参考论文。

源码分析

go-ethereum的p2p实现源码主要集中在p2p目录下,该目录下的discover主要实现了节点方法算法,本文主要来梳理这一部分源码

table.go

在源码中table就是Kademlia的主要实现地方,先看常量和结构体

const (
	alpha           = 3  // Kademlia concurrency factor
	bucketSize      = 16 // Kademlia bucket size
	maxReplacements = 10 // Size of per-bucket replacement list

	// We keep buckets for the upper 1/15 of distances because
	// it's very unlikely we'll ever encounter a node that's closer.
	hashBits          = len(common.Hash{}) * 8
	nBuckets          = hashBits / 15       // Number of buckets
	bucketMinDistance = hashBits - nBuckets // Log distance of closest bucket

	// IP address limits.
	bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
	tableIPLimit, tableSubnet   = 10, 24

	maxFindnodeFailures = 5 // Nodes exceeding this limit are dropped
	refreshInterval     = 30 * time.Minute
	revalidateInterval  = 10 * time.Second
	copyNodesInterval   = 30 * time.Second
	seedMinTableTime    = 5 * time.Minute
	seedCount           = 30
	seedMaxAge          = 5 * 24 * time.Hour
)

type Table struct {
	mutex   sync.Mutex        // protects buckets, bucket content, nursery, rand
	buckets [nBuckets]*bucket // index of known nodes by distance
	nursery []*node           // bootstrap nodes
	rand    *mrand.Rand       // source of randomness, periodically reseeded
	ips     netutil.DistinctNetSet

	db         *enode.DB // database of known nodes
	net        transport
	refreshReq chan chan struct{}
	initDone   chan struct{}

	closeOnce sync.Once
	closeReq  chan struct{}
	closed    chan struct{}

	nodeAddedHook func(*node) // for testing
}

首先常量中定义了一些Kademlia协议中的一些值,如k桶容量也就是k等于16,每次查找的节点为3,k桶置换表大小为10。有一点和协议不同的是,这里定义k桶的数量为hash长度的15分之一,没有像协议中定义有hash有多长就有多少个k桶。另外还有超时重试次数,刷新间隔等定义。

table中维护了一组k桶实例,一组bootstrap节点,还有数据库等辅助参数。我们看k桶的定义

type bucket struct {
	entries      []*node 
	replacements []*node 
	ips          netutil.DistinctNetSet
}

和协议中定义的类似,一组节点和一组置换节点。下面看初始化代码

func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
	tab := &Table{
		net:        t,
		db:         db,
		refreshReq: make(chan chan struct{}),
		initDone:   make(chan struct{}),
		closeReq:   make(chan struct{}),
		closed:     make(chan struct{}),
		rand:       mrand.New(mrand.NewSource(0)),
		ips:        netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
	}
	if err := tab.setFallbackNodes(bootnodes); err != nil {
		return nil, err
	}
	for i := range tab.buckets {
		tab.buckets[i] = &bucket{
			ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
		}
	}
	tab.seedRand()
	tab.loadSeedNodes()

	go tab.loop()
	return tab, nil
}

func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
	for _, n := range nodes {
		if err := n.ValidateComplete(); err != nil {
			return fmt.Errorf("bad bootstrap node %q: %v", n, err)
		}
	}
	tab.nursery = wrapNodes(nodes)
	return nil
}

首先初始化一个table实例,然后调用setFallbackNodes初始化连接节点。在setFallbackNodes中先检查所有节点是否有效,接下来的wrapNodes方法主要是将enode.Node类型对象包装为discover包内的node对象。

在setFallbackNodes方法中table的bootstrap被设置完成之后,接下来的一个遍历是用来初始化所有k桶。

再往下,tab.seedRand()方法先用crypto/rand中的Read随机生成一个长度为8的byte数组,并用该数组生成一个64位int类型数去作为rand的种子。接下来loadSeedNodes实现如下

func (tab *Table) loadSeedNodes() {
	seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge))
	seeds = append(seeds, tab.nursery...)
	for i := range seeds {
		seed := seeds[i]
		age := log.Lazy{Fn: func() interface{} { return time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) }}
		log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
		tab.addSeenNode(seed)
	}
}

还是先用wrapNodes进行了一次包装,这次包装的对象来自于数据库,tab.db保存的是已知的节点,我们要查找的种子节点数量为30,种子节点最长寿命为5天(根据前面常量定义),QuerySeeds的逻辑也很简单,就是从数据库中随机查找,但是要过滤掉那些太老的节点,而且最后返回的节点数不超过30。接下来将数据库中返回的节点和前面的种子节点进行合并,开始遍历,遍历的目的是利用addSeenNode将节点添加到合适的桶中,既然要添加的合适的桶中,就要计算距离,我们来看看距离的计算

// go-ethereum\p2p\enode\node.go
func LogDist(a, b ID) int {
	lz := 0 //记录前导0数量
	for i := range a {
		x := a[i] ^ b[i]
		if x == 0 { //结果为8表示全为零,就是8个前导零
			lz += 8
		} else {
			lz += bits.LeadingZeros(x)
			break
		}
	}
	return len(a)*8 - lz
}

其实严格来说,这里已经不叫距离计算,由于前面k桶数量并不是根据原始Kademlia协议中写的那样定义而是ID的比特数除以15,所以对应的距离计算也要改变,通过阅读代码,我们发现距离定义变为比特数减去两个ID对应字节异或后前导零的数量之和的差。根据距离选桶的实现如下

func (tab *Table) bucket(id enode.ID) *bucket {
	d := enode.LogDist(tab.self().ID(), id)
	if d <= bucketMinDistance {
		return tab.buckets[0]
	}
	return tab.buckets[d-bucketMinDistance-1]
}

其中最短距离bucketMinDistance = hashBits - nBuckets。为的是确保能找到一个合适的桶,避免差值过大。在回到addSeenNode,找的合适的桶时,先判断是否满了或已包含,来决定否添加,或是否添加到缓存表中。加载完种子节点之后,启动了一个goroutine运行loop,实现如下。

func (tab *Table) loop() {
	var (
		revalidate     = time.NewTimer(tab.nextRevalidateTime())
		refresh        = time.NewTicker(refreshInterval)
		copyNodes      = time.NewTicker(copyNodesInterval)
		refreshDone    = make(chan struct{})           // where doRefresh reports completion
		revalidateDone chan struct{}                   // where doRevalidate reports completion
		waiting        = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
	)
	defer refresh.Stop()
	defer revalidate.Stop()
	defer copyNodes.Stop()

	// Start initial refresh.
	go tab.doRefresh(refreshDone)

loop:
	.....
}

首先初始化了几个定时器和几个channel。之后启动了一个goroutine执行初始化刷新。

func (tab *Table) doRefresh(done chan struct{}) {
	defer close(done)

	tab.loadSeedNodes()

	var key ecdsa.PublicKey
	if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil {
		tab.lookup(encodePubkey(&key), false)
	}
	
	for i := 0; i < 3; i++ {
		var target encPubkey
		crand.Read(target[:])
		tab.lookup(target, false)
	}
}

刷新操作中先加载了种子节点。之后执行了自我查找,就是查找自己,用的就是lookup,我们来看一下实现:

func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
	var (
		target         = enode.ID(crypto.Keccak256Hash(targetKey[:]))
		asked          = make(map[enode.ID]bool)
		seen           = make(map[enode.ID]bool)
		reply          = make(chan []*node, alpha)
		pendingQueries = 0
		result         *nodesByDistance
	)
	asked[tab.self().ID()] = true

	for {
		tab.mutex.Lock()
		result = tab.closest(target, bucketSize)
		tab.mutex.Unlock()
		if len(result.entries) > 0 || !refreshIfEmpty {
			break
		}
		<-tab.refresh()
		refreshIfEmpty = false
	}

	for {
		for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
			n := result.entries[i]
			if !asked[n.ID()] {
				asked[n.ID()] = true
				pendingQueries++
				go tab.findnode(n, targetKey, reply)
			}
		}
		if pendingQueries == 0 {
			break
		}
		select {
		case nodes := <-reply:
			for _, n := range nodes {
				if n != nil && !seen[n.ID()] {
					seen[n.ID()] = true
					result.push(n, bucketSize)
				}
			}
		case <-tab.closeReq:
			return nil 
		}
		pendingQueries--
	}
	return result.entries
}

首先定义了一系列变量,然后从k桶中取得离目标最近的几个节点,实现如下

func (tab *Table) closest(target enode.ID, nresults int) *nodesByDistance {
	close := &nodesByDistance{target: target}
	for _, b := range &tab.buckets {
		for _, n := range b.entries {
			if n.livenessChecks > 0 {
				close.push(n, nresults)
			}
		}
	}
	return close
}

type nodesByDistance struct {
	entries []*node
	target  enode.ID
}

nodesByDistance类型存储着目标ID和一组距离目标较近的节点。实际上closest逻辑很简单,就是遍历所有桶内的节点,看是比已有的更近,判断是否添加的逻辑在push中:

func (h *nodesByDistance) push(n *node, maxElems int) {
	ix := sort.Search(len(h.entries), func(i int) bool {
		return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0
	})
	if len(h.entries) < maxElems {
		h.entries = append(h.entries, n)
	}
	if ix == len(h.entries) {
	} else {
		copy(h.entries[ix+1:], h.entries[ix:])
		h.entries[ix] = n
	}
}

DistCmp是给a,b判断二者谁离目标更近, sort.Search则是得出给定的ID在entries的位置,如果最后ix排到末尾而且entries已满,说明这个点不比已有的点目标更近。继续回到lookup,找到一组离目标较近的节点后,在第二个循环体内,开始遍历这一组节点,对每个节点都执行findnode操作,但是最多每次并发执行alpha个,也就是3个。由于entries是有序的,所以前面的都是例目标最近的。下面看findnode操作:

func (tab *Table) findnode(n *node, targetKey encPubkey, reply chan<- []*node) {
	fails := tab.db.FindFails(n.ID(), n.IP())
	r, err := tab.net.findnode(n.ID(), n.addr(), targetKey)
	if err == errClosed {
		reply <- nil
		return
	} else if err != nil || len(r) == 0 {
		fails++
		tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
		log.Trace("Findnode failed", "id", n.ID(), "failcount", fails, "err", err)
		if fails >= maxFindnodeFailures {
			log.Trace("Too many findnode failures, dropping", "id", n.ID(), "failcount", fails)
			tab.delete(n)
		}
	} else if fails > 0 {
		tab.db.UpdateFindFails(n.ID(), n.IP(), fails-1)
	}

	for _, n := range r {
		tab.addSeenNode(n)
	}
	reply <- r
}

实际上findnode的逻辑并不在这里,这里先不管后面再说,先知道findnode会返回一组离目标更近的node,否则返回一个错误。对于某个node当错误次数大于5次时就把它删除。而对于之前出错但这次成功的话,就将其累积出错次数减一。最后遍历返回的一组node,调用addSeenNode,这个方法在之前加载种子节点时出现,就是将节点添加到合适桶中。方法最后向reply中写入r。这样就回到lookup中,有一个select,当reply可以取值时,执行下面逻辑

case nodes := <-reply:
			for _, n := range nodes {
				if n != nil && !seen[n.ID()] {
					seen[n.ID()] = true
					result.push(n, bucketSize)
				}
			}

这里继续将新得到的node添加到有序的entries中,继续启动新一轮循环按照前面的逻辑继续查找。至于什么时候停止呢?在遍历entries时,对于每个node会判断是否访问过,对于访问过就不在进行findnode,当所有entries都被访问就意味着已经找不到离目标更近的节点了,这是for循环结束,pendingQueries为0,根据代码执行break退出外层循环,返回entries。这样lookup就执行完毕。代码返回到doRefresh中,在执行完初始的自我查找之后,开始了一个小循环,进行随机查找,只查找3个,每次产生一个随机的ID调用lookup进行查找,目的都是尽可能的完善桶。doRefresh结束后后回到loop中,由于doRefresh是一个单独的goroutine,loop主要循环在loop标签的代码中:

loop:
	for {
		select {
		case <-refresh.C:
			tab.seedRand()
			if refreshDone == nil {
				refreshDone = make(chan struct{})
				go tab.doRefresh(refreshDone)
			}
		case req := <-tab.refreshReq:
			waiting = append(waiting, req)
			if refreshDone == nil {
				refreshDone = make(chan struct{})
				go tab.doRefresh(refreshDone)
			}
		case <-refreshDone:
			for _, ch := range waiting {
				close(ch)
			}
			waiting, refreshDone = nil, nil
		case <-revalidate.C:
			revalidateDone = make(chan struct{})
			go tab.doRevalidate(revalidateDone)
		case <-revalidateDone:
			revalidate.Reset(tab.nextRevalidateTime())
			revalidateDone = nil
		case <-copyNodes.C:
			go tab.copyLiveNodes()
		case <-tab.closeReq:
			break loop
		}
	}

可见还是一个同步代码。外层是一个无限循环,内存是一个select阻塞,根据不同的操作触发不同代码。首先是一个refresh定时器,每30分钟触发一次,调用doRefresh执行刷新。除了定时刷新还有主动刷新,使用的是refreshReq这个channel。另外还有一个revalidate定时器,它的时间不固定,随机从0到10秒内选一个时间,到时间后触发doRevalidate,逻辑如下:

func (tab *Table) doRevalidate(done chan<- struct{}) {
	defer func() { done <- struct{}{} }()

	last, bi := tab.nodeToRevalidate()
	if last == nil {
		return
	}

	err := tab.net.ping(last.ID(), last.addr())

	tab.mutex.Lock()
	defer tab.mutex.Unlock()
	b := tab.buckets[bi]
	if err == nil {
		last.livenessChecks++
		log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks)
		tab.bumpInBucket(b, last)
		return
	}
	if r := tab.replace(b, last); r != nil {
		log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP())
	} else {
		log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks)
	}
}

这段代码主要逻辑是随机找一个非空的桶,取得其最后一个节点,然后去ping这个节点,看是否在线,如果在线将其移到桶的前部,否则从缓存节点找一个去替代这个节点。当操作结束后,重新设置revalidate时间,准备下一次随机验证。最后还有一个定时器copyNodesInterval,每30分钟触发一次,执行copyLiveNodes逻辑

func (tab *Table) copyLiveNodes() {
	tab.mutex.Lock()
	defer tab.mutex.Unlock()

	now := time.Now()
	for _, b := range &tab.buckets {
		for _, n := range b.entries {
			if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime {
				tab.db.UpdateNode(unwrapNode(n))
			}
		}
	}
}

主要逻辑是,遍历所有节点,对存活检查过的,且添加时间大于5分钟的节点存入数据库,存储的内容是经过rlp编码的数据。再回到loop中,只剩下一个关闭请求的channel,用于退出loop。退出loop的逻辑也很简单,如下

if refreshDone != nil {
	<-refreshDone
}
for _, ch := range waiting {
	close(ch)
}
if revalidateDone != nil {
	<-revalidateDone
}
close(tab.closed)

先等待刷新结束,然后关闭那些等待中的操作,再等待验证结束,最后关闭closed用于正常退出close方法。closeReq是在close方法中发出的:

func (tab *Table) Close() {
	tab.closeOnce.Do(func() {
		if tab.net != nil {
			tab.net.close()
		}
		close(tab.closeReq)
		<-tab.closed
	})
}

到这里p2p节点的发现以及维护逻辑源码梳理完毕。总体还是很清晰的,基本都是按照Kademlia协议要求的进行的,主要流程是先加载种子节点,这些种子节点来源于用户指定或者之前数据库的值,然后利用自我查找进一步的填充k桶,最后通过各种定时器维护k桶。

udp.go

前面table.go实现了Kademlia协议,udp则是进行网络通信的,先看常量的定义

const (
	pingPacket = iota + 1 // zero is 'reserved'
	pongPacket
	findnodePacket
	neighborsPacket
)

定义了4种数据包,pingPacket就是询问节点是否在线,pongPacket是对ping的回应。findnodePacket是请求节点,neighborsPacket是对findnode的响应。详细的结构体如下

type (
	ping struct {
		senderKey *ecdsa.PublicKey // filled in by preverify

		Version    uint
		From, To   rpcEndpoint
		Expiration uint64
		// Ignore additional fields (for forward compatibility).
		Rest []rlp.RawValue `rlp:"tail"`
	}

	// pong is the reply to ping.
	pong struct {
		// This field should mirror the UDP envelope address
		// of the ping packet, which provides a way to discover the
		// the external address (after NAT).
		To rpcEndpoint

		ReplyTok   []byte // This contains the hash of the ping packet.
		Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
		// Ignore additional fields (for forward compatibility).
		Rest []rlp.RawValue `rlp:"tail"`
	}

	// findnode is a query for nodes close to the given target.
	findnode struct {
		Target     encPubkey
		Expiration uint64
		// Ignore additional fields (for forward compatibility).
		Rest []rlp.RawValue `rlp:"tail"`
	}

	// reply to findnode
	neighbors struct {
		Nodes      []rpcNode
		Expiration uint64
		// Ignore additional fields (for forward compatibility).
		Rest []rlp.RawValue `rlp:"tail"`
	}

	rpcNode struct {
		IP  net.IP // len 4 for IPv4 or 16 for IPv6
		UDP uint16 // for discovery protocol
		TCP uint16 // for RLPx protocol
		ID  encPubkey
	}

	rpcEndpoint struct {
		IP  net.IP // len 4 for IPv4 or 16 for IPv6
		UDP uint16 // for discovery protocol
		TCP uint16 // for RLPx protocol
	}
)

还有两个接口,分别定义数据包和upd连接。

type packet interface {
	// preverify checks whether the packet is valid and should be handled at all.
	preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error
	// handle handles the packet.
	handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte)
	// name returns the name of the packet for logging purposes.
	name() string
}

type conn interface {
	ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
	WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
	Close() error
	LocalAddr() net.Addr
}

udp的结构如下

type udp struct {
	conn        conn
	netrestrict *netutil.Netlist
	priv        *ecdsa.PrivateKey
	localNode   *enode.LocalNode
	db          *enode.DB
	tab         *Table
	wg          sync.WaitGroup

	addReplyMatcher chan *replyMatcher
	gotreply        chan reply
	closing         chan struct{}
}

接下来看udp的创建

func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
	tab, _, err := newUDP(c, ln, cfg)
	if err != nil {
		return nil, err
	}
	return tab, nil
}

func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
	udp := &udp{
		conn:            c,
		priv:            cfg.PrivateKey,
		netrestrict:     cfg.NetRestrict,
		localNode:       ln,
		db:              ln.Database(),
		closing:         make(chan struct{}),
		gotreply:        make(chan reply),
		addReplyMatcher: make(chan *replyMatcher),
	}
	tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
	if err != nil {
		return nil, nil, err
	}
	udp.tab = tab

	udp.wg.Add(2)
	go udp.loop()
	go udp.readLoop(cfg.Unhandled)
	return udp.tab, udp, nil
}

外部调用的是ListenUDP,传入一个连接,然后调用newUDP创建,由于UDP是为table服务的,所以也创建了table。接下来启动了两个goroutine,后面再讲,我们下面主要看看那几个PRC是如何实现的

ping

在table中的doRevalidate也就是检查节点活性的时候调用udp的ping方法

func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error {
	return <-t.sendPing(toid, toaddr, nil)
}

func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error {
	req := &ping{
		Version:    4,
		From:       t.ourEndpoint(),
		To:         makeEndpoint(toaddr, 0), 
		Expiration: uint64(time.Now().Add(expiration).Unix()),
	}
	packet, hash, err := encodePacket(t.priv, pingPacket, req)
	if err != nil {
		errc := make(chan error, 1)
		errc <- err
		return errc
	}
	errc := t.pending(toid, toaddr.IP, pongPacket, func(p interface{}) (matched bool, requestDone bool) {
		matched = bytes.Equal(p.(*pong).ReplyTok, hash)
		if matched && callback != nil {
			callback()
		}
		return matched, matched
	})
	t.localNode.UDPContact(toaddr)
	t.write(toaddr, toid, req.name(), packet)
	return errc
}

主要执行的地方在sendPing,首先构造一个ping对象,这里面指明了发送方和接收方,超时时间为20s。接收方和发送方都是以rpcEndpoint表示的,内含IP地址,udp和tcp端口。紧接着调用encodePacket打包数据。

func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
	b := new(bytes.Buffer)
	b.Write(headSpace)
	b.WriteByte(ptype)
	if err := rlp.Encode(b, req); err != nil {
		log.Error("Can't encode discv4 packet", "err", err)
		return nil, nil, err
	}
	packet = b.Bytes()
	sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
	if err != nil {
		log.Error("Can't sign discv4 packet", "err", err)
		return nil, nil, err
	}
	copy(packet[macSize:], sig)
	hash = crypto.Keccak256(packet[macSize:])
	copy(packet, hash)
	return packet, hash, nil
}

数据包中首先有一个长度为97的空字节数组,然后写入包的类型即pingPacket,最后再将刚才的ping对象编码为rlp格式写入。然后对数据进行签名,签名前先对除开头空字节部分以外的数据进行摘要。然后将签名结果写入开头预留的地方,然后在对签名数据和数据进行摘要,填充到开头,这样就完成了对数据打包,这也是常见的消息认证的步骤。回到sendPing中,有了要发送的数据,接下来调用pending

func (t *udp) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) <-chan error {
	ch := make(chan error, 1)
	p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch}
	select {
	case t.addReplyMatcher <- p:
		// loop will handle it
	case <-t.closing:
		ch <- errClosed
	}
	return ch
}

type replyMatchFunc func(interface{}) (matched bool, requestDone bool)

pending主要是添加一个响应的匹配者用于匹配响应,匹配主要是通过对hash进行对比,这里的hash是前面数据包中的签名数据和数据摘要后的结果。最后发送数据。其实到这里还没有结束,sendPing返回一个channel对象err,而在ping中一直阻塞,知道可以从err中取值。err是由pending返回的,在pending中被replyMatcher持有并赋值给addReplyMatcher,这是触发loop方法(在newUDP中启动的一个goroutine)中的select结构中的一个条件

case p := <-t.addReplyMatcher:
			p.deadline = time.Now().Add(respTimeout)
			plist.PushBack(p)

这里将pending中构造的replyMatcher存到plist中。

findnode

在table中的lookup中使用了findnode功能,为的是进行节点查找,而实际执行的地方在udp的findnode中

func (t *udp) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
	if time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration {
		t.ping(toid, toaddr)
		time.Sleep(respTimeout)
	}

	nodes := make([]*node, 0, bucketSize)
	nreceived := 0
	errc := t.pending(toid, toaddr.IP, neighborsPacket, func(r interface{}) (matched bool, requestDone bool) {
		reply := r.(*neighbors)
		for _, rn := range reply.Nodes {
			nreceived++
			n, err := t.nodeFromRPC(toaddr, rn)
			if err != nil {
				log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err)
				continue
			}
			nodes = append(nodes, n)
		}
		return true, nreceived >= bucketSize
	})
	t.send(toaddr, toid, findnodePacket, &findnode{
		Target:     target,
		Expiration: uint64(time.Now().Add(expiration).Unix()),
	})
	return nodes, <-errc
}

首先对于那些超过24小时没有ping过的节点先ping一次看是否存活。之后也是利用pending添加了一个响应的匹配。最后调用send发送数据,和ping一样,也是先进行打包,在发送数据。

响应

对于响应,这里要看在新建udp时启动的第二个goroutine–readLoop。在readLoop中先分析了udp包的长度和发送者,然后处理这个包

func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
	packet, fromKey, hash, err := decodePacket(buf)
	if err != nil {
		log.Debug("Bad discv4 packet", "addr", from, "err", err)
		return err
	}
	fromID := fromKey.id()
	if err == nil {
		err = packet.preverify(t, from, fromID, fromKey)
	}
	log.Trace("<< "+packet.name(), "id", fromID, "addr", from, "err", err)
	if err == nil {
		packet.handle(t, from, fromID, hash)
	}
	return err
}

第一步先解包

func decodePacket(buf []byte) (packet, encPubkey, []byte, error) {
	if len(buf) < headSize+1 {
		return nil, encPubkey{}, nil, errPacketTooSmall
	}
	hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
	shouldhash := crypto.Keccak256(buf[macSize:])
	if !bytes.Equal(hash, shouldhash) {
		return nil, encPubkey{}, nil, errBadHash
	}
	fromKey, err := recoverNodeKey(crypto.Keccak256(buf[headSize:]), sig)
	if err != nil {
		return nil, fromKey, hash, err
	}

	var req packet
	switch ptype := sigdata[0]; ptype {
	case pingPacket:
		req = new(ping)
	case pongPacket:
		req = new(pong)
	case findnodePacket:
		req = new(findnode)
	case neighborsPacket:
		req = new(neighbors)
	default:
		return nil, fromKey, hash, fmt.Errorf("unknown type: %d", ptype)
	}
	s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
	err = s.Decode(req)
	return req, fromKey, hash, err
}

解包就是打包的逆过程。我们先回顾一下打包的内容:Hash(签名加数据),签名(对数据摘要后签名),数据(类型,请求体)。再看解包过程,先判断长度是否争取,然后分理出hash,签名和数据。然后验证hash是否正确。再根据签名以及被签名的内容计算出公钥。之后根据类型(数据部分的第一个字节)构造出请求对象。然后解码原始数据。通过解包我们得到了请求数据,发送节点公钥,以及hash值。之后回到handlePacket中,先计算发送者ID(就是对公钥的摘要),接着根据不同的请求执行不同的逻辑。

ping的接受与响应

以ping为例,执行ping的preverify和handle方法

func (req *ping) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
	if expired(req.Expiration) {
		return errExpired
	}
	key, err := decodePubkey(fromKey)
	if err != nil {
		return errors.New("invalid public key")
	}
	req.senderKey = key
	return nil
}

preverify主要是判断是否超时然后解码出公钥。主要处理逻辑在handle中

func (req *ping) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
	t.send(from, fromID, pongPacket, &pong{
		To:         makeEndpoint(from, req.From.TCP),
		ReplyTok:   mac,
		Expiration: uint64(time.Now().Add(expiration).Unix()),
	})

	n := wrapNode(enode.NewV4(req.senderKey, from.IP, int(req.From.TCP), from.Port))
	if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
		t.sendPing(fromID, from, func() {
			t.tab.addVerifiedNode(n)
		})
	} else {
		t.tab.addVerifiedNode(n)
	}

	t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now())
	t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
}

直接是调用send发送数据,由于是相应ping,所以这里类型为pongPacket,相应体也是一个pong对象,同用叶经理了打包和发送的过程。发送完之后还有一些事情要处理。首先判断这个节点上一次pong相应是否超过24小时,若是超过则进行ping,否则更新节点,随后也要更新数据库。

pong的接收

到这一步我们梳理了节点A发送ping到节点B,B返回一个pong相应到节点A,我们再看节点A收到pong相应的动作。直接看pong的preverify和handle。

func (req *pong) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
	if expired(req.Expiration) {
		return errExpired
	}
	if !t.handleReply(fromID, from.IP, pongPacket, req) {
		return errUnsolicitedReply
	}
	return nil
}

func (t *udp) handleReply(from enode.ID, fromIP net.IP, ptype byte, req packet) bool {
	matched := make(chan bool, 1)
	select {
	case t.gotreply <- reply{from, fromIP, ptype, req, matched}:
		// loop will handle it
		return <-matched
	case <-t.closing:
		return false
	}
}

preverify第一步也是验证是否超时,接着调用handleReply分发相应。首先是t.gotreply获得赋值,触发loop中的select逻辑,

case r := <-t.gotreply:
	var matched bool // whether any replyMatcher considered the reply acceptable.
	for el := plist.Front(); el != nil; el = el.Next() {
		p := el.Value.(*replyMatcher)
		if p.from == r.from && p.ptype == r.ptype && p.ip.Equal(r.ip) {
			ok, requestDone := p.callback(r.data)
			matched = matched || ok
			if requestDone {
				p.errc <- nil
				plist.Remove(el)
			}
			contTimeouts = 0
		}
	}
	r.matched <- matched

这里首先遍历plist,plist存储着我们发送请求时的响应匹配者,也就是回调。我们通过发送者,请求类型和ip三个条件判断是否匹配。如果匹配上的话执行callback方法,对于ping的callback并没有实质的内容,只是通过比较两个hash值是否相等来判断是否匹配。两个hash值分别是发送ping请求时对签名和数据的摘要以及pong响应的ReplyTok,而ReplyTok来自于对请求包解包后得到hash,实际上也就是请求所携带的hash,只不过在响应是写到了pong的ReplyTok字段中。只有二者相等,才能证明是从正确节点得到了响应。接下来如果请求结束,则给p.errc赋值为nil并移除这个回调,对于p.err这个channel,他在发送请求后一直阻塞,知道这里阻塞得到解除,ping流程结束。回到loop的select中,接着给matched赋值,这时handleReply的阻塞得到解决,preverify方法调用结束,接着调用handle方法,这里主要就是更新节点信息。刚才是成功响应时的流程,对于超时的情况,在loop的select有一个timeout触发的case,他会检测plist中各个回调是否超时,对于超时的给p.errc赋值超时错误,然后移除该回调。同时统计超时次数,对于超时次数过多的话检查时间是否准确。这个我们稍后再讲。

findnode的接收与响应

我们再梳理一下findnode的接收与响应。前面的逻辑都是一样的,直接从handlePacket中开始,实际上是看preverify和handle。

func (req *findnode) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
	if expired(req.Expiration) {
		return errExpired
	}
	if time.Since(t.db.LastPongReceived(fromID, from.IP)) > bondExpiration {
		return errUnknownNode
	}
	return nil
}

首先判断是否超时,再看这个节点是否过长时间没有响应。主要看handle

func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID enode.ID, mac []byte) {
	target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
	t.tab.mutex.Lock()
	closest := t.tab.closest(target, bucketSize).entries
	t.tab.mutex.Unlock()

	p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
	var sent bool
	for _, n := range closest {
		if netutil.CheckRelayIP(from.IP, n.IP()) == nil {
			p.Nodes = append(p.Nodes, nodeToRPC(n))
		}
		if len(p.Nodes) == maxNeighbors {
			t.send(from, fromID, neighborsPacket, &p)
			p.Nodes = p.Nodes[:0]
			sent = true
		}
	}
	if len(p.Nodes) > 0 || !sent {
		t.send(from, fromID, neighborsPacket, &p)
	}
}

首先得到目标的ID,然后调用table的closest得出距目标较近的一组节点,然后构造响应体,主要有超时时间和要返回的节点,最后通过send方法发送响应。到这里我们梳理了节点A发送findnode请求到B,B找到一组节点发回A,再看A接到neighborsPacket响应时的动作,主要还是看preverify和handle。

func (req *neighbors) preverify(t *udp, from *net.UDPAddr, fromID enode.ID, fromKey encPubkey) error {
	if expired(req.Expiration) {
		return errExpired
	}
	if !t.handleReply(fromID, from.IP, neighborsPacket, req) {
		return errUnsolicitedReply
	}
	return nil
}

先检查超时,在调用handleReply,还是回到loop中执行之前findnode时预留的回调。回调如下:

func(r interface{}) (matched bool, requestDone bool) {
		reply := r.(*neighbors)
		for _, rn := range reply.Nodes {
			nreceived++
			n, err := t.nodeFromRPC(toaddr, rn)
			if err != nil {
				log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err)
				continue
			}
			nodes = append(nodes, n)
		}
		return true, nreceived >= bucketSize
	}

在这里将接受到的节点存储到nodes中,在findnode最后errc阻塞得到接触,findnode顺利返回nodes供table使用。

ndoe.go

这是discover包内的节点类,它包装了enode,同时又附加了节点添加时间和活性检查时间两个成员

type node struct {
	enode.Node
	addedAt        time.Time 
	livenessChecks uint      
}

这个类提供了一组对公钥操作的方法:encodePubkey是将公钥记为64字节类型,decodePubkey则是从字节数组还原公钥,recoverNodeKey从签名数据和原始数据中还原出公钥。还规定了节点的ID就是对公钥的字节数组形式摘要的结果。最后提供了一组包装与解包装enode的方法.

ntp.go

这是一个时间校准的工具类,在udp中如果多次超时的话就要考虑是否是本地时间出错。调用了checkClockDrift方法

func checkClockDrift() {
	drift, err := sntpDrift(ntpChecks)
	if err != nil {
		return
	}
	if drift < -driftThreshold || drift > driftThreshold {
		log.Warn(fmt.Sprintf("System clock seems off by %v, which can prevent network connectivity", drift))
		log.Warn("Please enable network time synchronisation in system settings.")
	} else {
		log.Debug("NTP sanity check done", "drift", drift)
	}
}

func sntpDrift(measurements int) (time.Duration, error) {
	addr, err := net.ResolveUDPAddr("udp", ntpPool+":123")
	if err != nil {
		return 0, err
	}
	request := make([]byte, 48)
	request[0] = 3<<3 | 3

	drifts := []time.Duration{}
	for i := 0; i < measurements+2; i++ {
		conn, err := net.DialUDP("udp", nil, addr)
		if err != nil {
			return 0, err
		}
		defer conn.Close()

		sent := time.Now()
		if _, err = conn.Write(request); err != nil {
			return 0, err
		}
		conn.SetDeadline(time.Now().Add(5 * time.Second))

		reply := make([]byte, 48)
		if _, err = conn.Read(reply); err != nil {
			return 0, err
		}
		elapsed := time.Since(sent)

		sec := uint64(reply[43]) | uint64(reply[42])<<8 | uint64(reply[41])<<16 | uint64(reply[40])<<24
		frac := uint64(reply[47]) | uint64(reply[46])<<8 | uint64(reply[45])<<16 | uint64(reply[44])<<24

		nanosec := sec*1e9 + (frac*1e9)>>32

		t := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC).Add(time.Duration(nanosec)).Local()

		drifts = append(drifts, sent.Sub(t)+elapsed/2)
	}
	sort.Sort(durationSlice(drifts))

	drift := time.Duration(0)
	for i := 1; i < len(drifts)-1; i++ {
		drift += drifts[i]
	}
	return drift / time.Duration(measurements), nil
}

首先可以看出来是以udp方法通信的,目标地址是pool.ntp.org:123。首先构造了请求体,就是一个长度为48的字节数组,第一个字节是00011011(从左到右第3-5位表示协议版本号,也就是3,第6-8位表示操作模式,客户端为3)。接下来启动一个循环,一共5次。每次都通过udp与目标地址通信。读取到的响应长度也是48字节。分两部分,最后4字节表示秒的小数部分,再向前数4个字节表示秒。之后我们计算出服务给我们的时间的纳秒表示,之后计算出我们与服务器的时间差。详细的SNTP协议见官方文档

题图来自unsplash:https://unsplash.com/photos/-CZERTBlepA