作者: 章天杰

Go 并发:goroutine 调度与竞态处理

goroutine 的调度

操作系统线程、逻辑处理器和 goroutine

goroutine 是 Go 赖以支持并发的重要特性,Go 的运行时在一组逻辑处理器上调度 goroutine,每一个逻辑处理器都对应一个操作系统线程。

我们通过几种常见的情况来看看在 Go 的调度下 goroutine、逻辑处理器和操作系统线程的行为。

最普通的情况,当我们创建一个 goroutine 并准备运行它,这个 goroutine 会进入调度器的全局执行队列中,调度器会将队列中的 goroutine 分配给逻辑处理器来运行,goroutine 此时进入到逻辑处理器的本地执行队列中等待处理。下图说明了这种情况。

更复杂一些,有时 goroutine 会被系统调用阻塞,这时,逻辑处理器会与该 goroutine 及执行该 goroutine 的线程分离,如下图所示。此时,这个逻辑处理器会新建一个线程,并从本地执行队列中选取另一个 goroutine 来执行。当先前阻塞的 goroutine 准备好时,它会回到本地执行队列,多余的线程会保存好备用。

另一种常见的情况是 goroutine 中需要进行网络调用,此时 goroutine 与逻辑处理器也会分离,goroutine 会被到带网络轮询器的运行时上,直到网络调用完成之后,重新分配给逻辑处理器。

goroutine 的暂停与重新调度

一个 goroutine 在运行完毕之前,调度器可以暂停它的运行并随后重新调度它。我们通过下面这段程序来观察一下这种切换行为。

我们通过 runtime.GOMAXPROCS(1) 来限定只使用一个逻辑调度器,并在这个逻辑调度器上创建两个 goroutine,每一个 goroutine 中会执行一些比较耗时间的操作。

在 Go Playground 运行这段代码:https://go.dev/play/p/T3Clh4e-cH-

package main

import (
    "fmt"
    "runtime"
    "sync"
    "time"
)

var wg sync.WaitGroup

func main() {
    // 限定只使用一个逻辑调度器
    runtime.GOMAXPROCS(1)

    wg.Add(2)

    go longTimeFunc("A")
    go longTimeFunc("B")

    wg.Wait()
}

func longTimeFunc(name string){
    defer wg.Done()
    for i := 0; i < 100; i++ {
        time.Sleep(time.Duration(500)*time.Millisecond)
        fmt.Println(name)
    }
}
运行结果:
B
A <- 切换 goroutine
A
B <- 切换 goroutine
B
...(省略 195 行)

在这个执行结果中可以看到,在我们设置的唯一逻辑处理器上,两个 goroutine 切换运行。当然,每一次的执行结果会有所变化。

并行任务

上面的例子中只使用了一个逻辑处理器,接下来,我们来观察使用多个逻辑处理器来并行地来处理任务。

这里再来明晰一下并发与并行的概念。直观上,并发(concurrency)和并行(parallelism)都是在说“同时处理多件事情”,但它们并不是相同的概念:并行强调的是同时多件事情,某一时刻,多个事情同时在发生,而并发强调同时管理多件事情,这些事情可能是轮流做的,某一时刻,可能只有一件事情在发生。

打一个比方,并发是诊室里只有一个医生,他一会儿为病人 A 检查,一会儿为病人 B 开药,一会儿又回来看 A 的检查结果。而并行是诊室里有多个医生,每个医生同时在工作,各自负责一些病人。

对应到我们在谈论的 Go 上,我们的第一个例子中,多个 goroutine 是并发地在一个逻辑处理器上执行的。Go 1.5 之后,Go 运行时默认为每一个可用的物理处理器分配一个逻辑处理器,多个逻辑处理器会并行地执行任务。

多个逻辑处理器不一定会带来更好的性能,反之,单个逻辑处理器也不意味着低效率——Go 1.5 之前的版本上,一个 Go 程序只会分配一个逻辑处理器。

下面这段程序中,我们准备了两个 goroutine,它们分别会打印 100 个 A 和 100 个 B。首先我们还是像之前一样限定使用一个逻辑调度器。

在 Go Playground 运行这段代码:https://go.dev/play/p/QoqcDsEX4t7

package main

import (
    "fmt"
    "runtime"
    "sync"
)

func main() {
    // 限定只使用一个逻辑调度器
    runtime.GOMAXPROCS(1)

    var wg sync.WaitGroup
    wg.Add(2)

    // goroutine A
    go func() {
        defer wg.Done()
        for i := 1; i <= 100; i++ {
            fmt.Printf("A ")
        }
    }()

    // goroutine B
    go func() {
        defer wg.Done()
        for i := 1; i <= 100; i++ {
            fmt.Printf("B ")
        }
    }()

    wg.Wait()
}
运行结果:
B B B B B …(省略 90 个 B)… B B B B B A A A A A …(省略 90 个 A)… A A A A A

根据上面的运行结果可以分析到,在我们设置的唯一的逻辑调度器上,只发生了一次 goroutine 切换,因此呈现的是 100 个 B 后面跟着 100 个 A。这可以认为是因为打印这些字符需要的时间很短,所以没等待来回切换,两个 goroutine 就运行完了。

现在我们把逻辑处理器的数量调整到 2,即 runtime.GOMAXPROCS(2),其他代码没有变动,这里不再重复。

在 Go Playground 运行这段代码:https://go.dev/play/p/ZLnnjS3kMNr

运行结果:
B B B B A A A A A B B B B B B B A A A ...(后续字符省略)

可以看到,A 和 B 是混合着打印出来的,这是因为两个 goroutine 是分别在两个逻辑处理器上并行运行的(这个运行结果是在多核心的 CPU 上得到的,即两个逻辑处理器分别在两个物理处理器上)。当然,每次运行的结果会有所变化,不同的环境也会对运行结果造成影响。

