I recently bought a Raspberry Pi to tinker around with and I've been having a blast developing little applications and services for it. I've been wanting to expose some of these services however I'm skeptical about my abilities at doing Linux system administration so I figured I would expose them through another server on my network. As I've mentioned before I've been playing around with Go so I figured it'd be a fun project to build a small TCP request proxying server with it. Ideally the server would listen for TCP requests and forward them to another host and/or port. Note that this post was heavily inspired by this gist by vmihailenco. The complete code for this blog post can be found at the jokeofweek/gotcpproxy repo.

To start out, we'll be using the flag package to accept command line arguments for the proxy server's host/port and remote host/port to which all requests will be proxied. By default we'll simply forward all packets from localhost:80 to localhost:8000.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
package main

import (
    "flag"
    "fmt"
)

var fromHost = flag.String("from", "localhost:80", "The proxy server's host.")
var toHost = flag.String("to", "localhost:8000", "The host that the proxy " +
    " server should forward requests to.")

func main() {
    // Parse the command-line arguments.
    flag.Parse()
    fmt.Printf("Proxying %s->%s.\r\n", *fromHost, *toHost)
}

Now that we're reading in our arguments, let's set up the proxy server which will listen for incoming connections. To do this we will use the net package, which has an extremely convenient Listen function which sets up a server socket. We will also be using the log package to log errors. It turns out that calling log.Fatal will actually quit your application afterwards, so we can use that for errors with starting the server.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
package main

import (
    "flag"
    "fmt"
    "log"
    "net"
)

// ...

func main() {
    // ...
    // Set up our listening server
    server, err := net.Listen("tcp", *fromHost)

    // If any error occurs while setting up our listening server, error out.
    if err != nil {
        log.Fatal(err)
    }
}

From this point, we could simply loop indefinitely accepting requests and proxying them. However we'd like to make it so that some external user can't use Alice and Bob's proxy server to bombard the hidden service so we will limit the number of active connections. This can be particularly useful for a low-power machine like the Raspberry Pi. In order to do this we're going to use an extremely neat feature of go called channels. Channels are structures to which programs can send values, and the channels can block until they receive a value. We will effectively be using this channel as a semaphore to limit the number of connections that can be interacting with the remote server at any given time. We can pass an extra argument to the channels which specifies how many items it can hold (the buffer size).

We will have two channels - one will be a channel of awaiting connections, and another will be a channel of booleans representing how many connections are currently free. Whenever a connection is accepted, it is sent to the first channel. A seperate goroutine will block until a boolean is available in the second channel, take it, process the connection, and then put a boolean back in the second channel. Let's add a command line argument for the maximum number of active connections and create our two channels. Note that we have to initialize the boolean channel with the maximum number of booleans, or else we will never be able to accept connections.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
// ...
var maxConnections = flag.Int("c", 25, "The maximum number of active " +
    "connection at any given time.")

func main() {
    // ...
    // The channel of connections which are waiting to be processed.
    waiting := make(chan net.Conn)
    // The booleans representing the free active connection spaces.
    spaces := make(chan bool, *maxConnections)
    // Initialize the spaces
    for i := 0; i < *maxConnections; i++ {
        spaces <- true
    }
}

Now that we've got our two channels created, we will create our goroutine (called the connection matcher) which will use the range clause to iterate over the elements send to the waiting channel. We will block until there is an active space and then reserve it and handle the connection. Once that's done, our server can simply loop indefinitely accepting connections and sending them to the waiting channel.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
func main() {
    // ...
    // Start the connection matcher.
    go matchConnections(waiting, spaces)

    // Loop indefinitely, accepting connections and handling them.
    for {
        connection, err := server.Accept()
        if err != nil {
            // Log the error.
            log.Print(err)
        } else {
            // Create a goroutine to handle the conn
            log.Printf("Received connection from %s.\r\n",
                connection.RemoteAddr())
            waiting <- connection
        }
    }
}

func matchConnections(waiting chan net.Conn, spaces chan bool) {
    // Iterate over each connection in the waiting channel
    for connection := range waiting {
        // Block until we have a space.
        <-spaces
        // Create a new goroutine which will call the connection handler and 
        // then free up the space.
        go func(connection net.Conn) {
            handleConnection(connection)
            spaces <- true
            log.Printf("Closed connection from %s.\r\n", connection.RemoteAddr())
        }(connection)

    }
}

func handleConnection(connection net.Conn) {
    // Handle our connection...
}

