访问net / http响应的基础套接字

I'm new to Go and evaluating it for a project.

I'm trying to write a custom handler to serve files with net/http. I can't use the default http.FileServer() handler because I need to have access to the underlying socket (the internal net.Conn) so I can perform some informational platform specific "syscall" calls on it (mainly TCP_INFO).

More precisly: I need to access the underlying socket of the http.ResponseWriter in the handler function:

func myHandler(w http.ResponseWriter, r *http.Request) {
...
// I need the net.Conn of w
...
}

used in

http.HandleFunc("/", myHandler)

Is there a way to this. I looked at how websocket.Upgrade does this but it uses Hijack() which is 'too much' because then I have to code 'speaking http' over the raw tcp socket I get. I just want a reference to the socket and not taking over completely.

After Issue #30694 is completed, it looks like Go 1.13 will probably support storing the net.Conn in the Request Context, which makes this fairly clean and simple:

package main

import (
  "net/http"
  "context"
  "net"
  "log"
)

type contextKey struct {
  key string
}
var ConnContextKey = &contextKey{"http-conn"}
func SaveConnInContext(ctx context.Context, c net.Conn) (context.Context) {
  return context.WithValue(ctx, ConnContextKey, c)
}
func GetConn(r *http.Request) (net.Conn) {
  return r.Context().Value(ConnContextKey).(net.Conn)
}

func main() {
  http.HandleFunc("/", myHandler)

  server := http.Server{
    Addr: ":8080",
    ConnContext: SaveConnInContext,
  }
  server.ListenAndServe()
}

func myHandler(w http.ResponseWriter, r *http.Request) {
  conn := GetConn(r)
  ...
}

Until then ... For a server listening on a TCP port, net.Conn.RemoteAddr().String() is unique for each connection and is available to the http.Handler as r.RemoteAddr, so it can be used as a key to a global map of Conns:

package main
import (
  "net/http"
  "net"
  "fmt"
  "log"
)

var conns = make(map[string]net.Conn)
func ConnStateEvent(conn net.Conn, event http.ConnState) {
  if event == http.StateActive {
    conns[conn.RemoteAddr().String()] = conn
  } else if event == http.StateHijacked || event == http.StateClosed {
    delete(conns, conn.RemoteAddr().String())
  }
}
func GetConn(r *http.Request) (net.Conn) {
  return conns[r.RemoteAddr]
}

func main() {
  http.HandleFunc("/", myHandler)

  server := http.Server{
    Addr: ":8080",
    ConnState: ConnStateEvent,
  }
  server.ListenAndServe()
}

func myHandler(w http.ResponseWriter, r *http.Request) {
  conn := GetConn(r)
  ...
}

For a server listening on a UNIX socket, net.Conn.RemoteAddr().String() is always "@", so the above doesn't work. To make this work, we can override net.Listener.Accept(), and use that to override net.Conn.RemoteAddr().String() so that it returns a unique string for each connection:

package main

import (
  "net/http"
  "net"
  "os"
  "golang.org/x/sys/unix"
  "fmt"
  "log"
)

func main() {
  http.HandleFunc("/", myHandler)

  listenPath := "/var/run/go_server.sock"
  l, err := NewUnixListener(listenPath)
  if err != nil {
    log.Fatal(err)
  }
  defer os.Remove(listenPath)

  server := http.Server{
    ConnState: ConnStateEvent,
  }
  server.Serve(NewConnSaveListener(l))
}

func myHandler(w http.ResponseWriter, r *http.Request) {
  conn := GetConn(r)
  if unixConn, isUnix := conn.(*net.UnixConn); isUnix {
    f, _ := unixConn.File()
    pcred, _ := unix.GetsockoptUcred(int(f.Fd()), unix.SOL_SOCKET, unix.SO_PEERCRED)
    f.Close()
    log.Printf("Remote UID: %d", pcred.Uid)
  }
}

