package templ import ( "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "html" "html/template" "io" "net/http" "os" "reflect" "runtime" "sort" "strconv" "strings" "sync" "time" "github.com/a-h/templ/safehtml" ) // Types exposed by all components. // Component is the interface that all templates implement. type Component interface { // Render the template. Render(ctx context.Context, w io.Writer) error } // ComponentFunc converts a function that matches the Component interface's // Render method into a Component. type ComponentFunc func(ctx context.Context, w io.Writer) error // Render the template. func (cf ComponentFunc) Render(ctx context.Context, w io.Writer) error { return cf(ctx, w) } // WithNonce sets a CSP nonce on the context and returns it. func WithNonce(ctx context.Context, nonce string) context.Context { ctx, v := getContext(ctx) v.nonce = nonce return ctx } // GetNonce returns the CSP nonce value set with WithNonce, or an // empty string if none has been set. func GetNonce(ctx context.Context) (nonce string) { if ctx == nil { return "" } _, v := getContext(ctx) return v.nonce } func WithChildren(ctx context.Context, children Component) context.Context { ctx, v := getContext(ctx) v.children = &children return ctx } func ClearChildren(ctx context.Context) context.Context { _, v := getContext(ctx) v.children = nil return ctx } // NopComponent is a component that doesn't render anything. var NopComponent = ComponentFunc(func(ctx context.Context, w io.Writer) error { return nil }) // GetChildren from the context. func GetChildren(ctx context.Context) Component { _, v := getContext(ctx) if v.children == nil { return NopComponent } return *v.children } // EscapeString escapes HTML text within templates. func EscapeString(s string) string { return html.EscapeString(s) } // Bool attribute value. func Bool(value bool) bool { return value } // Classes for CSS. // Supported types are string, ConstantCSSClass, ComponentCSSClass, map[string]bool. func Classes(classes ...any) CSSClasses { return CSSClasses(classes) } // CSSClasses is a slice of CSS classes. type CSSClasses []any // String returns the names of all CSS classes. func (classes CSSClasses) String() string { if len(classes) == 0 { return "" } cp := newCSSProcessor() for _, v := range classes { cp.Add(v) } return cp.String() } func newCSSProcessor() *cssProcessor { return &cssProcessor{ classNameToEnabled: make(map[string]bool), } } type cssProcessor struct { classNameToEnabled map[string]bool orderedNames []string } func (cp *cssProcessor) Add(item any) { switch c := item.(type) { case []string: for _, className := range c { cp.AddClassName(className, true) } case string: cp.AddClassName(c, true) case ConstantCSSClass: cp.AddClassName(c.ClassName(), true) case ComponentCSSClass: cp.AddClassName(c.ClassName(), true) case map[string]bool: // In Go, map keys are iterated in a randomized order. // So the keys in the map must be sorted to produce consistent output. keys := make([]string, len(c)) var i int for key := range c { keys[i] = key i++ } sort.Strings(keys) for _, className := range keys { cp.AddClassName(className, c[className]) } case []KeyValue[string, bool]: for _, kv := range c { cp.AddClassName(kv.Key, kv.Value) } case KeyValue[string, bool]: cp.AddClassName(c.Key, c.Value) case []KeyValue[CSSClass, bool]: for _, kv := range c { cp.AddClassName(kv.Key.ClassName(), kv.Value) } case KeyValue[CSSClass, bool]: cp.AddClassName(c.Key.ClassName(), c.Value) case CSSClasses: for _, item := range c { cp.Add(item) } case []CSSClass: for _, item := range c { cp.Add(item) } case func() CSSClass: cp.AddClassName(c().ClassName(), true) default: cp.AddClassName(unknownTypeClassName, true) } } func (cp *cssProcessor) AddClassName(className string, enabled bool) { cp.classNameToEnabled[className] = enabled cp.orderedNames = append(cp.orderedNames, className) } func (cp *cssProcessor) String() string { // Order the outputs according to how they were input, and remove disabled names. rendered := make(map[string]any, len(cp.classNameToEnabled)) var names []string for _, name := range cp.orderedNames { if enabled := cp.classNameToEnabled[name]; !enabled { continue } if _, hasBeenRendered := rendered[name]; hasBeenRendered { continue } names = append(names, name) rendered[name] = struct{}{} } return strings.Join(names, " ") } // KeyValue is a key and value pair. type KeyValue[TKey comparable, TValue any] struct { Key TKey `json:"name"` Value TValue `json:"value"` } // KV creates a new key/value pair from the input key and value. func KV[TKey comparable, TValue any](key TKey, value TValue) KeyValue[TKey, TValue] { return KeyValue[TKey, TValue]{ Key: key, Value: value, } } const unknownTypeClassName = "--templ-css-class-unknown-type" // Class returns a CSS class name. // Deprecated: use a string instead. func Class(name string) CSSClass { return SafeClass(name) } // SafeClass bypasses CSS class name validation. // Deprecated: use a string instead. func SafeClass(name string) CSSClass { return ConstantCSSClass(name) } // CSSClass provides a class name. type CSSClass interface { ClassName() string } // ConstantCSSClass is a string constant of a CSS class name. // Deprecated: use a string instead. type ConstantCSSClass string // ClassName of the CSS class. func (css ConstantCSSClass) ClassName() string { return string(css) } // ComponentCSSClass is a templ.CSS type ComponentCSSClass struct { // ID of the class, will be autogenerated. ID string // Definition of the CSS. Class SafeCSS } // ClassName of the CSS class. func (css ComponentCSSClass) ClassName() string { return css.ID } // CSSID calculates an ID. func CSSID(name string, css string) string { sum := sha256.Sum256([]byte(css)) hp := hex.EncodeToString(sum[:])[0:4] // Benchmarking showed this was fastest, and with fewest allocations (1). // Using strings.Builder (2 allocs). // Using fmt.Sprintf (3 allocs). return name + "_" + hp } // NewCSSMiddleware creates HTTP middleware that renders a global stylesheet of ComponentCSSClass // CSS if the request path matches, or updates the HTTP context to ensure that any handlers that // use templ.Components skip rendering `); err != nil { return err } } return nil } func renderCSSItemsToBuilder(sb *strings.Builder, v *contextValue, classes ...any) { for _, c := range classes { switch ccc := c.(type) { case ComponentCSSClass: if !v.hasClassBeenRendered(ccc.ID) { sb.WriteString(string(ccc.Class)) v.addClass(ccc.ID) } case KeyValue[ComponentCSSClass, bool]: if !ccc.Value { continue } renderCSSItemsToBuilder(sb, v, ccc.Key) case KeyValue[CSSClass, bool]: if !ccc.Value { continue } renderCSSItemsToBuilder(sb, v, ccc.Key) case CSSClasses: renderCSSItemsToBuilder(sb, v, ccc...) case []CSSClass: for _, item := range ccc { renderCSSItemsToBuilder(sb, v, item) } case func() CSSClass: renderCSSItemsToBuilder(sb, v, ccc()) case []string: // Skip. These are class names, not CSS classes. case string: // Skip. This is a class name, not a CSS class. case ConstantCSSClass: // Skip. This is a class name, not a CSS class. case CSSClass: // Skip. This is a class name, not a CSS class. case map[string]bool: // Skip. These are class names, not CSS classes. case KeyValue[string, bool]: // Skip. These are class names, not CSS classes. case []KeyValue[string, bool]: // Skip. These are class names, not CSS classes. case KeyValue[ConstantCSSClass, bool]: // Skip. These are class names, not CSS classes. case []KeyValue[ConstantCSSClass, bool]: // Skip. These are class names, not CSS classes. } } } // SafeCSS is CSS that has been sanitized. type SafeCSS string type SafeCSSProperty string var safeCSSPropertyType = reflect.TypeOf(SafeCSSProperty("")) // SanitizeCSS sanitizes CSS properties to ensure that they are safe. func SanitizeCSS[T ~string](property string, value T) SafeCSS { if reflect.TypeOf(value) == safeCSSPropertyType { return SafeCSS(safehtml.SanitizeCSSProperty(property) + ":" + string(value) + ";") } p, v := safehtml.SanitizeCSS(property, string(value)) return SafeCSS(p + ":" + v + ";") } // Attributes is an alias to map[string]any made for spread attributes. type Attributes map[string]any // sortedKeys returns the keys of a map in sorted order. func sortedKeys(m map[string]any) (keys []string) { keys = make([]string, len(m)) var i int for k := range m { keys[i] = k i++ } sort.Strings(keys) return keys } func writeStrings(w io.Writer, ss ...string) (err error) { for _, s := range ss { if _, err = io.WriteString(w, s); err != nil { return err } } return nil } func RenderAttributes(ctx context.Context, w io.Writer, attributes Attributes) (err error) { for _, key := range sortedKeys(attributes) { value := attributes[key] switch value := value.(type) { case string: if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(value), `"`); err != nil { return err } case *string: if value != nil { if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(*value), `"`); err != nil { return err } } case bool: if value { if err = writeStrings(w, ` `, EscapeString(key)); err != nil { return err } } case *bool: if value != nil && *value { if err = writeStrings(w, ` `, EscapeString(key)); err != nil { return err } } case KeyValue[string, bool]: if value.Value { if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(value.Key), `"`); err != nil { return err } } case KeyValue[bool, bool]: if value.Value && value.Key { if err = writeStrings(w, ` `, EscapeString(key)); err != nil { return err } } case func() bool: if value() { if err = writeStrings(w, ` `, EscapeString(key)); err != nil { return err } } } } return nil } // Script handling. func safeEncodeScriptParams(escapeHTML bool, params []any) []string { encodedParams := make([]string, len(params)) for i := 0; i < len(encodedParams); i++ { enc, _ := json.Marshal(params[i]) if !escapeHTML { encodedParams[i] = string(enc) continue } encodedParams[i] = EscapeString(string(enc)) } return encodedParams } // SafeScript encodes unknown parameters for safety for inside HTML attributes. func SafeScript(functionName string, params ...any) string { encodedParams := safeEncodeScriptParams(true, params) sb := new(strings.Builder) sb.WriteString(functionName) sb.WriteRune('(') sb.WriteString(strings.Join(encodedParams, ",")) sb.WriteRune(')') return sb.String() } // SafeScript encodes unknown parameters for safety for inline scripts. func SafeScriptInline(functionName string, params ...any) string { encodedParams := safeEncodeScriptParams(false, params) sb := new(strings.Builder) sb.WriteString(functionName) sb.WriteRune('(') sb.WriteString(strings.Join(encodedParams, ",")) sb.WriteRune(')') return sb.String() } type contextKeyType int const contextKey = contextKeyType(0) type contextValue struct { ss map[string]struct{} onceHandles map[*OnceHandle]struct{} children *Component nonce string } func (v *contextValue) setHasBeenRendered(h *OnceHandle) { if v.onceHandles == nil { v.onceHandles = map[*OnceHandle]struct{}{} } v.onceHandles[h] = struct{}{} } func (v *contextValue) getHasBeenRendered(h *OnceHandle) (ok bool) { if v.onceHandles == nil { v.onceHandles = map[*OnceHandle]struct{}{} } _, ok = v.onceHandles[h] return } func (v *contextValue) addScript(s string) { if v.ss == nil { v.ss = map[string]struct{}{} } v.ss["script_"+s] = struct{}{} } func (v *contextValue) hasScriptBeenRendered(s string) (ok bool) { if v.ss == nil { v.ss = map[string]struct{}{} } _, ok = v.ss["script_"+s] return } func (v *contextValue) addClass(s string) { if v.ss == nil { v.ss = map[string]struct{}{} } v.ss["class_"+s] = struct{}{} } func (v *contextValue) hasClassBeenRendered(s string) (ok bool) { if v.ss == nil { v.ss = map[string]struct{}{} } _, ok = v.ss["class_"+s] return } // InitializeContext initializes context used to store internal state used during rendering. func InitializeContext(ctx context.Context) context.Context { if _, ok := ctx.Value(contextKey).(*contextValue); ok { return ctx } v := &contextValue{} ctx = context.WithValue(ctx, contextKey, v) return ctx } func getContext(ctx context.Context) (context.Context, *contextValue) { v, ok := ctx.Value(contextKey).(*contextValue) if !ok { ctx = InitializeContext(ctx) v = ctx.Value(contextKey).(*contextValue) } return ctx, v } // ComponentScript is a templ Script template. type ComponentScript struct { // Name of the script, e.g. print. Name string // Function to render. Function string // Call of the function in JavaScript syntax, including parameters, and // ensures parameters are HTML escaped; useful for injecting into HTML // attributes like onclick, onhover, etc. // // Given: // functionName("some string",12345) // It would render: // __templ_functionName_sha("some string",12345)) // // This is can be injected into HTML attributes: // Call string // Call of the function in JavaScript syntax, including parameters. It // does not HTML escape parameters; useful for directly calling in script // elements. // // Given: // functionName("some string",12345) // It would render: // __templ_functionName_sha("some string",12345)) // // This is can be used to call the function inside a script tag: // CallInline string } var _ Component = ComponentScript{} func writeScriptHeader(ctx context.Context, w io.Writer) (err error) { var nonceAttr string if nonce := GetNonce(ctx); nonce != "" { nonceAttr = " nonce=\"" + EscapeString(nonce) + "\"" } _, err = fmt.Fprintf(w, ``); err != nil { return err } } return nil } // RenderScriptItems renders a `); err != nil { return err } } return nil } var bufferPool = sync.Pool{ New: func() any { return new(bytes.Buffer) }, } func GetBuffer() *bytes.Buffer { return bufferPool.Get().(*bytes.Buffer) } func ReleaseBuffer(b *bytes.Buffer) { b.Reset() bufferPool.Put(b) } // JoinStringErrs joins an optional list of errors. func JoinStringErrs(s string, errs ...error) (string, error) { return s, errors.Join(errs...) } // Error returned during template rendering. type Error struct { Err error // FileName of the template file. FileName string // Line index of the error. Line int // Col index of the error. Col int } func (e Error) Error() string { if e.FileName == "" { e.FileName = "templ" } return fmt.Sprintf("%s: error at line %d, col %d: %v", e.FileName, e.Line, e.Col, e.Err) } func (e Error) Unwrap() error { return e.Err } // Raw renders the input HTML to the output without applying HTML escaping. // // Use of this component presents a security risk - the HTML should come from // a trusted source, because it will be included as-is in the output. func Raw[T ~string](html T, errs ...error) Component { return ComponentFunc(func(ctx context.Context, w io.Writer) (err error) { if err = errors.Join(errs...); err != nil { return err } _, err = io.WriteString(w, string(html)) return err }) } // FromGoHTML creates a templ Component from a Go html/template template. func FromGoHTML(t *template.Template, data any) Component { return ComponentFunc(func(ctx context.Context, w io.Writer) (err error) { return t.Execute(w, data) }) } // ToGoHTML renders the component to a Go html/template template.HTML string. func ToGoHTML(ctx context.Context, c Component) (s template.HTML, err error) { b := GetBuffer() defer ReleaseBuffer(b) if err = c.Render(ctx, b); err != nil { return } s = template.HTML(b.String()) return } // WriteWatchModeString is used when rendering templates in development mode. // the generator would have written non-go code to the _templ.txt file, which // is then read by this function and written to the output. func WriteWatchModeString(w io.Writer, lineNum int) error { _, path, _, _ := runtime.Caller(1) if !strings.HasSuffix(path, "_templ.go") { return errors.New("templ: WriteWatchModeString can only be called from _templ.go") } txtFilePath := strings.Replace(path, "_templ.go", "_templ.txt", 1) literals, err := getWatchedStrings(txtFilePath) if err != nil { return fmt.Errorf("templ: failed to cache strings: %w", err) } if lineNum > len(literals) { return errors.New("templ: failed to find line " + strconv.Itoa(lineNum) + " in " + txtFilePath) } unquoted, err := strconv.Unquote(`"` + literals[lineNum-1] + `"`) if err != nil { return err } _, err = io.WriteString(w, unquoted) return err } var ( watchModeCache = map[string]watchState{} watchStateMutex sync.Mutex ) type watchState struct { modTime time.Time strings []string } func getWatchedStrings(txtFilePath string) ([]string, error) { watchStateMutex.Lock() defer watchStateMutex.Unlock() state, cached := watchModeCache[txtFilePath] if !cached { return cacheStrings(txtFilePath) } if time.Since(state.modTime) < time.Millisecond*100 { return state.strings, nil } info, err := os.Stat(txtFilePath) if err != nil { return nil, fmt.Errorf("templ: failed to stat %s: %w", txtFilePath, err) } if !info.ModTime().After(state.modTime) { return state.strings, nil } return cacheStrings(txtFilePath) } func cacheStrings(txtFilePath string) ([]string, error) { txtFile, err := os.Open(txtFilePath) if err != nil { return nil, fmt.Errorf("templ: failed to open %s: %w", txtFilePath, err) } defer txtFile.Close() info, err := txtFile.Stat() if err != nil { return nil, fmt.Errorf("templ: failed to stat %s: %w", txtFilePath, err) } all, err := io.ReadAll(txtFile) if err != nil { return nil, fmt.Errorf("templ: failed to read %s: %w", txtFilePath, err) } literals := strings.Split(string(all), "\n") watchModeCache[txtFilePath] = watchState{ modTime: info.ModTime(), strings: literals, } return literals, nil }