共享资源与竞态

竞态

程序中总存在一些需要被多个 goroutine 同时访问并操作的共享资源,他可能是一个变量,可能是一个文件。如果不加同步控制去访问、操作这些资源,就会产生竞态。竞态会导致许多潜在的问题,我们通过下面这段程序制造竞态,来观察并分析问题是如何产生的。

这段程序中启动了两个 goroutine,分别对 shared 做 100000 次 +1 的操作。

在 Go Playground 运行这段代码:https://go.dev/play/p/AuauJtSkykF

package main

import (
    "fmt"
    "sync"
)

var shared int
var wg sync.WaitGroup

func main() {
    wg.Add(2)

    go inc("A")
    go inc("B")

    wg.Wait()
    fmt.Printf("%d\n", shared)
}

func inc(name string) {
    defer wg.Done()
    for i := 1; i <= 100000; i++ {
        shared = shared + 1
    }
}
运行结果:
107197

每次运行的结果会有所不同,最终的结果大概率不是 200000。这是因为 shared = shared + 1 不是一个原子操作,如果一个 goroutine 在一次自增未完全完成的时候被切换,那么另一个 goroutine 就会在原值上继续操作。也就是说,有一些自增操作被“覆盖”了。

Go 1.1 之后加入了竞态检测器,可以在测试程序(go test)、编译并运行程序(go run)、构建(go build)、安装包(go install)时加上 -race 标志,竞态检测器会监视内存访问、检测对共享变量的非同步访问,如果检测到竞争行为,会给出警告。

例如,对上面的程序使用 -race 标志:

go run -race main.go

得到下面的结果

==================
WARNING: DATA RACE
Read at 0x0000009c1e90 by goroutine 8:
  main.inc()
      /foo/main.go:24 +0x84

Previous write at 0x0000009c1e90 by goroutine 7:
  main.inc()
      /foo/main.go:24 +0xa4

Goroutine 8 (running) created at:
  main.main()
      /foo/main.go:15 +0xa8

Goroutine 7 (running) created at:
  main.main()
      /foo/main.go:14 +0x7b
==================
200000
Found 1 data race(s)
exit status 66

Go 的竞态检测器为我们返回了引发竞态的 goroutine 和对应代码所在的位置。

更多竞态检测器的信息可以参考 Go 官方博客 Introducing the Go Race Detector

既然竞态会导致各种潜在的问题,我们就要着手来应对它。Go 支持传统的同步机制——加锁,也有一种具有 Go 特色的同步机制——通道,我们先来看前者,它又可以分为两种形式——原子函数和互斥锁。

原子函数和互斥锁

原子函数

Go 在 atomic 中提供了一些原子函数,它们通过很底层的锁机制获得原子能力。例如,可以用原子函数 AddInt64 来改写上面的代码,如下所示。

在 Go Playground 运行这段代码:https://go.dev/play/p/RzHIWBVQjda

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
)

var shared int64
var wg sync.WaitGroup

func main() {
    wg.Add(2)

    go inc("A")
    go inc("B")

    wg.Wait()
    fmt.Printf("%d\n", shared)
}

func inc(name string) {
    defer wg.Done()
    for i := 1; i <= 100000; i++ {
        atomic.AddInt64(&shared, 1)
    }
}
运行结果:
200000

我们用 atomic.AddInt64(&shared, 1) 替换了原来的 shared = shared + 1,这个原子函数会保证同一时刻只有一个 goroutine 对 shared 进行并完成加法操作。

atomic 包提供原子交换(Swap*)、原子 CAS(CompareAndSwap*)、原子加(Add*)、原子取(Load*)、原子存(Store*)等一系列原子能力。

互斥锁

Go 在 sync 包中提供了 Mutex 类型,即互斥锁。互斥锁可以是一段代码成为临界区,同一时刻只允许一个 goroutine 进入。

我们用互斥锁来改写之前的代码,如下所示:

在 Go Playground 运行这段代码:https://go.dev/play/p/JFuMXZ1VhoY

package main

import (
    "fmt"
    "sync"
)

var shared int
var wg sync.WaitGroup
var mutex sync.Mutex

func main() {
    wg.Add(2)

    go inc("A")
    go inc("B")

    wg.Wait()
    fmt.Printf("%d\n", shared)
}

func inc(name string) {
    defer wg.Done()
    for i := 1; i <= 100000; i++ {
        mutex.Lock()
        shared = shared + 1 // 临界区
        mutex.Unlock()
    }
}
运行结果:
200000

shared = shared + 1 操作之前,我们对 mutex 上锁,在完成自增操作后解锁,这样就是使 shared = shared + 1 这一行成为了临界区,只有一个 goroutine 能够进入,也就避免了竞态带来的问题。

原子函数和互斥锁是传统且常见的同步机制,除此之外,Go 还提供另一种同步机制——通道。

通道

无缓冲通道

通道机制基于通信顺序进程(Communicating Sequential Process, CSP)这种消息传递模型。

正如“通道”这个名字所言,它是用来传递东西的。再来回顾一下,我们面临的问题是:有多个 goroutine 都要使用一个共享资源,但同时只能有一个 goroutine 使用这个资源。来看看通道机制如何做到这件事情。

让我们想象,现在有两个小人 Alice 和 Bob,分别站在两面悬崖上,它们都需要有一份共享的物品,现在由 Alice 持有,Bob 无法使用,Alice 使用完毕之后,要想办法给 Bob。

他们想到了一个好办法,准备一条绳子,Alice 和 Bob 各牵着绳子的一头,把这份物品从绳子上“滑”过去。

