// RUN: transform-opt-ch4 %s --transform-interpreter --verify-diagnostics // // RUN: transform-opt-ch4 %s \ // RUN: --transform-interpreter='entry-point=__transform_main_v2' \ // RUN: --verify-diagnostics // ****************************** IMPORTANT NOTE ****************************** // // If you are changing this file, you may also need to change // mlir/docs/Tutorials/Transform accordingly. // // **************************************************************************** // Original function to optimize. func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) -> tensor<512x512xf32> { // Matrix-matrix multiplication. // expected-remark @below {{matmul}} %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> // Elementwise addition. // expected-remark @below {{elementwise binary}} %biased = linalg.elemwise_binary { fun = #linalg.binary_fn } ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> // Elementwise max with 0 (ReLU). %c0f = arith.constant 0.0 : f32 // expected-remark @below {{elementwise binary}} %relued = linalg.elemwise_binary { fun = #linalg.binary_fn } ins(%biased, %c0f : tensor<512x512xf32>, f32) outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> func.return %relued : tensor<512x512xf32> } // The module containing named sequences must have an attribute allowing them // to enable verification. module @transforms attributes { transform.with_named_sequence } { // Entry point. This takes as the only argument the root operation (typically // pass root) given to the transform interpreter. transform.named_sequence @__transform_main( %root: !transform.any_op {transform.readonly}) { // Collect operations that match the criteria specified in the named // sequence. If the named sequence fails with a silenceable failure, // silences it (the message is forwarded to the debug stream). If the named // sequence succeeds, appends its results to the results of this operation. %elemwise = transform.collect_matching @match_elemwise in %root : (!transform.any_op) -> !transform.any_op %matmul = transform.collect_matching @match_matmul in %root : (!transform.any_op) -> !transform.any_op transform.include @print_elemwise failures(propagate) (%elemwise) : (!transform.any_op) -> () transform.include @print_matmul failures(propagate) (%matmul) : (!transform.any_op) -> () transform.yield } // Alternative entry point. transform.named_sequence @__transform_main_v2( %root: !transform.any_op {transform.readonly}) { // Collect groups of operations that match the criteria specified in the // named sequence. %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op transform.include @print_elemwise failures(propagate) (%elemwise) : (!transform.any_op) -> () transform.include @print_matmul failures(propagate) (%matmul) : (!transform.any_op) -> () transform.yield } // This is a matcher sequence. It is given an operation to match and the // match is considered successful unless any nested operation produces a // failure. The values yielded by this operation will be forwarded to the // rewriter sequence on success. transform.named_sequence @match_elemwise( %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { transform.match.operation_name %entry ["linalg.elemwise_binary"] : !transform.any_op transform.yield %entry : !transform.any_op } transform.named_sequence @match_matmul( %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op transform.yield %entry : !transform.any_op } // This is an action sequence. transform.named_sequence @print_elemwise( %elemwise_binary: !transform.any_op {transform.readonly}) { transform.debug.emit_remark_at %elemwise_binary, "elementwise binary" : !transform.any_op transform.yield } transform.named_sequence @print_matmul( %matmul: !transform.any_op {transform.readonly}) { transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op transform.yield } // This is also a matcher sequence. It is similarly given an operation to // match and nested operations must succeed in order for a match to be deemed // successful. It starts matching from the last operation in the use-def chain // and goes back because each operand (use) has exactly one definition. transform.named_sequence @match_matmul_elemwise( %last: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_op, !transform.any_op) { // The last operation must be an elementwise binary. transform.match.operation_name %last ["linalg.elemwise_binary"] : !transform.any_op // Its first operand must be defined by another operation, to which we // will get a handle here. We are guaranteed that the first operand exists // because we know the operation is binary, but even in absence of such a // guarantee, this operation would have produced a silenceable failure when // `%last` does not have enough operands. %middle = transform.get_producer_of_operand %last[0] : (!transform.any_op) -> !transform.any_op // The defining operation must itself be an elementwise binary. transform.match.operation_name %middle ["linalg.elemwise_binary"] : !transform.any_op // And the first operand of that operation must be defined by yet another // operation. %matmul = transform.get_producer_of_operand %middle[0] : (!transform.any_op) -> !transform.any_op // And that operation is a matmul. transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op // We will yield the handles to the matmul and the two elementwise // operations separately. transform.yield %matmul, %middle, %last : !transform.any_op, !transform.any_op, !transform.any_op } }