diff --git a/jinja2cli/cli.py b/jinja2cli/cli.py index dc24495..a97d7a3 100644 --- a/jinja2cli/cli.py +++ b/jinja2cli/cli.py @@ -242,12 +242,13 @@ def _load_json5(): } -def render(template_path, data, extensions, strict=False): +def render(template_path, data, extensions, strict=False, stream=False): from jinja2 import ( __version__ as jinja_version, Environment, FileSystemLoader, StrictUndefined, + BaseLoader, ) # Starting with jinja2 3.1, `with_` and `autoescape` are no longer @@ -262,7 +263,7 @@ def render(template_path, data, extensions, strict=False): extensions.append(ext) env = Environment( - loader=FileSystemLoader(os.path.dirname(template_path)), + loader=BaseLoader() if stream else FileSystemLoader(os.path.dirname(template_path)), extensions=extensions, keep_trailing_newline=True, ) @@ -273,7 +274,8 @@ def render(template_path, data, extensions, strict=False): env.globals["environ"] = lambda key: force_text(os.environ.get(key)) env.globals["get_context"] = lambda: data - return env.get_template(os.path.basename(template_path)).render(data) + template = env.from_string(template_path.read()) if stream else env.get_template(os.path.basename(template_path)) + return template.render(data) def is_fd_alive(fd): @@ -285,53 +287,57 @@ def is_fd_alive(fd): def cli(opts, args): - template_path, data = args - format = opts.format - if data in ("-", ""): - if data == "-" or (data == "" and is_fd_alive(sys.stdin)): - data = sys.stdin.read() - if format == "auto": - # default to yaml first if available since yaml - # is a superset of json - if has_format("yaml"): - format = "yaml" - else: - format = "json" - else: - path = os.path.join(os.getcwd(), os.path.expanduser(data)) - if format == "auto": - ext = os.path.splitext(path)[1][1:] - if has_format(ext): - format = ext - else: - raise InvalidDataFormat(ext) - - with open(path) as fp: - data = fp.read() - - template_path = os.path.abspath(template_path) - - if data: - try: - fn, except_exc, raise_exc = get_format(format) - except InvalidDataFormat: - if format in ("yml", "yaml"): - raise InvalidDataFormat("%s: install pyyaml to fix" % format) - if format == "toml": - raise InvalidDataFormat("toml: install toml to fix") - if format == "xml": - raise InvalidDataFormat("xml: install xmltodict to fix") - if format == "hjson": - raise InvalidDataFormat("hjson: install hjson to fix") - if format == "json5": - raise InvalidDataFormat("json5: install json5 to fix") - raise - try: - data = fn(data) or {} - except except_exc: - raise raise_exc("%s ..." % data[:60]) - else: + if opts.stream: data = {} + template_path = sys.stdin + else: + template_path, data = args + format = opts.format + if data in ("-", ""): + if data == "-" or (data == "" and is_fd_alive(sys.stdin)): + data = sys.stdin.read() + if format == "auto": + # default to yaml first if available since yaml + # is a superset of json + if has_format("yaml"): + format = "yaml" + else: + format = "json" + else: + path = os.path.join(os.getcwd(), os.path.expanduser(data)) + if format == "auto": + ext = os.path.splitext(path)[1][1:] + if has_format(ext): + format = ext + else: + raise InvalidDataFormat(ext) + + with open(path) as fp: + data = fp.read() + + template_path = os.path.abspath(template_path) + + if data: + try: + fn, except_exc, raise_exc = get_format(format) + except InvalidDataFormat: + if format in ("yml", "yaml"): + raise InvalidDataFormat("%s: install pyyaml to fix" % format) + if format == "toml": + raise InvalidDataFormat("toml: install toml to fix") + if format == "xml": + raise InvalidDataFormat("xml: install xmltodict to fix") + if format == "hjson": + raise InvalidDataFormat("hjson: install hjson to fix") + if format == "json5": + raise InvalidDataFormat("json5: install json5 to fix") + raise + try: + data = fn(data) or {} + except except_exc: + raise raise_exc("%s ..." % data[:60]) + else: + data = {} extensions = [] for ext in opts.extensions: @@ -362,7 +368,7 @@ def cli(opts, args): out = codecs.getwriter("utf8")(out) - out.write(render(template_path, data, extensions, opts.strict)) + out.write(render(template_path, data, extensions, opts.strict, opts.stream)) out.flush() return 0 @@ -461,12 +467,18 @@ def main(): metavar="FILE", action="store", ) + parser.add_option( + "-S", + "--stream", + help="Input template comes from stdin. There is no input data argument", + action="store_true", + ) opts, args = parser.parse_args() # Dedupe list opts.extensions = set(opts.extensions) - if len(args) == 0: + if len(args) == 0 and not opts.stream: parser.print_help() sys.exit(1)