var conns = make(map[string]net.Conn)
type connSaveListener struct {
  net.Listener
}
func NewConnSaveListener(wrap net.Listener) (net.Listener) {
  return connSaveListener{wrap}
}
func (self connSaveListener) Accept() (net.Conn, error) {
  conn, err := self.Listener.Accept()
  ptrStr := fmt.Sprintf("%d", &conn)
  conns[ptrStr] = conn
  return remoteAddrPtrConn{conn, ptrStr}, err
}
func GetConn(r *http.Request) (net.Conn) {
  return conns[r.RemoteAddr]
}
func ConnStateEvent(conn net.Conn, event http.ConnState) {
  if event == http.StateHijacked || event == http.StateClosed {
    delete(conns, conn.RemoteAddr().String())
  }
}
type remoteAddrPtrConn struct {
  net.Conn
  ptrStr string
}
func (self remoteAddrPtrConn) RemoteAddr() (net.Addr) {
  return remoteAddrPtr{self.ptrStr}
}
type remoteAddrPtr struct {
  ptrStr string
}
func (remoteAddrPtr) Network() (string) {
  return ""
}
func (self remoteAddrPtr) String() (string) {
  return self.ptrStr
}

func NewUnixListener(path string) (net.Listener, error) {
  if err := unix.Unlink(path); err != nil && !os.IsNotExist(err) {
    return nil, err
  }
  mask := unix.Umask(0777)
  defer unix.Umask(mask)

  l, err := net.Listen("unix", path)
  if err != nil {
    return nil, err
  }

  if err := os.Chmod(path, 0660); err != nil {
    l.Close()
    return nil, err
  }

  return l, nil
}

Note that although in current implementation http.ResponseWriter is a *http.response (note the lowercase!) which holds the connection, the field is unexported and you can't access it.

Instead take a look at the Server.ConnState hook: you can "register" a function which will be called when the connection state changes, see http.ConnState for details. For example you will get the net.Conn even before the request enters the handler (http.StateNew and http.StateActive states).

You can install a connection state listener by creating a custom Server like this:

func main() {
    http.HandleFunc("/", myHandler)

    s := &http.Server{
        Addr:           ":8081",
        ReadTimeout:    10 * time.Second,
        WriteTimeout:   10 * time.Second,
        MaxHeaderBytes: 1 << 20,
        ConnState:      ConnStateListener,
    }
    panic(s.ListenAndServe())
}

