injection.go 12.9 KB
Newer Older
1 2 3 4 5 6
package injection

import (
	"bytes"
	"fmt"
	"io/ioutil"
7
	"math"
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
	"net/http"
	"regexp"
	"strings"

	"github.com/pkg/errors"
	"go.uber.org/zap"

	"github.com/caddyserver/caddy/v2"
	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
	"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
	"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)

func init() {
	caddy.RegisterModule(Middleware{})
	httpcaddyfile.RegisterHandlerDirective("injection", parseCaddyfile)
}

type Middleware struct {
	// Regex to specify which kind of response should we filter
	ContentType string `json:"content_type"`
29 30
	Inject      string `json:"inject"`
	Before      string `json:"before"`
31 32 33

	compiledContentTypeRegex *regexp.Regexp

34
	Logger *zap.Logger
35 36 37 38 39 40 41 42 43 44 45 46
}

// CaddyModule returns the Caddy module information.
func (Middleware) CaddyModule() caddy.ModuleInfo {
	return caddy.ModuleInfo{
		ID:  "http.handlers.injection",
		New: func() caddy.Module { return new(Middleware) },
	}
}

// Provision implements caddy.Provisioner.
func (m *Middleware) Provision(ctx caddy.Context) error {
47 48
	m.Logger = ctx.Logger(m)
	m.Logger.Info("Provisioning injection plugin",
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
		zap.String("ContentType", m.ContentType),
		zap.String("Inject", m.Inject))
	return nil
}

// Validate implements caddy.Validator.
func (m *Middleware) Validate() error {
	var err error
	if m.compiledContentTypeRegex, err = regexp.Compile(m.ContentType); err != nil {
		return fmt.Errorf("invalid regex content_type: %w", err)
	}
	return nil
}

type ContentTypeStatus int

const (
	toBeChecked ContentTypeStatus = iota
	noMatch
	matches
)

71
type LineHandler interface {
72
	HandleLine(line string) (string, error)
73 74
}

75
type InjectedWriter struct {
76 77 78
	OriginalWriter    http.ResponseWriter
	Request           *http.Request
	RecordedHTML      bytes.Buffer
79
	totalBytesWritten int
80
	Logger            *zap.Logger
81
	contentTypeStatus ContentTypeStatus
82
	LineHandler       LineHandler
Simao Gomes Viana's avatar
Simao Gomes Viana committed
83
	cspNonce		  string
84
	hasSeenClosingHead bool
85
	M                 *Middleware
86 87 88
}

func (i InjectedWriter) Header() http.Header {
89
	return i.OriginalWriter.Header()
90 91 92
}

func (i *InjectedWriter) Write(bytes []byte) (int, error) {
93 94
	if i.LineHandler == nil {
		i.LineHandler = i
95
	}
96
	if i.contentTypeStatus == noMatch {
97 98 99
		return i.OriginalWriter.Write(bytes)
	} else if i.contentTypeStatus == toBeChecked && !i.M.compiledContentTypeRegex.MatchString(
		strings.Split(i.OriginalWriter.Header().Get("Content-Type"), ";")[0]) {
100
		i.contentTypeStatus = noMatch
101
		return i.OriginalWriter.Write(bytes)
102
	}
Simao Gomes Viana's avatar
Fixes  
Simao Gomes Viana committed
103
	i.contentTypeStatus = matches
104 105
	i.RecordedHTML.Write(bytes)
	recordedString := i.RecordedHTML.String()
106
	if strings.ContainsRune(recordedString, '\n') {
107
		i.RecordedHTML.Truncate(0)
108 109 110 111 112 113
		isLastLineComplete := false
		if strings.HasSuffix(recordedString, "\n") {
			isLastLineComplete = true
		}
		lines := strings.Split(recordedString, "\n")
		for index, line := range lines {
114
			if !isLastLineComplete && index == len(lines)-1 {
115
				// Write the incomplete line back into the buffer
116
				i.RecordedHTML.WriteString(line)
117 118
				break
			}
119
			newString, err := i.LineHandler.HandleLine(i.handleCSPForLine(line))
120 121 122 123 124
			if err != nil {
				return 0, err
			}
			newString += "\n"
			newBytes := []byte(newString)
125
			bytesWritten, err := i.OriginalWriter.Write(newBytes)
126 127 128 129 130 131 132 133 134 135 136 137 138
			if err != nil {
				return 0, errors.Wrap(err, "error occurred while writing out response bytes")
			}
			if bytesWritten != len(newBytes) {
				return bytesWritten, errors.Wrap(err, "couldn't write out complete response bytes")
			}
			i.totalBytesWritten += bytesWritten
		}
	}
	return len(bytes), nil
}

func (i *InjectedWriter) textToInject() (string, error) {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
139 140 141
	if len(i.M.Inject) == 0 {
		return "", nil
	}
142
	content, err := ioutil.ReadFile(i.M.Inject)
143
	if err != nil {
144
		i.Logger.Warn("Could not read file to inject!", zap.Error(err))
145 146 147 148 149 150
		return "", err
	}
	contentString := string(content)
	return contentString, nil
}

151
func (i *InjectedWriter) HandleLine(line string) (string, error) {
152
	if strings.Contains(line, i.M.Before) {
153
		textToInject, err := i.textToInject()
Simao Gomes Viana's avatar
Simao Gomes Viana committed
154
		textToInject = i.HandleCSPForText(textToInject)
155 156 157
		if err != nil {
			return line, nil
		}
158
		return strings.Replace(line, i.M.Before, textToInject+i.M.Before, 1), nil
159 160 161 162
	}
	return line, nil
}

163 164 165 166 167 168 169 170
func extractValueForDirective(csp string, name string) string {
	if !strings.Contains(csp, name + " ") {
		return ""
	}
	return strings.TrimSpace(
		strings.Split(strings.Split(csp, name + " ")[1], ";")[0])
}

Simao Gomes Viana's avatar
Simao Gomes Viana committed
171 172
func (i *InjectedWriter) HandleCSP() error {
	csp := i.OriginalWriter.Header().Get("Content-Security-Policy")
Simao Gomes Viana's avatar
Simao Gomes Viana committed
173
	if len(strings.TrimSpace(csp)) != 0 {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
174 175 176 177 178
		var err error
		i.cspNonce, err = GenerateRandomStringURLSafe(6)
		if err != nil {
			return err
		}
179 180 181 182 183 184 185 186 187
		i.OriginalWriter.Header().Set("Content-Security-Policy", i.transformCSP(csp))
	}

	return nil
}

func (i *InjectedWriter) transformCSP(csp string) string {
	defaultSrc := extractValueForDirective(csp, "default-src")
	if len(defaultSrc) == 0 {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
188 189
		// add back 'unsafe-hashes' when appropriate
		defaultSrc = "'self' https: data: blob: 'unsafe-eval' 'unsafe-inline'"
190 191 192 193 194 195 196 197 198 199
		csp = fmt.Sprintf("default-src %s; %s", defaultSrc, csp)
	}
	cspSrcArg := fmt.Sprintf("'nonce-%s' 'unsafe-inline'", i.cspNonce)
	if strings.Contains(csp, "script-src ") {
		if !strings.Contains(
			extractValueForDirective(csp, "script-src"),
			"'unsafe-inline'",
		) {
			// we need to add a source instead of adding the entire directive
			csp = strings.Replace(csp, "script-src ", fmt.Sprintf("script-src %s ", cspSrcArg), 1)
200
		}
201 202 203 204 205
	} else {
		cspSrcArgFinal := cspSrcArg
		if strings.Contains(defaultSrc, "'unsafe-inline'") {
			// Skip nonce if unsafe-inline otherwise it will be disabled
			cspSrcArgFinal = ""
Simao Gomes Viana's avatar
Simao Gomes Viana committed
206
		}
207 208 209 210 211 212 213 214 215
		csp += fmt.Sprintf("; script-src %s %s", defaultSrc, cspSrcArgFinal)
	}
	if strings.Contains(csp, "style-src ") {
		if !strings.Contains(
			extractValueForDirective(csp, "style-src"),
			"'unsafe-inline'",
		) {
			// we need to add a source instead of adding the entire directive
			csp = strings.Replace(csp, "style-src ", fmt.Sprintf("style-src %s ", cspSrcArg), 1)
Simao Gomes Viana's avatar
Simao Gomes Viana committed
216
		}
217 218 219 220 221 222 223
	} else {
		cspSrcArgFinal := cspSrcArg
		if strings.Contains(defaultSrc, "'unsafe-inline'") {
			// Skip nonce if unsafe-inline otherwise it will be disabled
			cspSrcArgFinal = ""
		}
		csp += fmt.Sprintf("; style-src %s %s", defaultSrc, cspSrcArgFinal)
Simao Gomes Viana's avatar
Simao Gomes Viana committed
224
	}
225
	return csp
Simao Gomes Viana's avatar
Simao Gomes Viana committed
226 227
}

228
func (i *InjectedWriter) HandleCSPForText(text string) string {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
229 230
	if len(i.cspNonce) == 0 {
		// Remove nonce attributes since CSP is not active
231 232 233 234 235
		return strings.ReplaceAll(text, " nonce=\"{{csp-nonce}}\"", "")
	}
	return strings.ReplaceAll(text, "{{csp-nonce}}", i.cspNonce)
}

236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
func nonNegativeMin(is ...int) int {
	min := math.MaxInt32
	found := false
	for _, i := range is {
		if min > i && i != -1 {
			min = i
			found = true
		}
	}
	if !found {
		return -1
	}
	return min
}

251 252 253 254
const metaTag = "<meta "
const httpEquivPrefix = "http-equiv=\""
const metaCSPPrefix = httpEquivPrefix+"content-security-policy\""
const contentPrefix = "content=\""
Simao Gomes Viana's avatar
Simao Gomes Viana committed
255
const metaEnd = "</meta>"
256 257 258 259 260 261 262 263 264 265 266 267 268 269
func (i *InjectedWriter) handleCSPForLine(line string) string {
	if len(i.cspNonce) == 0 {
		return line
	}
	if i.hasSeenClosingHead {
		return line
	}
	lowerLine := strings.ToLower(line)
	httpEquivIndex := strings.Index(lowerLine, metaCSPPrefix)
	closingHeadIndex := strings.Index(lowerLine, "</head>")
	if httpEquivIndex > closingHeadIndex && closingHeadIndex != -1 {
		i.hasSeenClosingHead = true
		return line
	}
270
	lineToReturn := line
271 272 273
	if httpEquivIndex >= len(metaTag) {
		i.Logger.Debug("Found CSP in HTML, replacing it")
		fullTagToEnd := line[httpEquivIndex:]
274 275 276 277 278 279 280 281 282 283 284 285 286
		endIndex := -1
		endSuffixLen := 2 // />

		endIndexSlashClose := strings.Index(fullTagToEnd, "/>")
		endIndexMetaEnd := strings.Index(fullTagToEnd, metaEnd)
		endIndexGtImplClose := strings.Index(fullTagToEnd, ">")

		endIndex = nonNegativeMin(endIndexSlashClose, endIndexMetaEnd, endIndexGtImplClose)
		switch endIndex {
		default: fallthrough
		case -1:
			return line
		case endIndexMetaEnd:
Simao Gomes Viana's avatar
Simao Gomes Viana committed
287
			endSuffixLen = len(metaEnd)
288 289
		case endIndexGtImplClose:
			endSuffixLen = 1 // >
290
		}
291

292 293 294 295
		fullTag := fullTagToEnd[:endIndex]

		fullContentAttrStartIndex := strings.Index(fullTag, contentPrefix)
		if fullContentAttrStartIndex == -1 {
296
			goto end
297 298 299 300
		}
		contentAttrToEnd := fullTag[fullContentAttrStartIndex+len(contentPrefix):]
		contentAttrEndIndex := strings.Index(contentAttrToEnd, "\"")
		if contentAttrEndIndex == -1 {
301
			goto end
302 303 304
		}
		contentAttrValue := contentAttrToEnd[:contentAttrEndIndex]
		if len(contentAttrValue) == 0 {
305
			goto end
306
		}
307 308 309 310
		if strings.Contains(contentAttrValue, "default-src") {
			// Otherwise we could run into issues.
			// We'll remove the tag entirely if it doesn't have default-src.
			// The one in the header will still be transformed.
Simao Gomes Viana's avatar
Simao Gomes Viana committed
311 312 313
			goodCsp := i.transformCSP(contentAttrValue)
			newTag := fmt.Sprintf("http-equiv=\"content-security-policy\" content=\"%s\" ", goodCsp)
			i.Logger.Debug("Replaced CSP in HTML")
314 315
			lineToReturn = strings.Replace(line, fullTag, newTag, 1)
			goto end
Simao Gomes Viana's avatar
Simao Gomes Viana committed
316 317
		} else {
			fullTagEndToEnd :=
318
			 	line[strings.LastIndex(line[:httpEquivIndex+1], metaTag):endIndex+httpEquivIndex+endSuffixLen]
319 320 321
			lineToReturn = strings.Replace(line, fullTagEndToEnd, "", 1)
			i.Logger.Debug("Removing CSP entirely")
			goto end
322
		}
Simao Gomes Viana's avatar
Simao Gomes Viana committed
323

324
	}
325
end:
326 327
	if closingHeadIndex != -1 {
		i.hasSeenClosingHead = true
Simao Gomes Viana's avatar
Simao Gomes Viana committed
328
	}
329
	return lineToReturn
Simao Gomes Viana's avatar
Simao Gomes Viana committed
330 331
}

Simao Gomes Viana's avatar
Simao Gomes Viana committed
332
func (i *InjectedWriter) Flush() error {
333
	var err error
334
	finalString := i.RecordedHTML.String()
335
	if len(finalString) > 0 {
336
		finalString, err = i.LineHandler.HandleLine(i.handleCSPForLine(finalString))
337 338 339
		if err != nil {
			return err
		}
340
		n, err := i.OriginalWriter.Write([]byte(finalString))
341 342 343 344 345 346 347 348
		if err != nil {
			return err
		}
		i.totalBytesWritten += n
	}
	return nil
}

Simao Gomes Viana's avatar
Simao Gomes Viana committed
349
func (i *InjectedWriter) WriteHeader(statusCode int) {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
350
	if statusCode < http.StatusOK || statusCode >= 600 || i.M.ShouldBypassForResponse(i.OriginalWriter) {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
351
		i.Logger.Debug("This request is not eligible to be modified, passing thru.")
Simao Gomes Viana's avatar
Simao Gomes Viana committed
352 353 354
		i.OriginalWriter.WriteHeader(statusCode)
		return
	}
Simao Gomes Viana's avatar
Simao Gomes Viana committed
355 356
	// Ignore error because it's not critical
	_ = i.HandleCSP()
357 358
	i.OriginalWriter.Header().Del("Content-Length")
	i.OriginalWriter.WriteHeader(statusCode)
359 360
}

361 362 363 364
func CreateInjectedWriter(
	w http.ResponseWriter, r *http.Request, m *Middleware,
) *InjectedWriter {
	iw := &InjectedWriter{
365 366 367
		OriginalWriter: w,
		Request:        r,
		RecordedHTML:   bytes.Buffer{},
368 369
		Logger:         m.Logger,
		M:              m,
370
	}
Simao Gomes Viana's avatar
Fixes  
Simao Gomes Viana committed
371
	if len(m.ContentType) == 0 {
372
		iw.contentTypeStatus = matches
Simao Gomes Viana's avatar
Fixes  
Simao Gomes Viana committed
373
	}
374 375 376
	return iw
}

Simao Gomes Viana's avatar
Simao Gomes Viana committed
377
func (m Middleware) IsWebSocket(r *http.Request) bool {
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
	connectionValue := r.Header.Get("connection")
	connectionValues := strings.Split(connectionValue, ",")
	connectionValueMatches := false
	for _, connectionElem := range connectionValues {
		if strings.EqualFold(strings.TrimSpace(connectionElem), "upgrade") {
			connectionValueMatches = true
			break
		}
	}
	if !connectionValueMatches {
		return false
	}

	upgradeValue := r.Header.Get("upgrade")
	upgradeValues := strings.Split(upgradeValue, ",")
	upgradeValueMatches := false
	for _, upgradeElem := range upgradeValues {
		if strings.EqualFold(strings.TrimSpace(upgradeElem), "websocket") {
			upgradeValueMatches = true
		}
	}
	if !upgradeValueMatches {
		return false
	}

	return true
Simao Gomes Viana's avatar
Simao Gomes Viana committed
404 405 406
}

func (m Middleware) ShouldBypassForRequest(w http.ResponseWriter, r *http.Request) bool {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
407 408 409
	isWebsocket := m.IsWebSocket(r)
	m.Logger.Debug("This is a websocket, passing thru.")
	return isWebsocket
Simao Gomes Viana's avatar
Simao Gomes Viana committed
410 411 412 413 414 415
}

func (m Middleware) ShouldBypassForResponse(w http.ResponseWriter) bool {
	return len(w.Header().Get("upgrade")) > 0
}

416 417
// ServeHTTP implements caddyhttp.MiddlewareHandler.
func (m Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
418
	if m.ShouldBypassForRequest(w, r) {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
419
		m.Logger.Debug("This request is going to pass thru.")
Simao Gomes Viana's avatar
Simao Gomes Viana committed
420 421 422
		return next.ServeHTTP(w ,r)
	}

Simao Gomes Viana's avatar
Simao Gomes Viana committed
423 424
	var err error

425 426
	r.Header.Set("Accept-Encoding", "identity")
	injectedWriter := CreateInjectedWriter(w, r, &m)
Simao Gomes Viana's avatar
Simao Gomes Viana committed
427 428

	err = next.ServeHTTP(injectedWriter, r)
429 430 431
	if err != nil {
		return err
	}
Simao Gomes Viana's avatar
Simao Gomes Viana committed
432
	if err := injectedWriter.Flush(); err != nil {
433 434
		return err
	}
435
	m.Logger.Debug("", zap.Int("total bytes written", injectedWriter.totalBytesWritten))
436 437 438 439 440 441
	return nil
}

// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
	if !d.Next() {
Simao Gomes Viana's avatar
Simao Gomes Viana committed
442
		return d.Err("expected token following injection")
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
	}
	for d.NextBlock(0) {
		key := d.Val()
		var value string
		d.Args(&value)
		if d.NextArg() {
			return d.ArgErr()
		}
		switch key {
		case "content_type":
			m.ContentType = value
		case "inject":
			m.Inject = value
		case "before":
			m.Before = value
		default:
Simao Gomes Viana's avatar
Simao Gomes Viana committed
459
			return d.Err(fmt.Sprintf("invalid key for injection directive: %s", key))
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
		}
	}
	return nil
}

// parseCaddyfile unmarshals tokens from h into a new Middleware.
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
	var m Middleware
	err := m.UnmarshalCaddyfile(h.Dispenser)
	return m, err
}

// Interface guards
var (
	_ caddy.Provisioner           = (*Middleware)(nil)
	_ caddy.Validator             = (*Middleware)(nil)
	_ caddyhttp.MiddlewareHandler = (*Middleware)(nil)
	_ caddyfile.Unmarshaler       = (*Middleware)(nil)
)