| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704 |
- package postgres
- import (
- "bufio"
- "crypto/md5"
- "crypto/rand"
- "crypto/tls"
- "encoding/binary"
- "fmt"
- "io"
- "net"
- "strings"
- "sync"
- "time"
- "github.com/seaweedfs/seaweedfs/weed/glog"
- "github.com/seaweedfs/seaweedfs/weed/query/engine"
- "github.com/seaweedfs/seaweedfs/weed/util/version"
- )
- // PostgreSQL protocol constants
- const (
- // Protocol versions
- PG_PROTOCOL_VERSION_3 = 196608 // PostgreSQL 3.0 protocol (0x00030000)
- PG_SSL_REQUEST = 80877103 // SSL request (0x04d2162f)
- PG_GSSAPI_REQUEST = 80877104 // GSSAPI request (0x04d21630)
- // Message types from client
- PG_MSG_STARTUP = 0x00
- PG_MSG_QUERY = 'Q'
- PG_MSG_PARSE = 'P'
- PG_MSG_BIND = 'B'
- PG_MSG_EXECUTE = 'E'
- PG_MSG_DESCRIBE = 'D'
- PG_MSG_CLOSE = 'C'
- PG_MSG_FLUSH = 'H'
- PG_MSG_SYNC = 'S'
- PG_MSG_TERMINATE = 'X'
- PG_MSG_PASSWORD = 'p'
- // Response types to client
- PG_RESP_AUTH_OK = 'R'
- PG_RESP_BACKEND_KEY = 'K'
- PG_RESP_PARAMETER = 'S'
- PG_RESP_READY = 'Z'
- PG_RESP_COMMAND = 'C'
- PG_RESP_DATA_ROW = 'D'
- PG_RESP_ROW_DESC = 'T'
- PG_RESP_PARSE_COMPLETE = '1'
- PG_RESP_BIND_COMPLETE = '2'
- PG_RESP_CLOSE_COMPLETE = '3'
- PG_RESP_ERROR = 'E'
- PG_RESP_NOTICE = 'N'
- // Transaction states
- PG_TRANS_IDLE = 'I'
- PG_TRANS_INTRANS = 'T'
- PG_TRANS_ERROR = 'E'
- // Authentication methods
- AUTH_OK = 0
- AUTH_CLEAR = 3
- AUTH_MD5 = 5
- AUTH_TRUST = 10
- // PostgreSQL data types
- PG_TYPE_BOOL = 16
- PG_TYPE_BYTEA = 17
- PG_TYPE_INT8 = 20
- PG_TYPE_INT4 = 23
- PG_TYPE_TEXT = 25
- PG_TYPE_FLOAT4 = 700
- PG_TYPE_FLOAT8 = 701
- PG_TYPE_VARCHAR = 1043
- PG_TYPE_TIMESTAMP = 1114
- PG_TYPE_JSON = 114
- PG_TYPE_JSONB = 3802
- // Default values
- DEFAULT_POSTGRES_PORT = 5432
- )
- // Authentication method type
- type AuthMethod int
- const (
- AuthTrust AuthMethod = iota
- AuthPassword
- AuthMD5
- )
- // PostgreSQL server configuration
- type PostgreSQLServerConfig struct {
- Host string
- Port int
- AuthMethod AuthMethod
- Users map[string]string
- TLSConfig *tls.Config
- MaxConns int
- IdleTimeout time.Duration
- StartupTimeout time.Duration // Timeout for client startup handshake
- Database string
- }
- // PostgreSQL server
- type PostgreSQLServer struct {
- config *PostgreSQLServerConfig
- listener net.Listener
- sqlEngine *engine.SQLEngine
- sessions map[uint32]*PostgreSQLSession
- sessionMux sync.RWMutex
- shutdown chan struct{}
- wg sync.WaitGroup
- nextConnID uint32
- }
- // PostgreSQL session
- type PostgreSQLSession struct {
- conn net.Conn
- reader *bufio.Reader
- writer *bufio.Writer
- authenticated bool
- username string
- database string
- parameters map[string]string
- preparedStmts map[string]*PreparedStatement
- portals map[string]*Portal
- transactionState byte
- processID uint32
- secretKey uint32
- created time.Time
- lastActivity time.Time
- mutex sync.Mutex
- }
- // Prepared statement
- type PreparedStatement struct {
- Name string
- Query string
- ParamTypes []uint32
- Fields []FieldDescription
- }
- // Portal (cursor)
- type Portal struct {
- Name string
- Statement string
- Parameters [][]byte
- Suspended bool
- }
- // Field description
- type FieldDescription struct {
- Name string
- TableOID uint32
- AttrNum int16
- TypeOID uint32
- TypeSize int16
- TypeMod int32
- Format int16
- }
- // NewPostgreSQLServer creates a new PostgreSQL protocol server
- func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
- if config.Port <= 0 {
- config.Port = DEFAULT_POSTGRES_PORT
- }
- if config.Host == "" {
- config.Host = "localhost"
- }
- if config.Database == "" {
- config.Database = "default"
- }
- if config.MaxConns <= 0 {
- config.MaxConns = 100
- }
- if config.IdleTimeout <= 0 {
- config.IdleTimeout = time.Hour
- }
- if config.StartupTimeout <= 0 {
- config.StartupTimeout = 30 * time.Second
- }
- // Create SQL engine (now uses CockroachDB parser for PostgreSQL compatibility)
- sqlEngine := engine.NewSQLEngine(masterAddr)
- server := &PostgreSQLServer{
- config: config,
- sqlEngine: sqlEngine,
- sessions: make(map[uint32]*PostgreSQLSession),
- shutdown: make(chan struct{}),
- nextConnID: 1,
- }
- return server, nil
- }
- // Start begins listening for PostgreSQL connections
- func (s *PostgreSQLServer) Start() error {
- addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
- var listener net.Listener
- var err error
- if s.config.TLSConfig != nil {
- listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
- glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
- } else {
- listener, err = net.Listen("tcp", addr)
- glog.Infof("PostgreSQL Server listening on %s", addr)
- }
- if err != nil {
- return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
- }
- s.listener = listener
- // Start accepting connections
- s.wg.Add(1)
- go s.acceptConnections()
- // Start cleanup routine
- s.wg.Add(1)
- go s.cleanupSessions()
- return nil
- }
- // Stop gracefully shuts down the PostgreSQL server
- func (s *PostgreSQLServer) Stop() error {
- close(s.shutdown)
- if s.listener != nil {
- s.listener.Close()
- }
- // Close all sessions
- s.sessionMux.Lock()
- for _, session := range s.sessions {
- session.close()
- }
- s.sessions = make(map[uint32]*PostgreSQLSession)
- s.sessionMux.Unlock()
- s.wg.Wait()
- glog.Infof("PostgreSQL Server stopped")
- return nil
- }
- // acceptConnections handles incoming PostgreSQL connections
- func (s *PostgreSQLServer) acceptConnections() {
- defer s.wg.Done()
- for {
- select {
- case <-s.shutdown:
- return
- default:
- }
- conn, err := s.listener.Accept()
- if err != nil {
- select {
- case <-s.shutdown:
- return
- default:
- glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
- continue
- }
- }
- // Check connection limit
- s.sessionMux.RLock()
- sessionCount := len(s.sessions)
- s.sessionMux.RUnlock()
- if sessionCount >= s.config.MaxConns {
- glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
- s.config.MaxConns, conn.RemoteAddr())
- conn.Close()
- continue
- }
- s.wg.Add(1)
- go s.handleConnection(conn)
- }
- }
- // handleConnection processes a single PostgreSQL connection
- func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
- defer s.wg.Done()
- defer conn.Close()
- // Generate unique connection ID
- connID := s.generateConnectionID()
- secretKey := s.generateSecretKey()
- // Create session
- session := &PostgreSQLSession{
- conn: conn,
- reader: bufio.NewReader(conn),
- writer: bufio.NewWriter(conn),
- authenticated: false,
- database: s.config.Database,
- parameters: make(map[string]string),
- preparedStmts: make(map[string]*PreparedStatement),
- portals: make(map[string]*Portal),
- transactionState: PG_TRANS_IDLE,
- processID: connID,
- secretKey: secretKey,
- created: time.Now(),
- lastActivity: time.Now(),
- }
- // Register session
- s.sessionMux.Lock()
- s.sessions[connID] = session
- s.sessionMux.Unlock()
- // Clean up on exit
- defer func() {
- s.sessionMux.Lock()
- delete(s.sessions, connID)
- s.sessionMux.Unlock()
- }()
- glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
- // Handle startup
- err := s.handleStartup(session)
- if err != nil {
- // Handle common disconnection scenarios more gracefully
- if strings.Contains(err.Error(), "client disconnected") {
- glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
- } else if strings.Contains(err.Error(), "timeout") {
- glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
- } else {
- glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
- }
- return
- }
- // Handle messages
- for {
- select {
- case <-s.shutdown:
- return
- default:
- }
- // Set read timeout
- conn.SetReadDeadline(time.Now().Add(30 * time.Second))
- err := s.handleMessage(session)
- if err != nil {
- if err == io.EOF {
- glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
- } else {
- glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
- }
- return
- }
- session.lastActivity = time.Now()
- }
- }
- // handleStartup processes the PostgreSQL startup sequence
- func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
- // Set a startup timeout to prevent hanging connections
- startupTimeout := s.config.StartupTimeout
- session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
- defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
- for {
- // Read startup message length
- length := make([]byte, 4)
- _, err := io.ReadFull(session.reader, length)
- if err != nil {
- if err == io.EOF {
- // Client disconnected during startup - this is common for health checks
- return fmt.Errorf("client disconnected during startup handshake")
- }
- if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
- return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
- }
- return fmt.Errorf("failed to read message length during startup: %v", err)
- }
- msgLength := binary.BigEndian.Uint32(length) - 4
- if msgLength > 10000 { // Reasonable limit for startup messages
- return fmt.Errorf("startup message too large: %d bytes", msgLength)
- }
- // Read startup message content
- msg := make([]byte, msgLength)
- _, err = io.ReadFull(session.reader, msg)
- if err != nil {
- if err == io.EOF {
- return fmt.Errorf("client disconnected while reading startup message")
- }
- if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
- return fmt.Errorf("startup message read timeout")
- }
- return fmt.Errorf("failed to read startup message: %v", err)
- }
- // Parse protocol version
- protocolVersion := binary.BigEndian.Uint32(msg[0:4])
- switch protocolVersion {
- case PG_SSL_REQUEST:
- // Reject SSL request - send 'N' to indicate SSL not supported
- _, err = session.conn.Write([]byte{'N'})
- if err != nil {
- return fmt.Errorf("failed to reject SSL request: %v", err)
- }
- // Continue loop to read the actual startup message
- continue
- case PG_GSSAPI_REQUEST:
- // Reject GSSAPI request - send 'N' to indicate GSSAPI not supported
- _, err = session.conn.Write([]byte{'N'})
- if err != nil {
- return fmt.Errorf("failed to reject GSSAPI request: %v", err)
- }
- // Continue loop to read the actual startup message
- continue
- case PG_PROTOCOL_VERSION_3:
- // This is the actual startup message, break out of loop
- break
- default:
- return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
- }
- // Parse parameters
- params := strings.Split(string(msg[4:]), "\x00")
- for i := 0; i < len(params)-1; i += 2 {
- if params[i] == "user" {
- session.username = params[i+1]
- } else if params[i] == "database" {
- session.database = params[i+1]
- }
- session.parameters[params[i]] = params[i+1]
- }
- // Break out of the main loop - we have the startup message
- break
- }
- // Handle authentication
- err := s.handleAuthentication(session)
- if err != nil {
- return err
- }
- // Send parameter status messages
- err = s.sendParameterStatus(session, "server_version", fmt.Sprintf("%s (SeaweedFS)", version.VERSION_NUMBER))
- if err != nil {
- return err
- }
- err = s.sendParameterStatus(session, "server_encoding", "UTF8")
- if err != nil {
- return err
- }
- err = s.sendParameterStatus(session, "client_encoding", "UTF8")
- if err != nil {
- return err
- }
- err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
- if err != nil {
- return err
- }
- err = s.sendParameterStatus(session, "integer_datetimes", "on")
- if err != nil {
- return err
- }
- // Send backend key data
- err = s.sendBackendKeyData(session)
- if err != nil {
- return err
- }
- // Send ready for query
- err = s.sendReadyForQuery(session)
- if err != nil {
- return err
- }
- session.authenticated = true
- return nil
- }
- // handleAuthentication processes authentication
- func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
- switch s.config.AuthMethod {
- case AuthTrust:
- return s.sendAuthenticationOk(session)
- case AuthPassword:
- return s.handlePasswordAuth(session)
- case AuthMD5:
- return s.handleMD5Auth(session)
- default:
- return fmt.Errorf("unsupported authentication method")
- }
- }
- // sendAuthenticationOk sends authentication OK message
- func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
- msg := make([]byte, 9)
- msg[0] = PG_RESP_AUTH_OK
- binary.BigEndian.PutUint32(msg[1:5], 8)
- binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
- _, err := session.writer.Write(msg)
- if err == nil {
- err = session.writer.Flush()
- }
- return err
- }
- // handlePasswordAuth handles clear password authentication
- func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
- // Send password request
- msg := make([]byte, 9)
- msg[0] = PG_RESP_AUTH_OK
- binary.BigEndian.PutUint32(msg[1:5], 8)
- binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
- _, err := session.writer.Write(msg)
- if err != nil {
- return err
- }
- err = session.writer.Flush()
- if err != nil {
- return err
- }
- // Read password response
- msgType := make([]byte, 1)
- _, err = io.ReadFull(session.reader, msgType)
- if err != nil {
- return err
- }
- if msgType[0] != PG_MSG_PASSWORD {
- return fmt.Errorf("expected password message, got %c", msgType[0])
- }
- length := make([]byte, 4)
- _, err = io.ReadFull(session.reader, length)
- if err != nil {
- return err
- }
- msgLength := binary.BigEndian.Uint32(length) - 4
- password := make([]byte, msgLength)
- _, err = io.ReadFull(session.reader, password)
- if err != nil {
- return err
- }
- // Verify password
- expectedPassword, exists := s.config.Users[session.username]
- if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
- return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
- }
- return s.sendAuthenticationOk(session)
- }
- // handleMD5Auth handles MD5 password authentication
- func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
- // Generate salt
- salt := make([]byte, 4)
- _, err := rand.Read(salt)
- if err != nil {
- return err
- }
- // Send MD5 request
- msg := make([]byte, 13)
- msg[0] = PG_RESP_AUTH_OK
- binary.BigEndian.PutUint32(msg[1:5], 12)
- binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
- copy(msg[9:13], salt)
- _, err = session.writer.Write(msg)
- if err != nil {
- return err
- }
- err = session.writer.Flush()
- if err != nil {
- return err
- }
- // Read password response
- msgType := make([]byte, 1)
- _, err = io.ReadFull(session.reader, msgType)
- if err != nil {
- return err
- }
- if msgType[0] != PG_MSG_PASSWORD {
- return fmt.Errorf("expected password message, got %c", msgType[0])
- }
- length := make([]byte, 4)
- _, err = io.ReadFull(session.reader, length)
- if err != nil {
- return err
- }
- msgLength := binary.BigEndian.Uint32(length) - 4
- response := make([]byte, msgLength)
- _, err = io.ReadFull(session.reader, response)
- if err != nil {
- return err
- }
- // Verify MD5 hash
- expectedPassword, exists := s.config.Users[session.username]
- if !exists {
- return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
- }
- // Calculate expected hash: md5(md5(password + username) + salt)
- inner := md5.Sum([]byte(expectedPassword + session.username))
- expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
- if string(response[:len(response)-1]) != expected { // Remove null terminator
- return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
- }
- return s.sendAuthenticationOk(session)
- }
- // generateConnectionID generates a unique connection ID
- func (s *PostgreSQLServer) generateConnectionID() uint32 {
- s.sessionMux.Lock()
- defer s.sessionMux.Unlock()
- id := s.nextConnID
- s.nextConnID++
- return id
- }
- // generateSecretKey generates a secret key for the connection
- func (s *PostgreSQLServer) generateSecretKey() uint32 {
- key := make([]byte, 4)
- rand.Read(key)
- return binary.BigEndian.Uint32(key)
- }
- // close marks the session as closed
- func (s *PostgreSQLSession) close() {
- s.mutex.Lock()
- defer s.mutex.Unlock()
- if s.conn != nil {
- s.conn.Close()
- s.conn = nil
- }
- }
- // cleanupSessions periodically cleans up idle sessions
- func (s *PostgreSQLServer) cleanupSessions() {
- defer s.wg.Done()
- ticker := time.NewTicker(time.Minute)
- defer ticker.Stop()
- for {
- select {
- case <-s.shutdown:
- return
- case <-ticker.C:
- s.cleanupIdleSessions()
- }
- }
- }
- // cleanupIdleSessions removes sessions that have been idle too long
- func (s *PostgreSQLServer) cleanupIdleSessions() {
- now := time.Now()
- s.sessionMux.Lock()
- defer s.sessionMux.Unlock()
- for id, session := range s.sessions {
- if now.Sub(session.lastActivity) > s.config.IdleTimeout {
- glog.Infof("Closing idle PostgreSQL session %d", id)
- session.close()
- delete(s.sessions, id)
- }
- }
- }
- // GetAddress returns the server address
- func (s *PostgreSQLServer) GetAddress() string {
- return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
- }
|