以此类比,Alice 和 Bob 就是两个 goroutine,他们之间用来从传递物品的绳子就是通道,更精确一些——这是一个无缓冲通道。也就是说,通道本身没有能力存储值,这就要求发送方和接受方同时准备好。就好像 Alice 和 Bob 的绳子一样,如果其中有一头没人牵着,那就没法传递物品,像下面这样:

如果发送、接收的两个 goroutine 没有都准备好,无缓冲通道就会导致其中一方阻塞等待,直到双方同时准备好为止,所以无缓冲通道的发送和接受动作本身就是同步的。下面我们用 Go 代码来模拟一下 Alice 和 Bob 来回传递物品的动作。

在 Go Playground 运行这段代码:https://go.dev/play/p/XxIx89bG1s8

package main

import (
    "fmt"
    "math/rand"
    "sync"
    "time"
)

var wg sync.WaitGroup

func main() {
    rand.Seed(time.Now().Unix())

    wg.Add(2)

    // 创建 int 类型的无缓冲通道
    rope := make(chan int)

    go transfer("Alice", rope)
    go transfer("Bob", rope)

    rope <- 0

    wg.Wait()
}

func transfer(name string, rope chan int) {
    defer wg.Done()
    for {
        // 从绳子上拿取物品
        times, ok := <-rope
        if !ok {
            fmt.Printf("%s 结束接收、n", name)
            return
        }
        // 以一定概率结束传送
        if n := rand.Intn(100); n > 80 {
            fmt.Printf("%s 收到物品并结束发送,物品共被传送 %d 次、n", name, times)
            close(rope)
            return
        }
        times++
        fmt.Printf("%s 收到物品,并开始物品的第 %d 次传送、n", name, times)
        rope <- times
    }
}

在 Go 中,需要用 make 来声明通道,声明通道时需要指定通道中允许传输的数据类型。上面的代码中,使用 rope := make(chan int) 创建了一个允许共享 int 类型数据的无缓冲通道。

随后,为 Alice 和 Bob 各启动了一个 goroutine,执行 transfer 操作。在 transfer 函数中,我们使用 <-rope 符号从通道中接收值,使用 rope <- 向通道中发送值。<- 符号和通道名称的位置与数据的流向是一致的。

上面这段程序的运行结果如下:

Bob 收到物品,并开始物品的第 1 次传送
Alice 收到物品,并开始物品的第 2 次传送
Bob 收到物品,并开始物品的第 3 次传送
Alice 收到物品,并开始物品的第 4 次传送
Bob 收到物品,并开始物品的第 5 次传送
Alice 收到物品并结束发送,物品共被传送 5 次
Bob 结束接收

这个结果符合我们的预期。

有缓冲通道

除了无缓冲通道,Go 还提供有缓冲通道,有缓冲通道在被接收前能存储一个或多个值。

还是用悬崖上的 Alice 和 Bob 来举例,他们之间现在建起了一条传送带!这意味着在一方准备好接收之前,共享资源可以存放在传送带上,发送者可以去干别的事情。

前面的无缓冲通道会在收发两方任一一方没准备好的时候阻塞,那么有缓冲通道的阻塞条件是什么呢?举个例子来看,这个例子中我们让 Alice 来做发送者,Bob 做接收者,传送带可以容纳三件物品放在上面。

我们用这个例子模拟了一个容量为 3 的有缓冲通道,它在没有空间继续容纳新的值时阻塞发送动作,在通道中没有值可以获取时阻塞接收动作。

我们用通道来改写一下之前多个 goroutine 各自自增 100000 次的程序。

在 Go Playground 运行这段代码:https://go.dev/play/p/A9J8Q2S-E7P

package main

import (
    "fmt"
    "sync"
)

var wg sync.WaitGroup

func main() {
    wg.Add(2)

    shared := make(chan int, 1)
    shared <- 0

    go inc("A", shared)
    go inc("B", shared)

    wg.Wait()

    close(shared)
    fmt.Printf("%d\n", <-shared)
}

func inc(name string, shared chan int) {
    defer wg.Done()
    for i := 1; i <= 100000; i++ {
        cur := <-shared
        cur++
        shared <- cur
    }
}
运行结果:
200000

在这段代码中使用了一个带 1 个容量的有缓冲通道。有缓冲通道的声明是在 make 的时候带上容量值,例如 shared := make(chan int, 1)。值得关注的是在第 21 行和 22 行,通道先被关闭了,随后我们又从中读出了值。这是因为,通道被关闭之后,仍可以从中接收数据,但不可以继续发送数据。

Java 创建线程的三种方式

继承 Thread 类创建线程

Java 使用 Thread 类代表线程,所有的线程对象都必须是 Thread 类或其子类的实例。

可以通过继承 Thread 类来创建线程,一般步骤如下:

  • 定义一个子类(本例中命名为 ThreadByThread),该类继承 Thread 类并重写 run()方法,run() 方法体即为线程需要完成的任务
  • 创建该子类的示例,即创建线程对象
  • 调用该线程对象的 start() 方法来启动线程

示例代码如下:

// ThreadTest.java
package cn.imztj.test.thread;

public class ThreadTest {
    public static void main(String[] args) {
        new ThreadByThread().start();
    }
}

class ThreadByThread extends Thread {
    @Override
    public void run() {
        for (int i = 0; i < 5; i++) {
            System.out.println("Current Thread: " + Thread.currentThread().getName() + ", Current i: " + i);
        }
    }
}

测试结果如下,可以看到,两个线程交替执行:

Current Thread: Thread-1, Current i: 0
Current Thread: Thread-0, Current i: 0
Current Thread: Thread-1, Current i: 1
Current Thread: Thread-0, Current i: 1
Current Thread: Thread-1, Current i: 2
Current Thread: Thread-0, Current i: 2
Current Thread: Thread-1, Current i: 3
Current Thread: Thread-0, Current i: 3
Current Thread: Thread-1, Current i: 4
Current Thread: Thread-0, Current i: 4

