484 lines
18 KiB
C++
484 lines
18 KiB
C++
//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <cstdint>
|
|
#include <optional>
|
|
#include <pybind11/cast.h>
|
|
#include <pybind11/detail/common.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/pytypes.h>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "IRModule.h"
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
#include "mlir-c/IR.h"
|
|
#include "mlir-c/Interfaces.h"
|
|
#include "mlir-c/Support.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace mlir {
|
|
namespace python {
|
|
|
|
constexpr static const char *constructorDoc =
|
|
R"(Creates an interface from a given operation/opview object or from a
|
|
subclass of OpView. Raises ValueError if the operation does not implement the
|
|
interface.)";
|
|
|
|
constexpr static const char *operationDoc =
|
|
R"(Returns an Operation for which the interface was constructed.)";
|
|
|
|
constexpr static const char *opviewDoc =
|
|
R"(Returns an OpView subclass _instance_ for which the interface was
|
|
constructed)";
|
|
|
|
constexpr static const char *inferReturnTypesDoc =
|
|
R"(Given the arguments required to build an operation, attempts to infer
|
|
its return types. Raises ValueError on failure.)";
|
|
|
|
constexpr static const char *inferReturnTypeComponentsDoc =
|
|
R"(Given the arguments required to build an operation, attempts to infer
|
|
its return shaped type components. Raises ValueError on failure.)";
|
|
|
|
namespace {
|
|
|
|
/// Takes in an optional ist of operands and converts them into a SmallVector
|
|
/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
|
|
llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
|
|
llvm::SmallVector<MlirValue> mlirOperands;
|
|
|
|
if (!operandList || operandList->empty()) {
|
|
return mlirOperands;
|
|
}
|
|
|
|
// Note: as the list may contain other lists this may not be final size.
|
|
mlirOperands.reserve(operandList->size());
|
|
for (const auto &&it : llvm::enumerate(*operandList)) {
|
|
if (it.value().is_none())
|
|
continue;
|
|
|
|
PyValue *val;
|
|
try {
|
|
val = py::cast<PyValue *>(it.value());
|
|
if (!val)
|
|
throw py::cast_error();
|
|
mlirOperands.push_back(val->get());
|
|
continue;
|
|
} catch (py::cast_error &err) {
|
|
// Intentionally unhandled to try sequence below first.
|
|
(void)err;
|
|
}
|
|
|
|
try {
|
|
auto vals = py::cast<py::sequence>(it.value());
|
|
for (py::object v : vals) {
|
|
try {
|
|
val = py::cast<PyValue *>(v);
|
|
if (!val)
|
|
throw py::cast_error();
|
|
mlirOperands.push_back(val->get());
|
|
} catch (py::cast_error &err) {
|
|
throw py::value_error(
|
|
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
|
" must be a Value or Sequence of Values (" + err.what() + ")")
|
|
.str());
|
|
}
|
|
}
|
|
continue;
|
|
} catch (py::cast_error &err) {
|
|
throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
|
" must be a Value or Sequence of Values (" +
|
|
err.what() + ")")
|
|
.str());
|
|
}
|
|
|
|
throw py::cast_error();
|
|
}
|
|
|
|
return mlirOperands;
|
|
}
|
|
|
|
/// Takes in an optional vector of PyRegions and returns a SmallVector of
|
|
/// MlirRegion. Returns an empty SmallVector if the list is empty.
|
|
llvm::SmallVector<MlirRegion>
|
|
wrapRegions(std::optional<std::vector<PyRegion>> regions) {
|
|
llvm::SmallVector<MlirRegion> mlirRegions;
|
|
|
|
if (regions) {
|
|
mlirRegions.reserve(regions->size());
|
|
for (PyRegion ®ion : *regions) {
|
|
mlirRegions.push_back(region);
|
|
}
|
|
}
|
|
|
|
return mlirRegions;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
/// CRTP base class for Python classes representing MLIR Op interfaces.
|
|
/// Interface hierarchies are flat so no base class is expected here. The
|
|
/// derived class is expected to define the following static fields:
|
|
/// - `const char *pyClassName` - the name of the Python class to create;
|
|
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
|
|
/// of the interface.
|
|
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
|
|
/// interface-specific methods.
|
|
///
|
|
/// An interface class may be constructed from either an Operation/OpView object
|
|
/// or from a subclass of OpView. In the latter case, only the static interface
|
|
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
|
|
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
|
|
/// method to check whether the interface object was constructed from a class or
|
|
/// an operation/opview instance. The `getOpName` always succeeds and returns a
|
|
/// canonical name of the operation suitable for lookups.
|
|
template <typename ConcreteIface>
|
|
class PyConcreteOpInterface {
|
|
protected:
|
|
using ClassTy = py::class_<ConcreteIface>;
|
|
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
|
|
|
public:
|
|
/// Constructs an interface instance from an object that is either an
|
|
/// operation or a subclass of OpView. In the latter case, only the static
|
|
/// methods of the interface are accessible to the caller.
|
|
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
|
|
: obj(std::move(object)) {
|
|
try {
|
|
operation = &py::cast<PyOperation &>(obj);
|
|
} catch (py::cast_error &) {
|
|
// Do nothing.
|
|
}
|
|
|
|
try {
|
|
operation = &py::cast<PyOpView &>(obj).getOperation();
|
|
} catch (py::cast_error &) {
|
|
// Do nothing.
|
|
}
|
|
|
|
if (operation != nullptr) {
|
|
if (!mlirOperationImplementsInterface(*operation,
|
|
ConcreteIface::getInterfaceID())) {
|
|
std::string msg = "the operation does not implement ";
|
|
throw py::value_error(msg + ConcreteIface::pyClassName);
|
|
}
|
|
|
|
MlirIdentifier identifier = mlirOperationGetName(*operation);
|
|
MlirStringRef stringRef = mlirIdentifierStr(identifier);
|
|
opName = std::string(stringRef.data, stringRef.length);
|
|
} else {
|
|
try {
|
|
opName = obj.attr("OPERATION_NAME").template cast<std::string>();
|
|
} catch (py::cast_error &) {
|
|
throw py::type_error(
|
|
"Op interface does not refer to an operation or OpView class");
|
|
}
|
|
|
|
if (!mlirOperationImplementsInterfaceStatic(
|
|
mlirStringRefCreate(opName.data(), opName.length()),
|
|
context.resolve().get(), ConcreteIface::getInterfaceID())) {
|
|
std::string msg = "the operation does not implement ";
|
|
throw py::value_error(msg + ConcreteIface::pyClassName);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Creates the Python bindings for this class in the given module.
|
|
static void bind(py::module &m) {
|
|
py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
|
|
py::module_local());
|
|
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
|
|
py::arg("context") = py::none(), constructorDoc)
|
|
.def_property_readonly("operation",
|
|
&PyConcreteOpInterface::getOperationObject,
|
|
operationDoc)
|
|
.def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
|
|
opviewDoc);
|
|
ConcreteIface::bindDerived(cls);
|
|
}
|
|
|
|
/// Hook for derived classes to add class-specific bindings.
|
|
static void bindDerived(ClassTy &cls) {}
|
|
|
|
/// Returns `true` if this object was constructed from a subclass of OpView
|
|
/// rather than from an operation instance.
|
|
bool isStatic() { return operation == nullptr; }
|
|
|
|
/// Returns the operation instance from which this object was constructed.
|
|
/// Throws a type error if this object was constructed from a subclass of
|
|
/// OpView.
|
|
py::object getOperationObject() {
|
|
if (operation == nullptr) {
|
|
throw py::type_error("Cannot get an operation from a static interface");
|
|
}
|
|
|
|
return operation->getRef().releaseObject();
|
|
}
|
|
|
|
/// Returns the opview of the operation instance from which this object was
|
|
/// constructed. Throws a type error if this object was constructed form a
|
|
/// subclass of OpView.
|
|
py::object getOpView() {
|
|
if (operation == nullptr) {
|
|
throw py::type_error("Cannot get an opview from a static interface");
|
|
}
|
|
|
|
return operation->createOpView();
|
|
}
|
|
|
|
/// Returns the canonical name of the operation this interface is constructed
|
|
/// from.
|
|
const std::string &getOpName() { return opName; }
|
|
|
|
private:
|
|
PyOperation *operation = nullptr;
|
|
std::string opName;
|
|
py::object obj;
|
|
};
|
|
|
|
/// Python wrapper for InferTypeOpInterface. This interface has only static
|
|
/// methods.
|
|
class PyInferTypeOpInterface
|
|
: public PyConcreteOpInterface<PyInferTypeOpInterface> {
|
|
public:
|
|
using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
|
|
|
|
constexpr static const char *pyClassName = "InferTypeOpInterface";
|
|
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
|
&mlirInferTypeOpInterfaceTypeID;
|
|
|
|
/// C-style user-data structure for type appending callback.
|
|
struct AppendResultsCallbackData {
|
|
std::vector<PyType> &inferredTypes;
|
|
PyMlirContext &pyMlirContext;
|
|
};
|
|
|
|
/// Appends the types provided as the two first arguments to the user-data
|
|
/// structure (expects AppendResultsCallbackData).
|
|
static void appendResultsCallback(intptr_t nTypes, MlirType *types,
|
|
void *userData) {
|
|
auto *data = static_cast<AppendResultsCallbackData *>(userData);
|
|
data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
|
|
for (intptr_t i = 0; i < nTypes; ++i) {
|
|
data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
|
|
}
|
|
}
|
|
|
|
/// Given the arguments required to build an operation, attempts to infer its
|
|
/// return types. Throws value_error on failure.
|
|
std::vector<PyType>
|
|
inferReturnTypes(std::optional<py::list> operandList,
|
|
std::optional<PyAttribute> attributes, void *properties,
|
|
std::optional<std::vector<PyRegion>> regions,
|
|
DefaultingPyMlirContext context,
|
|
DefaultingPyLocation location) {
|
|
llvm::SmallVector<MlirValue> mlirOperands =
|
|
wrapOperands(std::move(operandList));
|
|
llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
|
|
|
|
std::vector<PyType> inferredTypes;
|
|
PyMlirContext &pyContext = context.resolve();
|
|
AppendResultsCallbackData data{inferredTypes, pyContext};
|
|
MlirStringRef opNameRef =
|
|
mlirStringRefCreate(getOpName().data(), getOpName().length());
|
|
MlirAttribute attributeDict =
|
|
attributes ? attributes->get() : mlirAttributeGetNull();
|
|
|
|
MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
|
|
opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
|
|
mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
|
|
mlirRegions.data(), &appendResultsCallback, &data);
|
|
|
|
if (mlirLogicalResultIsFailure(result)) {
|
|
throw py::value_error("Failed to infer result types");
|
|
}
|
|
|
|
return inferredTypes;
|
|
}
|
|
|
|
static void bindDerived(ClassTy &cls) {
|
|
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
|
|
py::arg("operands") = py::none(),
|
|
py::arg("attributes") = py::none(),
|
|
py::arg("properties") = py::none(), py::arg("regions") = py::none(),
|
|
py::arg("context") = py::none(), py::arg("loc") = py::none(),
|
|
inferReturnTypesDoc);
|
|
}
|
|
};
|
|
|
|
/// Wrapper around an shaped type components.
|
|
class PyShapedTypeComponents {
|
|
public:
|
|
PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
|
|
PyShapedTypeComponents(py::list shape, MlirType elementType)
|
|
: shape(std::move(shape)), elementType(elementType), ranked(true) {}
|
|
PyShapedTypeComponents(py::list shape, MlirType elementType,
|
|
MlirAttribute attribute)
|
|
: shape(std::move(shape)), elementType(elementType), attribute(attribute),
|
|
ranked(true) {}
|
|
PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
|
|
PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept
|
|
: shape(other.shape), elementType(other.elementType),
|
|
attribute(other.attribute), ranked(other.ranked) {}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
|
|
py::module_local())
|
|
.def_property_readonly(
|
|
"element_type",
|
|
[](PyShapedTypeComponents &self) { return self.elementType; },
|
|
"Returns the element type of the shaped type components.")
|
|
.def_static(
|
|
"get",
|
|
[](PyType &elementType) {
|
|
return PyShapedTypeComponents(elementType);
|
|
},
|
|
py::arg("element_type"),
|
|
"Create an shaped type components object with only the element "
|
|
"type.")
|
|
.def_static(
|
|
"get",
|
|
[](py::list shape, PyType &elementType) {
|
|
return PyShapedTypeComponents(std::move(shape), elementType);
|
|
},
|
|
py::arg("shape"), py::arg("element_type"),
|
|
"Create a ranked shaped type components object.")
|
|
.def_static(
|
|
"get",
|
|
[](py::list shape, PyType &elementType, PyAttribute &attribute) {
|
|
return PyShapedTypeComponents(std::move(shape), elementType,
|
|
attribute);
|
|
},
|
|
py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
|
|
"Create a ranked shaped type components object with attribute.")
|
|
.def_property_readonly(
|
|
"has_rank",
|
|
[](PyShapedTypeComponents &self) -> bool { return self.ranked; },
|
|
"Returns whether the given shaped type component is ranked.")
|
|
.def_property_readonly(
|
|
"rank",
|
|
[](PyShapedTypeComponents &self) -> py::object {
|
|
if (!self.ranked) {
|
|
return py::none();
|
|
}
|
|
return py::int_(self.shape.size());
|
|
},
|
|
"Returns the rank of the given ranked shaped type components. If "
|
|
"the shaped type components does not have a rank, None is "
|
|
"returned.")
|
|
.def_property_readonly(
|
|
"shape",
|
|
[](PyShapedTypeComponents &self) -> py::object {
|
|
if (!self.ranked) {
|
|
return py::none();
|
|
}
|
|
return py::list(self.shape);
|
|
},
|
|
"Returns the shape of the ranked shaped type components as a list "
|
|
"of integers. Returns none if the shaped type component does not "
|
|
"have a rank.");
|
|
}
|
|
|
|
pybind11::object getCapsule();
|
|
static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
|
|
|
|
private:
|
|
py::list shape;
|
|
MlirType elementType;
|
|
MlirAttribute attribute;
|
|
bool ranked{false};
|
|
};
|
|
|
|
/// Python wrapper for InferShapedTypeOpInterface. This interface has only
|
|
/// static methods.
|
|
class PyInferShapedTypeOpInterface
|
|
: public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
|
|
public:
|
|
using PyConcreteOpInterface<
|
|
PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
|
|
|
|
constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
|
|
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
|
&mlirInferShapedTypeOpInterfaceTypeID;
|
|
|
|
/// C-style user-data structure for type appending callback.
|
|
struct AppendResultsCallbackData {
|
|
std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
|
|
};
|
|
|
|
/// Appends the shaped type components provided as unpacked shape, element
|
|
/// type, attribute to the user-data.
|
|
static void appendResultsCallback(bool hasRank, intptr_t rank,
|
|
const int64_t *shape, MlirType elementType,
|
|
MlirAttribute attribute, void *userData) {
|
|
auto *data = static_cast<AppendResultsCallbackData *>(userData);
|
|
if (!hasRank) {
|
|
data->inferredShapedTypeComponents.emplace_back(elementType);
|
|
} else {
|
|
py::list shapeList;
|
|
for (intptr_t i = 0; i < rank; ++i) {
|
|
shapeList.append(shape[i]);
|
|
}
|
|
data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
|
|
attribute);
|
|
}
|
|
}
|
|
|
|
/// Given the arguments required to build an operation, attempts to infer the
|
|
/// shaped type components. Throws value_error on failure.
|
|
std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
|
|
std::optional<py::list> operandList,
|
|
std::optional<PyAttribute> attributes, void *properties,
|
|
std::optional<std::vector<PyRegion>> regions,
|
|
DefaultingPyMlirContext context, DefaultingPyLocation location) {
|
|
llvm::SmallVector<MlirValue> mlirOperands =
|
|
wrapOperands(std::move(operandList));
|
|
llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(std::move(regions));
|
|
|
|
std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
|
|
PyMlirContext &pyContext = context.resolve();
|
|
AppendResultsCallbackData data{inferredShapedTypeComponents};
|
|
MlirStringRef opNameRef =
|
|
mlirStringRefCreate(getOpName().data(), getOpName().length());
|
|
MlirAttribute attributeDict =
|
|
attributes ? attributes->get() : mlirAttributeGetNull();
|
|
|
|
MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
|
|
opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
|
|
mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
|
|
mlirRegions.data(), &appendResultsCallback, &data);
|
|
|
|
if (mlirLogicalResultIsFailure(result)) {
|
|
throw py::value_error("Failed to infer result shape type components");
|
|
}
|
|
|
|
return inferredShapedTypeComponents;
|
|
}
|
|
|
|
static void bindDerived(ClassTy &cls) {
|
|
cls.def("inferReturnTypeComponents",
|
|
&PyInferShapedTypeOpInterface::inferReturnTypeComponents,
|
|
py::arg("operands") = py::none(),
|
|
py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
|
|
py::arg("properties") = py::none(), py::arg("context") = py::none(),
|
|
py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
|
|
}
|
|
};
|
|
|
|
void populateIRInterfaces(py::module &m) {
|
|
PyInferTypeOpInterface::bind(m);
|
|
PyShapedTypeComponents::bind(m);
|
|
PyInferShapedTypeOpInterface::bind(m);
|
|
}
|
|
|
|
} // namespace python
|
|
} // namespace mlir
|