server.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. package postgres
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "crypto/rand"
  6. "crypto/tls"
  7. "encoding/binary"
  8. "fmt"
  9. "io"
  10. "net"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/seaweedfs/seaweedfs/weed/glog"
  15. "github.com/seaweedfs/seaweedfs/weed/query/engine"
  16. "github.com/seaweedfs/seaweedfs/weed/util/version"
  17. )
  18. // PostgreSQL protocol constants
  19. const (
  20. // Protocol versions
  21. PG_PROTOCOL_VERSION_3 = 196608 // PostgreSQL 3.0 protocol (0x00030000)
  22. PG_SSL_REQUEST = 80877103 // SSL request (0x04d2162f)
  23. PG_GSSAPI_REQUEST = 80877104 // GSSAPI request (0x04d21630)
  24. // Message types from client
  25. PG_MSG_STARTUP = 0x00
  26. PG_MSG_QUERY = 'Q'
  27. PG_MSG_PARSE = 'P'
  28. PG_MSG_BIND = 'B'
  29. PG_MSG_EXECUTE = 'E'
  30. PG_MSG_DESCRIBE = 'D'
  31. PG_MSG_CLOSE = 'C'
  32. PG_MSG_FLUSH = 'H'
  33. PG_MSG_SYNC = 'S'
  34. PG_MSG_TERMINATE = 'X'
  35. PG_MSG_PASSWORD = 'p'
  36. // Response types to client
  37. PG_RESP_AUTH_OK = 'R'
  38. PG_RESP_BACKEND_KEY = 'K'
  39. PG_RESP_PARAMETER = 'S'
  40. PG_RESP_READY = 'Z'
  41. PG_RESP_COMMAND = 'C'
  42. PG_RESP_DATA_ROW = 'D'
  43. PG_RESP_ROW_DESC = 'T'
  44. PG_RESP_PARSE_COMPLETE = '1'
  45. PG_RESP_BIND_COMPLETE = '2'
  46. PG_RESP_CLOSE_COMPLETE = '3'
  47. PG_RESP_ERROR = 'E'
  48. PG_RESP_NOTICE = 'N'
  49. // Transaction states
  50. PG_TRANS_IDLE = 'I'
  51. PG_TRANS_INTRANS = 'T'
  52. PG_TRANS_ERROR = 'E'
  53. // Authentication methods
  54. AUTH_OK = 0
  55. AUTH_CLEAR = 3
  56. AUTH_MD5 = 5
  57. AUTH_TRUST = 10
  58. // PostgreSQL data types
  59. PG_TYPE_BOOL = 16
  60. PG_TYPE_BYTEA = 17
  61. PG_TYPE_INT8 = 20
  62. PG_TYPE_INT4 = 23
  63. PG_TYPE_TEXT = 25
  64. PG_TYPE_FLOAT4 = 700
  65. PG_TYPE_FLOAT8 = 701
  66. PG_TYPE_VARCHAR = 1043
  67. PG_TYPE_TIMESTAMP = 1114
  68. PG_TYPE_JSON = 114
  69. PG_TYPE_JSONB = 3802
  70. // Default values
  71. DEFAULT_POSTGRES_PORT = 5432
  72. )
  73. // Authentication method type
  74. type AuthMethod int
  75. const (
  76. AuthTrust AuthMethod = iota
  77. AuthPassword
  78. AuthMD5
  79. )
  80. // PostgreSQL server configuration
  81. type PostgreSQLServerConfig struct {
  82. Host string
  83. Port int
  84. AuthMethod AuthMethod
  85. Users map[string]string
  86. TLSConfig *tls.Config
  87. MaxConns int
  88. IdleTimeout time.Duration
  89. StartupTimeout time.Duration // Timeout for client startup handshake
  90. Database string
  91. }
  92. // PostgreSQL server
  93. type PostgreSQLServer struct {
  94. config *PostgreSQLServerConfig
  95. listener net.Listener
  96. sqlEngine *engine.SQLEngine
  97. sessions map[uint32]*PostgreSQLSession
  98. sessionMux sync.RWMutex
  99. shutdown chan struct{}
  100. wg sync.WaitGroup
  101. nextConnID uint32
  102. }
  103. // PostgreSQL session
  104. type PostgreSQLSession struct {
  105. conn net.Conn
  106. reader *bufio.Reader
  107. writer *bufio.Writer
  108. authenticated bool
  109. username string
  110. database string
  111. parameters map[string]string
  112. preparedStmts map[string]*PreparedStatement
  113. portals map[string]*Portal
  114. transactionState byte
  115. processID uint32
  116. secretKey uint32
  117. created time.Time
  118. lastActivity time.Time
  119. mutex sync.Mutex
  120. }
  121. // Prepared statement
  122. type PreparedStatement struct {
  123. Name string
  124. Query string
  125. ParamTypes []uint32
  126. Fields []FieldDescription
  127. }
  128. // Portal (cursor)
  129. type Portal struct {
  130. Name string
  131. Statement string
  132. Parameters [][]byte
  133. Suspended bool
  134. }
  135. // Field description
  136. type FieldDescription struct {
  137. Name string
  138. TableOID uint32
  139. AttrNum int16
  140. TypeOID uint32
  141. TypeSize int16
  142. TypeMod int32
  143. Format int16
  144. }
  145. // NewPostgreSQLServer creates a new PostgreSQL protocol server
  146. func NewPostgreSQLServer(config *PostgreSQLServerConfig, masterAddr string) (*PostgreSQLServer, error) {
  147. if config.Port <= 0 {
  148. config.Port = DEFAULT_POSTGRES_PORT
  149. }
  150. if config.Host == "" {
  151. config.Host = "localhost"
  152. }
  153. if config.Database == "" {
  154. config.Database = "default"
  155. }
  156. if config.MaxConns <= 0 {
  157. config.MaxConns = 100
  158. }
  159. if config.IdleTimeout <= 0 {
  160. config.IdleTimeout = time.Hour
  161. }
  162. if config.StartupTimeout <= 0 {
  163. config.StartupTimeout = 30 * time.Second
  164. }
  165. // Create SQL engine (now uses CockroachDB parser for PostgreSQL compatibility)
  166. sqlEngine := engine.NewSQLEngine(masterAddr)
  167. server := &PostgreSQLServer{
  168. config: config,
  169. sqlEngine: sqlEngine,
  170. sessions: make(map[uint32]*PostgreSQLSession),
  171. shutdown: make(chan struct{}),
  172. nextConnID: 1,
  173. }
  174. return server, nil
  175. }
  176. // Start begins listening for PostgreSQL connections
  177. func (s *PostgreSQLServer) Start() error {
  178. addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
  179. var listener net.Listener
  180. var err error
  181. if s.config.TLSConfig != nil {
  182. listener, err = tls.Listen("tcp", addr, s.config.TLSConfig)
  183. glog.Infof("PostgreSQL Server with TLS listening on %s", addr)
  184. } else {
  185. listener, err = net.Listen("tcp", addr)
  186. glog.Infof("PostgreSQL Server listening on %s", addr)
  187. }
  188. if err != nil {
  189. return fmt.Errorf("failed to start PostgreSQL server on %s: %v", addr, err)
  190. }
  191. s.listener = listener
  192. // Start accepting connections
  193. s.wg.Add(1)
  194. go s.acceptConnections()
  195. // Start cleanup routine
  196. s.wg.Add(1)
  197. go s.cleanupSessions()
  198. return nil
  199. }
  200. // Stop gracefully shuts down the PostgreSQL server
  201. func (s *PostgreSQLServer) Stop() error {
  202. close(s.shutdown)
  203. if s.listener != nil {
  204. s.listener.Close()
  205. }
  206. // Close all sessions
  207. s.sessionMux.Lock()
  208. for _, session := range s.sessions {
  209. session.close()
  210. }
  211. s.sessions = make(map[uint32]*PostgreSQLSession)
  212. s.sessionMux.Unlock()
  213. s.wg.Wait()
  214. glog.Infof("PostgreSQL Server stopped")
  215. return nil
  216. }
  217. // acceptConnections handles incoming PostgreSQL connections
  218. func (s *PostgreSQLServer) acceptConnections() {
  219. defer s.wg.Done()
  220. for {
  221. select {
  222. case <-s.shutdown:
  223. return
  224. default:
  225. }
  226. conn, err := s.listener.Accept()
  227. if err != nil {
  228. select {
  229. case <-s.shutdown:
  230. return
  231. default:
  232. glog.Errorf("Failed to accept PostgreSQL connection: %v", err)
  233. continue
  234. }
  235. }
  236. // Check connection limit
  237. s.sessionMux.RLock()
  238. sessionCount := len(s.sessions)
  239. s.sessionMux.RUnlock()
  240. if sessionCount >= s.config.MaxConns {
  241. glog.Warningf("Maximum connections reached (%d), rejecting connection from %s",
  242. s.config.MaxConns, conn.RemoteAddr())
  243. conn.Close()
  244. continue
  245. }
  246. s.wg.Add(1)
  247. go s.handleConnection(conn)
  248. }
  249. }
  250. // handleConnection processes a single PostgreSQL connection
  251. func (s *PostgreSQLServer) handleConnection(conn net.Conn) {
  252. defer s.wg.Done()
  253. defer conn.Close()
  254. // Generate unique connection ID
  255. connID := s.generateConnectionID()
  256. secretKey := s.generateSecretKey()
  257. // Create session
  258. session := &PostgreSQLSession{
  259. conn: conn,
  260. reader: bufio.NewReader(conn),
  261. writer: bufio.NewWriter(conn),
  262. authenticated: false,
  263. database: s.config.Database,
  264. parameters: make(map[string]string),
  265. preparedStmts: make(map[string]*PreparedStatement),
  266. portals: make(map[string]*Portal),
  267. transactionState: PG_TRANS_IDLE,
  268. processID: connID,
  269. secretKey: secretKey,
  270. created: time.Now(),
  271. lastActivity: time.Now(),
  272. }
  273. // Register session
  274. s.sessionMux.Lock()
  275. s.sessions[connID] = session
  276. s.sessionMux.Unlock()
  277. // Clean up on exit
  278. defer func() {
  279. s.sessionMux.Lock()
  280. delete(s.sessions, connID)
  281. s.sessionMux.Unlock()
  282. }()
  283. glog.V(2).Infof("New PostgreSQL connection from %s (ID: %d)", conn.RemoteAddr(), connID)
  284. // Handle startup
  285. err := s.handleStartup(session)
  286. if err != nil {
  287. // Handle common disconnection scenarios more gracefully
  288. if strings.Contains(err.Error(), "client disconnected") {
  289. glog.V(1).Infof("Client startup disconnected from %s (ID: %d): %v", conn.RemoteAddr(), connID, err)
  290. } else if strings.Contains(err.Error(), "timeout") {
  291. glog.Warningf("Startup timeout for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
  292. } else {
  293. glog.Errorf("Startup failed for connection %d from %s: %v", connID, conn.RemoteAddr(), err)
  294. }
  295. return
  296. }
  297. // Handle messages
  298. for {
  299. select {
  300. case <-s.shutdown:
  301. return
  302. default:
  303. }
  304. // Set read timeout
  305. conn.SetReadDeadline(time.Now().Add(30 * time.Second))
  306. err := s.handleMessage(session)
  307. if err != nil {
  308. if err == io.EOF {
  309. glog.Infof("PostgreSQL client disconnected (ID: %d)", connID)
  310. } else {
  311. glog.Errorf("Error handling PostgreSQL message (ID: %d): %v", connID, err)
  312. }
  313. return
  314. }
  315. session.lastActivity = time.Now()
  316. }
  317. }
  318. // handleStartup processes the PostgreSQL startup sequence
  319. func (s *PostgreSQLServer) handleStartup(session *PostgreSQLSession) error {
  320. // Set a startup timeout to prevent hanging connections
  321. startupTimeout := s.config.StartupTimeout
  322. session.conn.SetReadDeadline(time.Now().Add(startupTimeout))
  323. defer session.conn.SetReadDeadline(time.Time{}) // Clear timeout
  324. for {
  325. // Read startup message length
  326. length := make([]byte, 4)
  327. _, err := io.ReadFull(session.reader, length)
  328. if err != nil {
  329. if err == io.EOF {
  330. // Client disconnected during startup - this is common for health checks
  331. return fmt.Errorf("client disconnected during startup handshake")
  332. }
  333. if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  334. return fmt.Errorf("startup handshake timeout after %v", startupTimeout)
  335. }
  336. return fmt.Errorf("failed to read message length during startup: %v", err)
  337. }
  338. msgLength := binary.BigEndian.Uint32(length) - 4
  339. if msgLength > 10000 { // Reasonable limit for startup messages
  340. return fmt.Errorf("startup message too large: %d bytes", msgLength)
  341. }
  342. // Read startup message content
  343. msg := make([]byte, msgLength)
  344. _, err = io.ReadFull(session.reader, msg)
  345. if err != nil {
  346. if err == io.EOF {
  347. return fmt.Errorf("client disconnected while reading startup message")
  348. }
  349. if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
  350. return fmt.Errorf("startup message read timeout")
  351. }
  352. return fmt.Errorf("failed to read startup message: %v", err)
  353. }
  354. // Parse protocol version
  355. protocolVersion := binary.BigEndian.Uint32(msg[0:4])
  356. switch protocolVersion {
  357. case PG_SSL_REQUEST:
  358. // Reject SSL request - send 'N' to indicate SSL not supported
  359. _, err = session.conn.Write([]byte{'N'})
  360. if err != nil {
  361. return fmt.Errorf("failed to reject SSL request: %v", err)
  362. }
  363. // Continue loop to read the actual startup message
  364. continue
  365. case PG_GSSAPI_REQUEST:
  366. // Reject GSSAPI request - send 'N' to indicate GSSAPI not supported
  367. _, err = session.conn.Write([]byte{'N'})
  368. if err != nil {
  369. return fmt.Errorf("failed to reject GSSAPI request: %v", err)
  370. }
  371. // Continue loop to read the actual startup message
  372. continue
  373. case PG_PROTOCOL_VERSION_3:
  374. // This is the actual startup message, break out of loop
  375. break
  376. default:
  377. return fmt.Errorf("unsupported protocol version: %d", protocolVersion)
  378. }
  379. // Parse parameters
  380. params := strings.Split(string(msg[4:]), "\x00")
  381. for i := 0; i < len(params)-1; i += 2 {
  382. if params[i] == "user" {
  383. session.username = params[i+1]
  384. } else if params[i] == "database" {
  385. session.database = params[i+1]
  386. }
  387. session.parameters[params[i]] = params[i+1]
  388. }
  389. // Break out of the main loop - we have the startup message
  390. break
  391. }
  392. // Handle authentication
  393. err := s.handleAuthentication(session)
  394. if err != nil {
  395. return err
  396. }
  397. // Send parameter status messages
  398. err = s.sendParameterStatus(session, "server_version", fmt.Sprintf("%s (SeaweedFS)", version.VERSION_NUMBER))
  399. if err != nil {
  400. return err
  401. }
  402. err = s.sendParameterStatus(session, "server_encoding", "UTF8")
  403. if err != nil {
  404. return err
  405. }
  406. err = s.sendParameterStatus(session, "client_encoding", "UTF8")
  407. if err != nil {
  408. return err
  409. }
  410. err = s.sendParameterStatus(session, "DateStyle", "ISO, MDY")
  411. if err != nil {
  412. return err
  413. }
  414. err = s.sendParameterStatus(session, "integer_datetimes", "on")
  415. if err != nil {
  416. return err
  417. }
  418. // Send backend key data
  419. err = s.sendBackendKeyData(session)
  420. if err != nil {
  421. return err
  422. }
  423. // Send ready for query
  424. err = s.sendReadyForQuery(session)
  425. if err != nil {
  426. return err
  427. }
  428. session.authenticated = true
  429. return nil
  430. }
  431. // handleAuthentication processes authentication
  432. func (s *PostgreSQLServer) handleAuthentication(session *PostgreSQLSession) error {
  433. switch s.config.AuthMethod {
  434. case AuthTrust:
  435. return s.sendAuthenticationOk(session)
  436. case AuthPassword:
  437. return s.handlePasswordAuth(session)
  438. case AuthMD5:
  439. return s.handleMD5Auth(session)
  440. default:
  441. return fmt.Errorf("unsupported authentication method")
  442. }
  443. }
  444. // sendAuthenticationOk sends authentication OK message
  445. func (s *PostgreSQLServer) sendAuthenticationOk(session *PostgreSQLSession) error {
  446. msg := make([]byte, 9)
  447. msg[0] = PG_RESP_AUTH_OK
  448. binary.BigEndian.PutUint32(msg[1:5], 8)
  449. binary.BigEndian.PutUint32(msg[5:9], AUTH_OK)
  450. _, err := session.writer.Write(msg)
  451. if err == nil {
  452. err = session.writer.Flush()
  453. }
  454. return err
  455. }
  456. // handlePasswordAuth handles clear password authentication
  457. func (s *PostgreSQLServer) handlePasswordAuth(session *PostgreSQLSession) error {
  458. // Send password request
  459. msg := make([]byte, 9)
  460. msg[0] = PG_RESP_AUTH_OK
  461. binary.BigEndian.PutUint32(msg[1:5], 8)
  462. binary.BigEndian.PutUint32(msg[5:9], AUTH_CLEAR)
  463. _, err := session.writer.Write(msg)
  464. if err != nil {
  465. return err
  466. }
  467. err = session.writer.Flush()
  468. if err != nil {
  469. return err
  470. }
  471. // Read password response
  472. msgType := make([]byte, 1)
  473. _, err = io.ReadFull(session.reader, msgType)
  474. if err != nil {
  475. return err
  476. }
  477. if msgType[0] != PG_MSG_PASSWORD {
  478. return fmt.Errorf("expected password message, got %c", msgType[0])
  479. }
  480. length := make([]byte, 4)
  481. _, err = io.ReadFull(session.reader, length)
  482. if err != nil {
  483. return err
  484. }
  485. msgLength := binary.BigEndian.Uint32(length) - 4
  486. password := make([]byte, msgLength)
  487. _, err = io.ReadFull(session.reader, password)
  488. if err != nil {
  489. return err
  490. }
  491. // Verify password
  492. expectedPassword, exists := s.config.Users[session.username]
  493. if !exists || string(password[:len(password)-1]) != expectedPassword { // Remove null terminator
  494. return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
  495. }
  496. return s.sendAuthenticationOk(session)
  497. }
  498. // handleMD5Auth handles MD5 password authentication
  499. func (s *PostgreSQLServer) handleMD5Auth(session *PostgreSQLSession) error {
  500. // Generate salt
  501. salt := make([]byte, 4)
  502. _, err := rand.Read(salt)
  503. if err != nil {
  504. return err
  505. }
  506. // Send MD5 request
  507. msg := make([]byte, 13)
  508. msg[0] = PG_RESP_AUTH_OK
  509. binary.BigEndian.PutUint32(msg[1:5], 12)
  510. binary.BigEndian.PutUint32(msg[5:9], AUTH_MD5)
  511. copy(msg[9:13], salt)
  512. _, err = session.writer.Write(msg)
  513. if err != nil {
  514. return err
  515. }
  516. err = session.writer.Flush()
  517. if err != nil {
  518. return err
  519. }
  520. // Read password response
  521. msgType := make([]byte, 1)
  522. _, err = io.ReadFull(session.reader, msgType)
  523. if err != nil {
  524. return err
  525. }
  526. if msgType[0] != PG_MSG_PASSWORD {
  527. return fmt.Errorf("expected password message, got %c", msgType[0])
  528. }
  529. length := make([]byte, 4)
  530. _, err = io.ReadFull(session.reader, length)
  531. if err != nil {
  532. return err
  533. }
  534. msgLength := binary.BigEndian.Uint32(length) - 4
  535. response := make([]byte, msgLength)
  536. _, err = io.ReadFull(session.reader, response)
  537. if err != nil {
  538. return err
  539. }
  540. // Verify MD5 hash
  541. expectedPassword, exists := s.config.Users[session.username]
  542. if !exists {
  543. return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
  544. }
  545. // Calculate expected hash: md5(md5(password + username) + salt)
  546. inner := md5.Sum([]byte(expectedPassword + session.username))
  547. expected := fmt.Sprintf("md5%x", md5.Sum(append([]byte(fmt.Sprintf("%x", inner)), salt...)))
  548. if string(response[:len(response)-1]) != expected { // Remove null terminator
  549. return s.sendError(session, "28P01", "authentication failed for user \""+session.username+"\"")
  550. }
  551. return s.sendAuthenticationOk(session)
  552. }
  553. // generateConnectionID generates a unique connection ID
  554. func (s *PostgreSQLServer) generateConnectionID() uint32 {
  555. s.sessionMux.Lock()
  556. defer s.sessionMux.Unlock()
  557. id := s.nextConnID
  558. s.nextConnID++
  559. return id
  560. }
  561. // generateSecretKey generates a secret key for the connection
  562. func (s *PostgreSQLServer) generateSecretKey() uint32 {
  563. key := make([]byte, 4)
  564. rand.Read(key)
  565. return binary.BigEndian.Uint32(key)
  566. }
  567. // close marks the session as closed
  568. func (s *PostgreSQLSession) close() {
  569. s.mutex.Lock()
  570. defer s.mutex.Unlock()
  571. if s.conn != nil {
  572. s.conn.Close()
  573. s.conn = nil
  574. }
  575. }
  576. // cleanupSessions periodically cleans up idle sessions
  577. func (s *PostgreSQLServer) cleanupSessions() {
  578. defer s.wg.Done()
  579. ticker := time.NewTicker(time.Minute)
  580. defer ticker.Stop()
  581. for {
  582. select {
  583. case <-s.shutdown:
  584. return
  585. case <-ticker.C:
  586. s.cleanupIdleSessions()
  587. }
  588. }
  589. }
  590. // cleanupIdleSessions removes sessions that have been idle too long
  591. func (s *PostgreSQLServer) cleanupIdleSessions() {
  592. now := time.Now()
  593. s.sessionMux.Lock()
  594. defer s.sessionMux.Unlock()
  595. for id, session := range s.sessions {
  596. if now.Sub(session.lastActivity) > s.config.IdleTimeout {
  597. glog.Infof("Closing idle PostgreSQL session %d", id)
  598. session.close()
  599. delete(s.sessions, id)
  600. }
  601. }
  602. }
  603. // GetAddress returns the server address
  604. func (s *PostgreSQLServer) GetAddress() string {
  605. return fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
  606. }