实现 Runnable 接口创建线程

  • 首先,我们定义一个 Runnable 接口的实现类(本例中命名为 ThreadByRunnable),并重写该接口的 run() 方法,run() 方法体即为线程需要完成的任务
  • 创建该实现类的实例,并用这个实例作为 Thread 的 target 来创建线程对象
  • 通过调用线程对象的 start() 方法来启动线程
// RunnableTest.java
package cn.imztj.test.thread;

public class RunnableTest {
    public static void main(String[] args) {
        ThreadByRunnable threadByRunnable = new ThreadByRunnable();
        new Thread(threadByRunnable, "线程 1").start();
        new Thread(threadByRunnable, "线程 2").start();
    }
}

class ThreadByRunnable implements Runnable {

    @Override
    public void run() {
        for (int i = 0; i < 5; i++) {
            System.out.println("Current Thread: " + Thread.currentThread().getName() + ", Current i: " + i);
        }
    }
}

测试结果如下,可以看到两个线程交替执行,在本例中,我们还对线程进行了命名:

Current Thread: 线程 2, Current i: 0
Current Thread: 线程 1, Current i: 0
Current Thread: 线程 2, Current i: 1
Current Thread: 线程 1, Current i: 1
Current Thread: 线程 2, Current i: 2
Current Thread: 线程 1, Current i: 2
Current Thread: 线程 2, Current i: 3
Current Thread: 线程 1, Current i: 3
Current Thread: 线程 2, Current i: 4
Current Thread: 线程 1, Current i: 4

利用 Callable 和 Future 创建线程

与 Runnable 接口不同,Callable 接口提供了一个 call() 方法作为线程执行体,与 Runnable 接口的 run() 方法相比,call() 方法可以有返回值,可以声明抛出异常,而 run() 方法则不行。

Callable 接口在 JUC (java.util.concurrent) 包下,其定义如下:

public interface Callable<V> {
    /**
     * Computes a result, or throws an exception if unable to do so.
     *
     * @return computed result
     * @throws Exception if unable to compute a result
     */
    V `call()` throws Exception;
}

Callable 接口不是 Runnable 接口的子接口,Callable 对象是不能直接作为 Thread 对象的 target 的。针对这个问题,Java5 提供了 Future 接口来接收 Callable 接口中 call() 方法的返回值,Java 还引入了 RunnableFuture 接口,它是 Runnable 接口和 Future 接口的子接口,可以作为 Thread 对象的 target。同时提供了一个 RunnableFuture 接口的实现类:FutureTask ,FutureTask 可以作为 Thread 对象的 target。

我们可以在 IDEA 中查看上述类和接口的关系图:

使用 Callable 和 Future 创建线程的步骤如下:

  • 定义一个类实现 Callable 接口,并重写 call() 方法,call() 方法即为线程执行体,有返回值,且可以声明异常。
  • 创建 Callable 实现类的实例,使用 FutureTask 类来包装 Callable 对象
  • 使用 FutureTask 对象作为 Thread 对象的 target 创建并启动线程
  • 调用 FutureTask 对象的 get() 方法来获得子线程执行结束后的返回值
// CallableTest.java
package cn.imztj.test.thread;

import java.util.concurrent.Callable;
import java.util.concurrent.FutureTask;

