You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

303 lines
9.4 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Package httptransport provides a configurable cURL-style logging RoundTripper
// for HTTP-based ChatModel clients. It logs the real outbound HTTP request
// (as a cURL command) and the inbound HTTP response. Streaming responses (SSE
// or NDJSON) can be logged chunk-by-chunk without breaking the stream.
//
// Quick usage:
//
// client := &http.Client{Transport: httptransport.NewCurlRT(
// http.DefaultTransport,
// httptransport.WithLogger(log.Default()),
// // Or pass a context-aware logger to extract request IDs:
// httptransport.WithCtxLogger(httptransport.IDCtxLogger{L: log.Default()}),
// // Security controls:
// httptransport.WithPrintAuth(false), // mask Authorization
// httptransport.WithMaskHeaders([]string{"X-API-KEY"}), // mask custom headers
// // Streaming controls:
// httptransport.WithStreamLogging(true),
// httptransport.WithMaxStreamLogBytes(8192),
// )}
// cm, _ := openai.NewChatModel(ctx, &openai.ChatModelConfig{ HTTPClient: client, ... })
//
// Notes:
// - WithCtxLogger is preferred when you carry a request/log ID in context.
// - WithPrintAuth controls whether the Authorization header is printed.
// - WithMaskHeaders and WithMaskFunc allow masking arbitrary headers.
// - When stream logging is enabled, headers are logged once, and chunks are
// emitted as they are read. With a plain Logger, a capped summary is printed
// on Close(); with a CtxLogger, each chunk is logged directly.
package httptransport
import (
"bytes"
"context"
"io"
"log"
"net/http"
"strings"
)
// sanitizeLogValue removes line breaks and carriage returns to prevent log forging
func sanitizeLogValue(s string) string {
s = strings.ReplaceAll(s, "\n", "")
s = strings.ReplaceAll(s, "\r", "")
return s
}
// Logger is a minimal printf-style logger used when no context is required.
type Logger interface{ Printf(string, ...any) }
// CtxLogger is a context-aware logger; use this to inject request IDs or
// structured logging derived from the HTTP request context.
type CtxLogger interface {
Printf(context.Context, string, ...any)
}
// CurlRT is an http.RoundTripper that logs requests/responses in cURL style.
// Configure it with the CurlOption helpers via NewCurlRT.
type CurlRT struct {
base http.RoundTripper
logger Logger
ctxLogger CtxLogger
printAuth bool
maskHeaders map[string]struct{}
maskFn func(string, string) string
streamEnabled bool
maxStreamLogBytes int
streamCTFilter func(string) bool
}
// CurlOption configures CurlRT behavior.
type CurlOption func(*CurlRT)
// WithLogger sets a simple printf-style logger.
func WithLogger(l Logger) CurlOption { return func(c *CurlRT) { c.logger = l } }
// WithCtxLogger sets a context-aware logger for request/response/chunk logs.
func WithCtxLogger(l CtxLogger) CurlOption { return func(c *CurlRT) { c.ctxLogger = l } }
// WithPrintAuth controls whether the Authorization header value is printed.
func WithPrintAuth(b bool) CurlOption { return func(c *CurlRT) { c.printAuth = b } }
// WithMaskHeaders masks specified header names (case-insensitive) in logs.
func WithMaskHeaders(names []string) CurlOption {
return func(c *CurlRT) {
if c.maskHeaders == nil {
c.maskHeaders = make(map[string]struct{})
}
for _, n := range names {
c.maskHeaders[strings.ToLower(n)] = struct{}{}
}
}
}
// WithMaskFunc provides a custom masking function for header values.
func WithMaskFunc(f func(name, value string) string) CurlOption {
return func(c *CurlRT) { c.maskFn = f }
}
// WithStreamLogging enables logging for streaming responses (SSE/NDJSON).
func WithStreamLogging(enabled bool) CurlOption { return func(c *CurlRT) { c.streamEnabled = enabled } }
// WithMaxStreamLogBytes caps stream summary size when using a plain Logger.
func WithMaxStreamLogBytes(n int) CurlOption { return func(c *CurlRT) { c.maxStreamLogBytes = n } }
// WithStreamContentTypeFilter sets a filter to detect streaming responses.
func WithStreamContentTypeFilter(f func(ct string) bool) CurlOption {
return func(c *CurlRT) { c.streamCTFilter = f }
}
func NewCurlRT(base http.RoundTripper, opts ...CurlOption) *CurlRT {
rt := &CurlRT{base: base}
for _, o := range opts {
o(rt)
}
if rt.logger == nil {
rt.logger = log.Default()
}
if rt.maskFn == nil {
rt.maskFn = func(_ string, _ string) string { return "<redacted>" }
}
if rt.streamCTFilter == nil {
rt.streamCTFilter = func(ct string) bool {
ct = strings.ToLower(ct)
return strings.Contains(ct, "text/event-stream") || strings.Contains(ct, "application/x-ndjson")
}
}
return rt
}
func (c *CurlRT) RoundTrip(req *http.Request) (*http.Response, error) {
var reqBody []byte
if req.Body != nil {
reqBody, _ = io.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewReader(reqBody))
}
curl := c.buildCurl(req, reqBody)
if c.ctxLogger != nil {
c.ctxLogger.Printf(req.Context(), "[curl request] %s", curl)
} else {
c.logger.Printf("[curl request] %s", curl)
}
resp, err := c.base.RoundTrip(req)
if err != nil {
return nil, err
}
ct := resp.Header.Get("Content-Type")
if c.streamEnabled && c.streamCTFilter(ct) {
if c.ctxLogger != nil {
c.ctxLogger.Printf(req.Context(), "[curl response] HTTP/%d.%d %d\n%s\n\n(streaming...)", resp.ProtoMajor, resp.ProtoMinor, resp.StatusCode, c.formatHeaders(resp.Header))
} else {
c.logger.Printf("[curl response] HTTP/%d.%d %d\n%s\n\n(streaming...)", resp.ProtoMajor, resp.ProtoMinor, resp.StatusCode, c.formatHeaders(resp.Header))
}
resp.Body = newLoggingReadCloser(resp.Body, req.Context(), c)
return resp, nil
}
var respBody []byte
if resp.Body != nil {
respBody, _ = io.ReadAll(resp.Body)
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
if c.ctxLogger != nil {
c.ctxLogger.Printf(req.Context(), "[curl response] HTTP/%d.%d %d\n%s\n\n%s", resp.ProtoMajor, resp.ProtoMinor, resp.StatusCode, c.formatHeaders(resp.Header), string(respBody))
} else {
c.logger.Printf("[curl response] HTTP/%d.%d %d\n%s\n\n%s", resp.ProtoMajor, resp.ProtoMinor, resp.StatusCode, c.formatHeaders(resp.Header), string(respBody))
}
return resp, nil
}
func (c *CurlRT) mask(name, value string) string {
if strings.EqualFold(name, "Authorization") && !c.printAuth {
return "<redacted>"
}
if _, ok := c.maskHeaders[strings.ToLower(name)]; ok {
return c.maskFn(name, value)
}
return value
}
func (c *CurlRT) buildCurl(req *http.Request, body []byte) string {
var b bytes.Buffer
b.WriteString("curl -X ")
b.WriteString(sanitizeLogValue(req.Method))
b.WriteString(" '")
b.WriteString(sanitizeLogValue(req.URL.String()))
b.WriteString("'")
for k, vs := range req.Header {
for _, v := range vs {
v = c.mask(k, v)
b.WriteString(" -H '")
b.WriteString(sanitizeLogValue(k))
b.WriteString(": ")
b.WriteString(sanitizeLogValue(v))
b.WriteString("'")
}
}
if len(body) > 0 {
b.WriteString(" --data '")
b.WriteString(sanitizeLogValue(string(body)))
b.WriteString("'")
}
return b.String()
}
func (c *CurlRT) formatHeaders(h http.Header) string {
var b bytes.Buffer
for k, vs := range h {
for _, v := range vs {
v = c.mask(k, v)
b.WriteString(sanitizeLogValue(k))
b.WriteString(": ")
b.WriteString(sanitizeLogValue(v))
b.WriteString("\n")
}
}
return b.String()
}
type loggingReadCloser struct {
rc io.ReadCloser
ctx context.Context
l Logger
cl CtxLogger
cap int
total int
summary *bytes.Buffer
}
func newLoggingReadCloser(rc io.ReadCloser, ctx context.Context, c *CurlRT) io.ReadCloser {
var buf *bytes.Buffer
if c.ctxLogger == nil {
buf = &bytes.Buffer{}
}
ca := c.maxStreamLogBytes
if ca <= 0 {
ca = 8192
}
return &loggingReadCloser{rc: rc, ctx: ctx, l: c.logger, cl: c.ctxLogger, cap: ca, summary: buf}
}
func (lrc *loggingReadCloser) Read(p []byte) (int, error) {
n, err := lrc.rc.Read(p)
if n > 0 {
chunk := p[:n]
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if i < len(lines)-1 || len(line) > 0 {
if lrc.cl != nil {
lrc.cl.Printf(lrc.ctx, "[curl stream chunk] %s", string(line))
} else {
remaining := lrc.cap - lrc.total
if remaining > 0 {
toWrite := line
if len(toWrite) > remaining {
toWrite = toWrite[:remaining]
}
lrc.summary.Write(toWrite)
lrc.summary.WriteByte('\n')
lrc.total += len(toWrite)
}
}
}
}
}
return n, err
}
func (lrc *loggingReadCloser) Close() error {
if lrc.summary != nil && lrc.summary.Len() > 0 {
lrc.l.Printf("[curl stream summary]\n%s", lrc.summary.String())
}
return lrc.rc.Close()
}
type IDCtxLogger struct{ L Logger }
func (i IDCtxLogger) Printf(ctx context.Context, format string, args ...any) {
v := ctx.Value("log_id")
if s, ok := v.(string); ok && s != "" {
i.L.Printf("[req_id=%s] "+format, append([]any{s}, args...)...)
return
}
i.L.Printf(format, args...)
}