func ConnStateListener(c net.Conn, cs http.ConnState) {
    fmt.Printf("CONN STATE: %v, %v
", cs, c)
}

This way you will have exactly the desired net.Conn even before (and also during and after) invoking the handler. The downside is that it is not "paired" with the ResponseWriter, you have to do that manually if you need that.

You can use an HttpHijacker to take over the TCP connection from the ResponseWriter. Once you've done that you're free to use the socket to do whatever you want.

See http://golang.org/pkg/net/http/#Hijacker, which also contains a good example.

It looks like you cannot "pair" a socket (or net.Conn) to either http.Request or http.ResponseWriter.

But you can implement your own Listener:

package main

import (
    "fmt"
    "net"
    "net/http"
    "time"
    "log"
)

func main() {
    // init http server

    m := &MyHandler{}
    s := &http.Server{
        Handler:        m,
    }

    // create custom listener

    nl, err := net.Listen("tcp", ":8080")
    if err != nil {
        log.Fatal(err)
    }
    l := &MyListener{nl}

    // serve through custom listener

    err = s.Serve(l)
    if err != nil {
        log.Fatal(err)
    }
}

// net.Conn

type MyConn struct {
    nc net.Conn
}

func (c MyConn) Read(b []byte) (n int, err error) {
    return c.nc.Read(b)
}

func (c MyConn) Write(b []byte) (n int, err error) {
    return c.nc.Write(b)
}

func (c MyConn) Close() error {
    return c.nc.Close()
}

func (c MyConn) LocalAddr() net.Addr {
    return c.nc.LocalAddr()
}

func (c MyConn) RemoteAddr() net.Addr {
    return c.nc.RemoteAddr()
}

func (c MyConn) SetDeadline(t time.Time) error {
    return c.nc.SetDeadline(t)
}

func (c MyConn) SetReadDeadline(t time.Time) error {
    return c.nc.SetReadDeadline(t)
}

func (c MyConn) SetWriteDeadline(t time.Time) error {
    return c.nc.SetWriteDeadline(t)
}

// net.Listener

type MyListener struct {
    nl net.Listener
}

func (l MyListener) Accept() (c net.Conn, err error) {
    nc, err := l.nl.Accept()
    if err != nil {
        return nil, err
    }
    return MyConn{nc}, nil
}

func (l MyListener) Close() error {
    return l.nl.Close()
}

func (l MyListener) Addr() net.Addr {
    return l.nl.Addr()
}

// http.Handler

type MyHandler struct {
    // ...
}

func (h *MyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    fmt.Fprintf(w, "Hello World")
}

This can be done with reflection. it's a bit "dirty" but it works:

package main

import "net/http"
import "fmt"
import "runtime"

import "reflect"

func myHandler(w http.ResponseWriter, r *http.Request) {

    ptrVal := reflect.ValueOf(w)
    val := reflect.Indirect(ptrVal)

    // w is a "http.response" struct from which we get the 'conn' field
    valconn := val.FieldByName("conn")
    val1 := reflect.Indirect(valconn)

    // which is a http.conn from which we get the 'rwc' field
    ptrRwc := val1.FieldByName("rwc").Elem()
    rwc := reflect.Indirect(ptrRwc)

    // which is net.TCPConn from which we get the embedded conn
    val1conn := rwc.FieldByName("conn")
    val2 := reflect.Indirect(val1conn)

    // which is a net.conn from which we get the 'fd' field
    fdmember := val2.FieldByName("fd")
    val3 := reflect.Indirect(fdmember)

    // which is a netFD from which we get the 'sysfd' field
    netFdPtr := val3.FieldByName("sysfd")
    fmt.Printf("netFDPtr= %v
", netFdPtr)

    // which is the system socket (type is plateform specific - Int for linux)
    if runtime.GOOS == "linux" {
        fd := int(netFdPtr.Int())
        fmt.Printf("fd = %v
", fd)
        // fd is the socket - we can call unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd),....) on it for instance
    }

    fmt.Fprintf(w, "Hello World")
}

func main() {
    http.HandleFunc("/", myHandler)
    err := http.ListenAndServe(":8081", nil)
    fmt.Println(err.Error())
}

Ideally the library should be augmented with a method to get the underlying net.Conn

Expanding on KGJV's answer, a working solution using reflection to maintain a map of connections indexed by net.Conn instance memory addresses.

Instances of net.Conn can be looked up by pointer, and pointers derived using reflection against http.Response.

It's a bit nasty, but given you can't access unpublished fields with reflection it's the only way I could see of doing it.

// Connection array indexed by connection address
var conns = make(map[uintptr]net.Conn)
var connMutex = sync.Mutex{}

// writerToConnPrt converts an http.ResponseWriter to a pointer for indexing
func writerToConnPtr(w http.ResponseWriter) uintptr {
    ptrVal := reflect.ValueOf(w)
    val := reflect.Indirect(ptrVal)

    // http.conn
    valconn := val.FieldByName("conn")
    val1 := reflect.Indirect(valconn)

    // net.TCPConn
    ptrRwc := val1.FieldByName("rwc").Elem()
    rwc := reflect.Indirect(ptrRwc)

    // net.Conn
    val1conn := rwc.FieldByName("conn")
    val2 := reflect.Indirect(val1conn)

    return val2.Addr().Pointer()
}

// connToPtr converts a net.Conn into a pointer for indexing
func connToPtr(c net.Conn) uintptr {
    ptrVal := reflect.ValueOf(c)
    return ptrVal.Pointer()
}

// ConnStateListener bound to server and maintains a list of connections by pointer
func ConnStateListener(c net.Conn, cs http.ConnState) {
    connPtr := connToPtr(c)
    connMutex.Lock()
    defer connMutex.Unlock()

    switch cs {
    case http.StateNew:
        log.Printf("CONN Opened: 0x%x
", connPtr)
        conns[connPtr] = c

    case http.StateClosed:
        log.Printf("CONN Closed: 0x%x
", connPtr)
        delete(conns, connPtr)
    }
}

func HandleRequest(w http.ResponseWriter, r *http.Request) {
    connPtr := writerToConnPtr(w)
    connMutex.Lock()
    defer connMutex.Unlock()

    // Requests can access connections by pointer from the responseWriter object
    conn, ok := conns[connPtr]
    if !ok {
        log.Printf("error: no matching connection found")
        return
    }

    // Do something with connection here...

}

// Bind with http.Server.ConnState = ConnStateListener