public class CallableTest {
    public static void main(String[] args) {
        ThreadByCallable threadByCallable = new ThreadByCallable();
        FutureTask<String> futureTask1 = new FutureTask<>(threadByCallable);
        FutureTask<String> futureTask2 = new FutureTask<>(threadByCallable);
        new Thread(futureTask1, "线程 1").start();
        new Thread(futureTask2, "线程 2").start();
        try {
            System.out.println("线程 1 的返回值:" + futureTask1.get());
            System.out.println("线程 2 的返回值:" + futureTask2.get());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

class ThreadByCallable implements Callable {

    @Override
    public Object `call()` throws Exception {
        for (int i = 0; i < 5; i++) {
            System.out.println("Current Thread: " + Thread.currentThread().getName() + ", Current i: " + i);
        }
        return Thread.currentThread().getName();
    }
}

执行结果如下:

Current Thread: 线程 1, Current i: 0
Current Thread: 线程 2, Current i: 0
Current Thread: 线程 1, Current i: 1
Current Thread: 线程 2, Current i: 1
Current Thread: 线程 1, Current i: 2
Current Thread: 线程 2, Current i: 2
Current Thread: 线程 1, Current i: 3
Current Thread: 线程 2, Current i: 3
Current Thread: 线程 1, Current i: 4
Current Thread: 线程 2, Current i: 4
线程 1 的返回值:线程 1
线程 2 的返回值:线程 2

三种创建方式的对比

采用继承 Thread 类方式,优点是代码编写上比较简单,如果需要访问当前线程,无需使用 Thread.currentThread() 方法,可以直接使用 this。缺点是,由于已经继承了 Thread 类,所以它将不能再继承其他的父类。

采用实现 Runnable 接口方式,优点是线程类还可以继承其他的类,而且在这种方式下,多个线程可以共享同一个目标对象,适合多个相同线程来处理同一份资源。在这种情况下,如果需要访问当前线程,必须使用 Thread.currentThread() 方法。

采用实现 Callable 接口的方式,同样可以可以避免 Java 中单继承的限制,适合多个线程进行资源共享。这种方式还可以获得一个 Future 对象,通过 Future 对象可以得到异步计算的结果。它提供了检查计算是否完成的方法,以等待计算的完成,并检索计算的结果,也可以取消任务的执行。

另外,实际开发中应避免显式创建线程,而是建议使用线程池,而线程池只能放入 Runable 或 Callable 接口实现类,不能直接放入继承 Thread 的类。

平衡二叉树(AVL 树):概念、实现原理和算法代码

引入

上一篇中已经介绍了二叉搜索树(BST),在二叉搜索树的复杂度分析中,我们提到,二叉搜索树的算法复杂度与其拓扑结构(具体来说是其树的深度)有关。当 BST 最为「平衡」时,查找、删除、插入的算法为 O(\log{n}),而最不平衡情况下,算法复杂度为 O(n)

当节点数 n 增长时,O(\log{n}) 的算法相对于 O(n) 的算法在运算次数上增长较慢,性能更优。

因此,我们希望按照更加「平衡」的方式构建二叉树。那么,能否让二叉树一层一层地「铺满」,构建出一个完全二叉树?可以的,但是这样的插入算法较为复杂。我们下面介绍一种比较容易实现的解决平衡二叉树的算法——AVL 树,它是由两位俄罗斯数学家 G. M. Adelson-Velsky 和 Evgenii Landis 在 1962 年发明的。

定义

怎样的一棵树可以被称为高度平衡的呢?

一棵 BST 若要成为平衡二叉树,应满足:

  • 其左子树和右子树深度之差绝对值不超过 1
  • 其左子树和右子树也是平衡二叉树

这是一个具有递归特征的定义。值得注意的是,从定义上说,空树也属于一种特殊的平衡二叉树。

引入平衡因子 (Balance Factor, BF) 的概念,平衡因子是指二叉树上一个节点的左子树与右子树的深度之差。一棵平衡二叉树,其任意节点上的平衡应为 – 1、0 或 1。

根据平衡二叉树的定义,下面给出一些例子,节点左上角标注的是当前节点的平衡因子:

图一是一个平衡二叉树,它满足平衡二叉树的定义。图二不是平衡二叉树,其原因并不是不满足平衡因子的条件,而是因为它不满足 BST 的构成条件,这提醒我们平衡二叉树首先要是一棵 BST。图三满足平衡二叉树的构成条件。图 4 中的节点 (8) 平衡因子为 3,不满足平衡二叉树的要求。

实现原理

概述

AVL 树的实现原理是,在构建 BST 时,时刻对 BST 进行「监督」,即每插入一个节点,便检查插入操作是否破坏了 BST 的平衡性。

如果当前的插入操作破坏了平衡性,那么要找出最小不平衡子树,对其执行某种「旋转」操作来调整其中节点的连接关系,使之重新平衡。

这里所说的最小不平衡子树,指的是以离插入节点最近,且平衡因子绝对值大于 1 的节点为根的子树。(注意:是离插入节点最近,而不是离根节点最近)

那么问题的关键就是,如何进行这种「旋转」调整,使得树始终保持平衡?下面将进行叙述。

根据插入操作发生位置的不同,可以将需要进行调平的情况分成下面四类(记最小不平衡子树的根节点为 N):

  • 情况①:对 N 的左儿子的左子树进行插入之后(左 – 左插入)
  • 情况②:对 N 的左儿子的右子树进行插入之后(左 – 右插入)
  • 情况③:对 N 的右儿子的左子树进行插入之后(右 – 左插入)
  • 情况④:对 N 的右儿子的右子树进行插入之后(右 – 右插入)

下面是四种情况的图示:

可以看出情况①和情况④是对称的,情况②和情况③是对称的,所以在理解上实际只需要理解两种情况,但是在编程时要区分四种情况。

我们先统领性地讲一下针对情况①④和情况②③采用的调平方法:

  • 情况①④(即左 – 左和右 – 右),通过单旋转进行调平。
  • 情况②③(即左 – 右和右 – 左),通过双旋转进行调平。

下面具体介绍两种旋转方法。

单旋转

单旋转可以用于处理情形①和情形④,其中情形①需要右旋(顺时针旋转)调平,情形④需要左旋(逆时针旋转)调平。下面通过两个实例来理解这里提到的右旋和左旋。

首先看上图的上半部分。向原 BST 中插入节点 (6) 时破坏了 BST 的平衡,标为红色节点 (8) 处为最小不平衡子树的根节点(显然,节点 (6) 是对节点 (8) 左儿子的左子树进行插入,符合情形①)。将以节点 (8) 为根节点的最小不平衡子树「取出」,我们发现,如果对这棵子树进行一次右旋(顺时针旋转),让节点 (7) 成为根节点,节点 (6) 和节点 (8) 成为节点 (7) 的左右子节点,将其放回原 BST,即可保持 BST 的平衡。

同样的,对于情形④,对最小不平衡子树进行一次左旋(顺时针旋转),便可保持 BST 平衡。

下面给出一个更加形式化的操作说明:

如上图,对于情况①,标为红色的 n_2 为其最小不平衡子树的根节点,为了使 BST 重新平衡,将其进行一次右旋,使得 X 抬高一层,Z 下降一层,同时 Y 作为 n_1 的右子树被转移给 n_2 作为左子树。情况④是对称的,不再赘述。

通过单次左旋或右旋,即可调整情况①④导致的不平衡。

双旋转

单次旋转已经可以处理一部分的不平衡,但是针对情况②③,却是无效的,这里用形式化的表示来说明这种无效的情况:

可见,对于情况②,一次旋转并不能消除不平衡。情况③与情况②对称,可想而知结果相同。

对此,需要借助左 – 右双旋转或右 – 左双旋转,下面是一个实例。

节点 (3) 的插入破坏了节点 (4) 处的平衡,节点 (4) 为最小不平衡子树的根节点。如图所示,先执行左旋,再执行右旋,即可通过两次旋转纠正不平衡。

下面给出形式化的操作图示:

代码实现

搞清楚了如何进行通过「旋转」进行调平,便可以写出 AVL 树的代码了。

单旋转

/* 向左单旋转 */
static Position SingleRotateWithLeft(Position K2)
{
    Position K1;

    K1 = K2->Left;
    K2->Left = K1->Right;
    K1->Right = K2;

    K2->Height = Max(Height(K2->Left), Height(K2->Right)) + 1;
    K1->Height = Max(Height(K1->Left), K2->Height) + 1;

    return K1;
}

/* 向右单旋转 */
static Position SingleRotateWithRight(Position K1)
{
    Position K2;

    K2 = K1->Right;
    K1->Right = K2->Left;
    K2->Left = K1;

    K1->Height = Max(Height(K1->Left), Height(K1->Right)) + 1;
    K2->Height = Max(Height(K2->Right), K1->Height) + 1;

    return K2;
}

双旋转

/* 左 - 右双旋转 */
static Position DoubleRotateWithLeft(Position K3)
{
    K3->Left = SingleRotateWithRight(K3->Left);
    return SingleRotateWithLeft(K3);
}

/* 右 - 左双旋转 */
static Position DoubleRotateWithRight(Position K1)
{
    K1->Right = SingleRotateWithLeft(K1->Right);
    return SingleRotateWithRight(K1);
}

插入操作

/* 向 AVL 中插入节点 */
AvlTree Insert(ElementType X, AvlTree T)
{
    if (T == NULL)
    {
        T = malloc(sizeof(struct AvlNode));
        if (T == NULL)
            FatalError("Out of space!!!");
        else
        {
            T->Element = X;
            T->Height = 0;
            T->Left = T->Right = NULL;
        }
    }
    else if (X < T->Element)
    {
        T->Left = Insert(X, T->Left);
        if (Height(T->Left) - Height(T->Right) == 2)
            if (X < T->Left->Element)
                T = SingleRotateWithLeft(T);
            else
                T = DoubleRotateWithLeft(T);
    }
    else if (X > T->Element)
    {
        T->Right = Insert(X, T->Right);
        if (Height(T->Right) - Height(T->Left) == 2)
            if (X > T->Right->Element)
                T = SingleRotateWithRight(T);
            else
                T = DoubleRotateWithRight(T);
    }

    T->Height = Max(Height(T->Left), Height(T->Right)) + 1;
    return T;
}

参考资料

[1] 程杰. 大话数据结构 [M]. 清华大学出版社: 北京, 2011:328-341.
[2] (美)Mark Allen Weiss. 数据结构与算法分析: C 语言描述 [M]. 机械工业出版社: 北京, 2017:80-89.
[3] Vamei(张腾飞). 纸上谈兵: AVL 树 [EB/OL].https://www.cnblogs.com/vamei/archive/2013/03/21/2964092.html,2013-3-21.

二叉查找树(BST):概念、基本操作和性能分析

二叉查找树 (Binary Search Tree, BST) 是二叉树在查找中的一种重要应用形式。

在这篇文章接下来的叙述中,做出以下假设:

  • 尽管二叉树节点中存储的数据类型是任意的,但为了大小比较和理解的方便,本文中将数据类型指定为整数。
  • 各个节点中的存储的数据是没有重复的。

定义

若一棵二叉树为二叉查找树,那么它必须满足:

  • 对于树中的任意节点 X,若其左子树不为空,则左子树中每一个节点的值都小于其根节点的值
  • 对于树中的任意节点 X,若其右子树不为空,则右子树中每一个节点的值都大于或等于其根节点的值
  • 对于树中的任意节点 X,其左子树、右子树也为二叉查找树

这是一个明显带有递归特征的定义。根据这个定义,下面给出一个正确的 BST 示例和两个错误的 BST 实例。

声明

我们给出 BST 的类型定义并列出 BST 的一系列基本操作,如下即为 BST 的声明程序。

typedef int ElementType;

struct TreeNode;
typedef struct TreeNode *Position;
typedef struct TreeNode *SearchTree;

/*BST 的基本操作定义 */
SearchTree MakeEmpty(SearchTree T);
Position Find(ElementType X, SearchTree T);
Position FindMin(SearchTree T);
Position FindMax(SearchTree T);
SearchTree Insert(ElementType X, SearchTree T);
SearchTree Delete(ElementType X, SearchTree T);
ElementType Retrieve(Position P);

/* 单个树节点的结构体定义 */
struct TreeNode
{
    ElementType Element;
    SearchTree Left;
    SearchTree Right;
};

接下来我们将对 BST 的基本操作进行逐一实现。

基本操作

初始化 MakeEmpty

该操作用于初始化树 T

事实上,我们可以把 BST 的第一个元素初始化为一个单节点树,但是下面给出的初始化代码是一种更加遵循递归原则的形式(如 BST 的定义一样)。

/* 初始化一棵树 T*/
SearchTree MakeEmpty(SearchTree T)
{
    if (T != NULL)
    {
        MakeEmpty(T->Left);
        MakeEmpty(T->Right);
        free(T);
    }
    return NULL;
}

查找任意值 Find

该操作用于返回给定的树 T 中值为 X 的节点的指针。若不存在则返回 NULL。

/* 对树 T 查找值 X 的位置 */
Position Find(ElementType X, SearchTree T)
{
    if (T == NULL)
        return NULL;
    if (X < T->Element)
        return Find(X, T->Left);
    else if (X> T->Element)
        return Find(X, T->Right);
    else
        return T;
}

如果树 T 本身为 NULL,则直接返回 NULL。否则根据所寻值 X 与当前节点值的大小关系对左子树或右子树进行递归调用,若找到相等值即返回。

查找最小/最大值 FindMin/FindMax

与查找任意值 X 的 Find 操作类似,查找最小 / 最大值的程序可以很容易地利用递归的思想写出来。

下面的例程中阐释了 FindMin 操作的递归实现:从根节点出发,不断地向左寻找左子树(这是由于 BST 的性质),最终找到的就是最小的元素。FindMax 操作仅在寻找的方向上有所不同。

/*FindMin 操作的递归实现 */
Position FindMin(SearchTree T)
{
    if(T == NULL)
        return NULL;
    else
    if(T->Left == NULL)
        return T;
    else
        return FindMin(T->Left);
}

事实上,查找最小/最大值的操作使用非递归的思想实现也是非常简单的,下面是 FindMin 操作的非递归实现。

/*FindMin 操作的非递归实现 */
Position FindMin(SearchTree T)
{
    if(T != NULL)
        while(T->Left != NULL)
            T = T->Left;

    return T;
}

插入 Insert

在查找操作的基础上,我们很容易实现 BST 的插入操作,其本质上就是在树 T 中寻找适合待插值 X 的位置:如果在树 T 中找到了 X,那么不执行操作(或者执行特定的更新操作);如果树 T 中不存在 X,那么,Find 操作最终找到的位置就是适合 X 的位置。

/* 将值 X 插入到二叉查找树 T 中 */
SearchTree Insert(ElementType X, SearchTree T)
{
    if (T == NULL)
    {
        T = malloc(sizeof(struct TreeNode));
        if (T == NULL)
            return NULL; /*Out of space*/
        else
        {
            T->Element = X;
            T->Left = T->Right = NULL;
        }
    }
    else if (X < T->Element)
        T->Left = Insert(X, T->Left);
    else if (X> T->Element)
        T->Right = Insert(X, T->Right);

    return T;
}

有了上面的 Insert 操作,我们就可以用一段简单的程序来构造一棵 BST 了,像下面这样:

int i;
int a[7] = {67, 73, 129, 25, 101, 310, 123};
SearchTree *T = NULL;
for (i = 0; i < 7; i++)
{
    Insert(a[i], T);
}

删除 Delete

与前面的操作相比,从一棵 BST 中删除一个节点是最困难的,因为我们需要保证删除节点后的新树依然满足 BST 的性质。因此我们需要将可能存在的情况进行分类考虑。

首先是最简单的情况:要删除的节点是一个叶子节点,这时可以直接删除该节点,下面是这种情况的示意图。

其次是待删除节点 N 有一个子节点的情况,这时需要将其父节点 N_{parent} 连接至待删除节点 N 的子节点 N_{child} 后,再对待删除节点执行删除,下面是这种情况的示意图。

最复杂的情况是待删除的节点 N 有两个子节点 N_{lchild}N_{rchild},这时,通常采用的删除方法是,找到以 N_{rchild} 为根节点的右子树中数据最小的节点 N_{rmin} 来替代 N,然后递归地删除 N_{rmin}
举个例子,在下左的 BST 中删除具有两个子节点的节点 (2),根据上述的删除方法,删除的结果就是,(2) 的右子树中的最小节点 (3) 替代了 (2),并且原 (3) 节点被删除。

根据上面的思路,可以得到如下的删除操作代码:

/* 从树 T 中删除值为 X 的节点 */
SearchTree Delete(ElementType X, SearchTree T)
{
    Position TmpCell;

    if (T == NULL)
        Error("Element not found");
    else if (X < T->Element)
        T->Left = Delete(X, T->Left);
    else if (X> T->Element)
        T->Right = Delete(X, T->Right);
    else
        if (T->Left && T->Right)
        {
            TmpCell = FindMin(T->Right);
            T->Element = TmpCell->Element;
            T->Right = Delete(T->Element, T->Right);
        }
    else
    {
        TmpCell = T;
        if (T->Left == NULL)
            T = T->Right;
        else if (T->Right == NULL)
            T = T->Left;
        free(TmpCell);
    }
    return T;
}

性能分析

考虑两个元素内容相同但顺序不同的序列:{4, 2, 1, 3, 6, 5, 7, 9, 8} 和 {1, 2, 3, 4, 5, 6, 7, 8, 9},下图展示了两种序列得到的 BST,前一种得到的是一个相对「平衡」的 BST,而后一种,其构造序列是严格从小到大排列的,于是得到的 BST 是一种极端的右斜树。

在这两个不同的树上考虑 BST 的查找、插入和删除操作,很容易得出结论:BST 各项操作的复杂度和树的拓扑结构有着密切的联系。以查找操作为例,在左树中查找节点 (7) 仅需两次查找,而右树中则需要 7 次。

因此在树的结构上,我们希望 BST 是比较「平衡」的,其中最优的情况是,由 n 个元素构成的 BST 的高度 h 与其构成的完全二叉树高度相等,即满足 h = \lfloor \log{n} \rfloor + 1,最差的情况类似上图右树所示,其高度满足 h=n

与之对应,查找、删除、插入算法的复杂度最优情况下为 O(\log{n}),最坏情况下为 O(n)

这里也引出了一个问题,即如何使二叉树更加平衡,这涉及到一种古老的平衡查找树:AVL 树,将在下一篇做介绍。

参考资料

[1] (美)Mark Allen Weiss. 数据结构与算法分析: C 语言描述 [M]. 机械工业出版社: 北京, 2017:73-80.
[2] 程杰. 大话数据结构 [M]. 清华大学出版社: 北京, 2011:313-328.
[3] 维基百科:二叉搜索树 [EB/OL].https://zh.wikipedia.org/wiki/二叉搜索树,2020-2-14.

旷视「天元」深度学习框架上手:概况、安装和初步体验

概况

天元(英文名:MegEngine)是旷视科技 3 月 25 日开源的深度学习框架,这一名字取自围棋棋盘中心点的名称,也有向 AlphaGo 致敬之意。听了旷视大佬们在发布会上的介绍,MegEngine 从 2014 年开始研发,作为旷视内部全员使用的框架,MegEngine 是驱动旷视在深度学习领域取得一系列成绩的核心动力。

趁没开学在家,对照官方文档快速上手了一下,这篇文章会记录一下对 MegEngine 初步的一些感受。

  • 天元 MegEngine 官方网站:https://megengine.org.cn/
  • GitHub repo:https://github.com/MegEngine/MegEngine
  • 中文社区:https://discuss.megengine.org.cn/
  • MegStudio:https://studio.brainpp.com/

安装

目前 MegEngine 的支持平台还比较单一,仅支持 Linux 环境下安装。对 Windows 用户,官方提示了可以使用 WSL 来运行 MegEngine,不过只支持 CPU 后端。

详细的环境要求是:

  • 64 位、16.04 及以上版本的 Ubuntu
  • Python 3.5+
  • NVIDIA 驱动版本 418.x

旷视 MegStudio 平台提供了 MegEngine 0.3.1 + Python 3.8 的环境,可以很方便地开始使用 MegEngine。

除了 MegStudio,实测 Google Colab 平台也可以成功安装 MegEngine。

MegEngine 安装包本身集成了 CUDA 环境,因此不区分 CPU 和 GPU 版本。

通过 pip 安装 MegEngine:

pip3 install megengine -f https://megengine.org.cn/whl/mge.html

整个安装包 700MB+,下载的速度非常快,服务器在国内果然无惧速度问题。

import 一下来测试安装是否成功,官方对 MegEngine 的习惯性缩写是 mge

import megengine as mge

对于要参与到 MegEngine 开发贡献的开发者,或者需要使用未进入 release 的功能,则需要从源码安装,可以参考 README 中的指引

体验

MegStudio

MegStudio 是旷视开放的提供免费算力的在线深度学习开发平台。目前提供下面三种配置的环境,其中基础版的环境不限时长,高级版(CPU)和高级版(GPU)环境需要通过算力卡获得。目前算力卡是通过邀请用户的方式获得。

MegStudio 开发环境是基于 JupyterLab,环境关闭之后额外添加的文件会被销毁。使用体验上和 JupyterLab 基本没有区别,如果官方能默认支持一下代码补全就更好了。

友好的中文文档

对照 官方文档 上手的过程当中,很大的一个体会是 MegEngine 的文档做得很用心。

首先从语言上说,中文书写的文档对于国内的开发者来说无疑是很友好的。

官方文档目前分为基础学习和进阶学习两个部分,MegEngine 团队设计了一个循序渐进的入门指引帮助使用者熟悉 MegEngine。

MegEngine 的文档中穿插了对神经网络知识的简要讲解,比如下面是文档中介绍的 BP 过程。

文档代码中的注释也非常详细,介绍了代码释义,和 API 的细节,阅读起来会比较轻松。


支持基于 Module 的网络搭建

MegEngine 提供两种网络搭建方式:基于 functional (提供常见算子)和基于 Module(提供常见网络层)。

基于 Module 的构建方式和 PyTorch 的风格十分相似,下面分别是 MegEngine 和 PyTorch 的 LeNet 实现代码,可以看出总体的写法没有很大的区别,对 PyTorch 用户来说,熟悉 MegEngine 网络搭建压力很小(官方也提到支持导入 PyTorch Module)。

# MegEngine Implementation
class LeNet(M.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = M.Conv2d(1, 6, 5)
        self.relu1 = M.ReLU()
        self.pool1 = M.MaxPool2d(2, 2)
        self.conv2 = M.Conv2d(6, 16, 5)
        self.relu2 = M.ReLU()
        self.pool2 = M.MaxPool2d(2, 2)
        self.fc1 = M.Linear(16 * 5 * 5, 120)
        self.relu3 = M.ReLU()
        self.fc2 = M.Linear(120, 84)
        self.relu4 = M.ReLU()
        self.classifer = M.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = F.flatten(x, 1)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.classifer(x)
        return x
# PyTorch Implementation
class LeNet5(nn.Module):
   def __init__(self):
       super().__init__()
       self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
       self.conv2 = nn.Conv2d(6, 16, 5)
       self.fc1 = nn.Linear(16*5*5, 120)
       self.fc2 = nn.Linear(120, 84)
       self.fc3 = nn.Linear(84, 10)
   def forward(self, x):
       x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
       x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
       x = x.view(-1, self.num_flat_features(x))
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x
   def num_flat_features(self, x):
       size = x.size()[1:]
       num_features = 1
       for s in size:
           num_features *= s
       return num_features

动态图与静态图的转换

此前的 TensorFlow 框架注重静态图,这对于调试来说比较困难,PyTorch 侧重动态图,对调试开发来说比较便利,但是不利于高效部署。MegEngine 在动态图和静态图的兼容上下了功夫,支持通过用 jit 包中的 trace 装饰器来完成动静转换,这样调试时可以使用动态图,实时输出中间结果,部署时再使用静态图,提升训练和推理速度。这也是 MegEngine 作为工业级框架的重要特性之一。

import megengine.functional as F
from megengine.jit import trace

trace.enabled = True

@trace
def train_func(data, label, *, opt, net):
    pred = net(data)
    loss = F.cross_entropy_with_softmax(pred, label)
    opt.backward(loss)
    return pred, loss

train_func(data, label, opt=optimizer, net=le_net)

中文社区

MegEngine 建立了中文语言的 社区,对国内开发者无疑是一大好处,我在站务反馈区提了一个捉虫帖,没想到没几分钟之后管理员就反馈修改完毕,效率可以说是非常高了!

以上就是对 MegEngine 初步的体验,对国内人工智能生态来说,MegEngine 的开源无疑是一件好事,期待未来 MegEngine 和国内深度学习框架的发展!