# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from subprocess import Popen import os import subprocess import tempfile import traceback from ipykernel.kernelbase import Kernel __version__ = "0.0.1" def _get_executable(): """Find the mlir-opt executable.""" def is_exe(fpath): """Returns whether executable file.""" return os.path.isfile(fpath) and os.access(fpath, os.X_OK) program = os.environ.get("MLIR_OPT_EXECUTABLE", "mlir-opt") path, name = os.path.split(program) # Attempt to get the executable if path: if is_exe(program): return program else: for path in os.environ["PATH"].split(os.pathsep): file = os.path.join(path, name) if is_exe(file): return file raise OSError("mlir-opt not found, please see README") class MlirOptKernel(Kernel): """Kernel using mlir-opt inside jupyter. The reproducer syntax (`// configuration:`) is used to run passes. The previous result can be referenced to by using `_` (this variable is reset upon error). E.g., ```mlir // configuration: --pass func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... } ``` ```mlir // configuration: --next-pass _ ``` """ implementation = "mlir" implementation_version = __version__ language_version = __version__ language = "mlir" language_info = { "name": "mlir", "codemirror_mode": {"name": "mlir"}, "mimetype": "text/x-mlir", "file_extension": ".mlir", "pygments_lexer": "text", } @property def banner(self): """Returns kernel banner.""" # Just a placeholder. return "mlir-opt kernel %s" % __version__ def __init__(self, **kwargs): Kernel.__init__(self, **kwargs) self._ = None self.executable = None self.silent = False def get_executable(self): """Returns the mlir-opt executable path.""" if not self.executable: self.executable = _get_executable() return self.executable def process_output(self, output): """Reports regular command output.""" if not self.silent: # Send standard output stream_content = {"name": "stdout", "text": output} self.send_response(self.iopub_socket, "stream", stream_content) def process_error(self, output): """Reports error response.""" if not self.silent: # Send standard error stream_content = {"name": "stderr", "text": output} self.send_response(self.iopub_socket, "stream", stream_content) def do_execute( self, code, silent, store_history=True, user_expressions=None, allow_stdin=False ): """Execute user code using mlir-opt binary.""" def ok_status(): """Returns OK status.""" return { "status": "ok", "execution_count": self.execution_count, "payload": [], "user_expressions": {}, } def run(code): """Run the code by pipeing via filesystem.""" try: inputmlir = tempfile.NamedTemporaryFile(delete=False) command = [ # Specify input and output file to error out if also # set as arg. self.get_executable(), "--color", inputmlir.name, "-o", "-", ] # Simple handling of repeating last line. if code.endswith("\n_"): if not self._: raise NameError("No previous result set") code = code[:-1] + self._ inputmlir.write(code.encode("utf-8")) inputmlir.close() pipe = Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, errors = pipe.communicate() exitcode = pipe.returncode finally: os.unlink(inputmlir.name) # Replace temporary filename with placeholder. This takes the very # remote chance where the full input filename (generated above) # overlaps with something in the dump unrelated to the file. fname = inputmlir.name.encode("utf-8") output = output.replace(fname, b"<>") errors = errors.replace(fname, b"<>") return output, errors, exitcode self.silent = silent if not code.strip(): return ok_status() try: output, errors, exitcode = run(code) if exitcode: self._ = None else: self._ = output.decode("utf-8") except KeyboardInterrupt: return {"status": "abort", "execution_count": self.execution_count} except Exception as error: # Print traceback for local debugging. traceback.print_exc() self._ = None exitcode = 255 errors = repr(error).encode("utf-8") if exitcode: content = {"ename": "", "evalue": str(exitcode), "traceback": []} self.send_response(self.iopub_socket, "error", content) self.process_error(errors.decode("utf-8")) content["execution_count"] = self.execution_count content["status"] = "error" return content if not silent: data = {} data["text/x-mlir"] = self._ content = { "execution_count": self.execution_count, "data": data, "metadata": {}, } self.send_response(self.iopub_socket, "execute_result", content) self.process_output(self._) self.process_error(errors.decode("utf-8")) return ok_status()