Now that we've got our basic server up and running, how can we go about proxying the data from the client to the remote server and back? We're going to use the net.Dial function to connect to our remote server.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
func handleConnection(connection net.Conn) {
    // Always close our connection.
    defer connection.Close()

    // Try to connect to remote server.
    remote, err := net.Dial("tcp", *toHost)
    if err != nil {
        // Exit out when an error occurs
        log.Print(err)
        return
    }
    defer remote.Close()
}

But now how do we actually transfer the content? We need to copy the bytes from the client to the remote server and vice-versa, but as we don't want to limit ourselves just to HTTP, we need to do this until one of the connections is closed. We'll create a function which will simply take care of copying content from one connection to another until either is closed, and we'll run this function in two seperate goroutines to handle both directions. We will use a channel called complete to signal when at least one of the connections is closed to know when we're done handling our connection.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
func handleConnection(connection net.Conn) {
    // ...

    // Create our channel which waits for completion, and start both copying
    // goroutines.
    complete := make(chan bool)
    go copyContent(connection, remote, complete)
    go copyContent(remote, connection, complete)
    // Block until we've completed!
    <- complete
}

func copyContent(from net.Conn, to net.Conn, complete chan bool) {
    var err error = nil
    var bytes []byte = make([]byte, 256)
    var read int = 0
    for {
        // Read data from the source connection.
        read, err = from.Read(bytes)
        // If any errors occured, write to complete as we are done (one of the
        // connections closed.)
        if err != nil {
            complete <- true
            break
        }
        // Write data to the destination.
        _, err = to.Write(bytes[:read])
        // Same error checking.
        if err != nil {
            complete <- true
            break
        }
    }
}

And that's it! We've now got a simple proxy server which takes care of limiting active connections for us. One thing to keep in mind - web browsers often make many connections in parallel so you may want to up the maximum number of connections if you're proxying some web service as the limit can quickly be hit. This was a very fun project for me to build and I think it was a neat opportunity to learn about go's concurrency model. Again, all the code for this post can be found at the jokeofweek/gotcpproxy repo.

I hope you enjoyed this post, thanks for reading!

Dominic

Edit:

I've had two mistakes pointed out to me since I first posted this, the first by my friend Antoine and the second by commenter funny_falcon.

The first is that I never put a limit on the number of items in the waiting channel, so in theory the proxy server itself could easily be bombarded and fill up our memory. Let's add an option for the maximum number of waiting connctions.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// ...
var maxWaitingConnections = flag.Int("cw", 10000, "The maximum number of " +
    "connections that can be waiting to be served.")

func main() {
    // ...
    // The channel of connections which are waiting to be processed.
    waiting := make(chan net.Conn, *maxWaitingConnections)
    // ...
}
// ...

Now the main function's accept loop will block until we have space in the waiting channel before accepting another socket!

The second mistake is that the complete channel is not buffered, so there's no max size, yet we only wait for one goroutine to send a complete message and then exit out of the function. Since we never wait for the second goroutine to send a message, we could exit before the second one finishes and thus leak goroutines. Another thing which would be nice would be that when one goroutine ends, it should signal the other to tell it that the connection is over as some clients/servers have different levels of timeout and could leave connections hanging. To do this we will use the select keyword which allows us to check in a non-blocking fashion whether a channel has data in it.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
func handleConnection(connection net.Conn) {
    // ...

// Create our channel which waits for completion, and our two channels to
    // signal that a goroutine is done.
    complete := make(chan bool, 2)
    ch1 := make(chan bool, 1)
    ch2 := make(chan bool, 1)
    go copyContent(connection, remote, complete, ch1, ch2)
    go copyContent(remote, connection, complete, ch2, ch1)
    // Block until we've completed both goroutines!
    <- complete
    <- complete
}

func copyContent(from net.Conn, to net.Conn, complete chan bool, done chan bool, otherDone chan bool) {
    var err error = nil
    var bytes []byte = make([]byte, 256)
    var read int = 0
    for {
        select {
            // If we received a done message from the other goroutine, we exit.
            case <- otherDone:
                complete <- true
                return
            default:
                // Read data from the source connection.
                read, err = from.Read(bytes)
                // If any errors occured, write to complete as we are done (one of the
                // connections closed.)
                if err != nil {
                    complete <- true
                    done <- true
                    return
                }
                // Write data to the destination.
                _, err = to.Write(bytes[:read])
                // Same error checking.
                if err != nil {
                    complete <- true
                    done <- true
                    return
                }
        }
    }
}