// RUN: mlir-opt --transform-interpreter --split-input-file %s -verify-diagnostics | FileCheck %s #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> module { // CHECK-LABEL: func.func @fuse_tileable_op // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor func.func @fuse_tileable_op(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor) -> tensor %d0 = tensor.dim %arg1, %c0 : tensor %1 = affine.apply #map0()[%d0, %arg0] // CHECK: scf.forall {{.*}} { %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor) { %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.min #map2(%arg3)[%d0, %arg0] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]] %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: } func.return %2 : tensor } // Check no failure when nothing happens. func.func @dummy1() { return } func.func @dummy2() { return } func.func @dummy3() { return } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- #map0 = affine_map<()[s0] -> (64 ceildiv s0)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)> module { // CHECK-LABEL: func.func @fuse_untileable_op // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32> // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32> func.func @fuse_untileable_op(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { %0 = tensor.empty(%arg0) : tensor %1 = affine.apply #map0()[%arg0] // CHECK: scf.forall {{.*}} { %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %arg2) -> (tensor<64xf32>) { // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.min #map2(%arg3)[%arg0] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]] %7 = linalg.elemwise_unary ins(%0 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor<64xf32> } } // CHECK: } func.return %2 : tensor<64xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 : (!transform.any_op) -> !transform.op<"tensor.empty"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // tensor.empty is not tileable. The op is cloned and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- module { func.func @foo(%0: tensor) -> tensor { return %0: tensor } // CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor %d0 = tensor.dim %arg1, %c0 : tensor // CHECK: scf.forall {{.*}} -> (tensor) { %2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor) { %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor to tensor // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor to tensor<1xf32> // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32> // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor // CHECK: func.call @foo(%{{.*}}) : (tensor) -> tensor %7 = func.call @foo(%5) : (tensor) -> tensor scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor into tensor tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor into tensor } } // CHECK: } func.return %2 : tensor } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.fill"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> module { // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor func.func @fuse_tileable_op_through_bbarg(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor %d0 = tensor.dim %arg1, %c0 : tensor %1 = affine.apply #map0()[%d0, %arg0] // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor) { %2 = scf.forall (%arg3) in (%1) shared_outs(%o = %0) -> (tensor) { %3 = affine.apply #map1(%arg3)[%arg0] %4 = affine.min #map2(%arg3)[%d0, %arg0] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}] // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] %6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor to tensor // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]] %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: } func.return %2 : tensor } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op // linalg.fill is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> module { // CHECK-LABEL: func.func @fuse_tileable_multi_output_op // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) -> tensor { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %0:2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%in : tensor) outs(%out_1, %out_3 : tensor, tensor) { ^bb0(%a: f32, %b: f32, %c: f32): %d = arith.addf %a, %b : f32 %e = arith.addf %d, %c : f32 linalg.yield %d, %e : f32, f32 } -> (tensor, tensor) %d0 = tensor.dim %out_1, %c0 : tensor %1 = affine.apply #map0()[%d0, %idx] // CHECK: scf.forall {{.*}} { %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { %3 = affine.apply #map1(%i)[%idx] %4 = affine.min #map2(%i)[%d0, %idx] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}] // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]] %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor to tensor // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0 %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: } func.return %2 : tensor } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- module { // CHECK-LABEL: func.func @fuse_repeated func.func @fuse_repeated(%fill: tensor<2xf32>, %output: tensor<2xf32>) -> tensor<2xf32> { %c0 = arith.constant 0.0 : f32 %0 = linalg.fill ins(%c0 : f32) outs(%fill : tensor<2xf32>) -> tensor<2xf32> // CHECK: scf.forall %1 = scf.forall (%i) in (2) shared_outs(%arg1 = %output) -> (tensor<2xf32>) { %2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32> %3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32> // CHECK: %[[FUSED:.+]] = linalg.fill // CHECK: elemwise_unary ins(%[[FUSED]] %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32> } } return %1 : tensor<2xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op // Create a new handle that points to `linalg.fill` twice. %2 = transform.merge_handles %0, %0 : !transform.any_op // It shouldn't be a problem to fuse this handle. transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> module { // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) -> (tensor, tensor, tensor) { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index // CHECK: %[[G0:.*]]:2 = linalg.generic %0:2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%in : tensor) outs(%out_1, %out_3 : tensor, tensor) { ^bb0(%a: f32, %b: f32, %c: f32): %d = arith.addf %a, %b : f32 %e = arith.addf %d, %c : f32 linalg.yield %d, %e : f32, f32 } -> (tensor, tensor) %d0 = tensor.dim %out_1, %c0 : tensor %1 = affine.apply #map0()[%d0, %idx] // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) // CHECK-SAME: -> (tensor, tensor) { // expected-remark @below{{new containing op}} %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { // CHECK: %[[I0:.*]] = affine.apply {{.*}} %3 = affine.apply #map1(%i)[%idx] // CHECK: %[[I1:.*]] = affine.min {{.*}} %4 = affine.min #map2(%i)[%d0, %idx] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor to tensor %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1 func.return %2, %0#0, %0#1 : tensor, tensor, tensor // CHECK: } } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.generic is tileable. The op is tiled and fused. %fused, %containing = transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.debug.emit_remark_at %containing, "new containing op" : !transform.any_op transform.yield } } } // ----- #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> module { // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) -> (tensor, tensor) { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index // CHECK: %[[G0:.*]] = linalg.generic %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"] } ins(%in : tensor) outs(%out_1 : tensor) { ^bb0(%a: f32, %b: f32): %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } -> tensor // CHECK: %[[D0:.*]] = tensor.dim %[[G0]] %d0 = tensor.dim %0, %c0 : tensor %1 = affine.apply #map0()[%d0, %idx] // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) // CHECK-SAME: -> (tensor, tensor) { %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { // CHECK: %[[I0:.*]] = affine.apply {{.*}} %3 = affine.apply #map1(%i)[%idx] // CHECK: %[[I1:.*]] = affine.min {{.*}} %4 = affine.min #map2(%i)[%d0, %idx] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T1:.*]] = linalg.generic {{.*}} %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: return %[[R0]]#0, %[[R0]]#1 func.return %2, %0 : tensor, tensor // CHECK: } } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> #map3 = affine_map<(d0, d1) -> (d0, d1)> #map4 = affine_map<(d0, d1) -> (d0)> module { // CHECK-LABEL: func.func @fuse_tileable_reductions // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor func.func @fuse_tileable_reductions(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) -> (tensor, tensor) { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %0 = linalg.generic { indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"] } ins(%in : tensor) outs(%out_1 : tensor) { ^bb0(%a: f32, %b: f32): %d = arith.maximumf %a, %b : f32 linalg.yield %d : f32 } -> tensor %d0 = tensor.dim %out_1, %c0 : tensor %1 = affine.apply #map0()[%d0, %idx] // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) // CHECK-SAME: -> (tensor, tensor) { %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor) { // CHECK: %[[I0:.*]] = affine.apply {{.*}} %3 = affine.apply #map1(%i)[%idx] // CHECK: %[[I1:.*]] = affine.min {{.*}} %4 = affine.min #map2(%i)[%d0, %idx] %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor // CHECK: %[[T1:.*]] = linalg.generic {{.*}} %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor to tensor %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor } } // CHECK: return %[[R0]]#0, %[[R0]]#1 func.return %2, %0 : tensor, tensor // CHECK: } } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> // linalg.generic is tileable. The op is tiled and fused. transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } // ----- #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> #map1 = affine_map<(d0)[s0] -> (d0 * s0)> #map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> #map3 = affine_map<(d0) -> (d0)> module { // CHECK-LABEL: func.func @fuse_tileable_using_new_handle // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor, %out_1: tensor, %out_2: tensor, %out_3: tensor) -> (tensor, tensor) { %cst = arith.constant 4.200000e+01 : f32 %c0 = arith.constant 0 : index %0 = linalg.generic { indexing_maps = [#map3, #map3], iterator_types = ["parallel"] } ins(%in : tensor) outs(%out_1 : tensor) { ^bb0(%a: f32, %b: f32): %d = arith.addf %a, %b : f32 linalg.yield %d : f32 } -> tensor %1 = linalg.generic { indexing_maps = [#map3, #map3], iterator_types = ["parallel"] } ins(%0 : tensor) outs(%out_1 : tensor) { ^bb0(%a: f32, %b: f32): %d = arith.mulf %a, %b : f32 linalg.yield %d : f32 } -> tensor %d0 = tensor.dim %out_1, %c0 : tensor %2 = affine.apply #map0()[%d0, %idx] // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]]) // CHECK-SAME: -> (tensor, tensor) { %3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor) { // CHECK: %[[I0:.*]] = affine.apply {{.*}} %4 = affine.apply #map1(%i)[%idx] // CHECK: %[[I1:.*]] = affine.min {{.*}} %5 = affine.min #map2(%i)[%d0, %idx] %6 = tensor.extract_slice %o[%4] [%5] [1] : tensor to tensor // CHECK: %[[T1:.*]] = linalg.generic {{.*}} // CHECK: %[[T2:.*]] = linalg.generic {{.*}} %7 = tensor.extract_slice %1[%4] [%5] [1] : tensor to tensor %8 = linalg.elemwise_unary ins(%7 : tensor) outs(%6 : tensor) -> tensor scf.forall.in_parallel { // CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor into tensor tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor into tensor } } // CHECK: return %[[R0]]#0, %[[R0]]#1 func.return %3, %1 : tensor, tensor // CHECK: } } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> %add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">) %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> %fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) %fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">) transform.yield } } } // ----- // This is a regression test. Make sure that the transform succeeds and valid // IR is generated. module { // CHECK-LABEL: func.func @softmax_dispatch_0_generic_16x128x128_f32 func.func @softmax_dispatch_0_generic_16x128x128_f32() -> tensor<16x128x128xf32> { %c0 = arith.constant 0 : index %cst = arith.constant dense<5.000000e+00> : tensor<16x128x128xf32> %cst_1 = arith.constant 5.000000e+00 : f32 %1 = tensor.empty() : tensor<16x128xf32> %2 = tensor.empty() : tensor<16x128x128xf32> %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32> %4 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<16x128xf32>) -> tensor<16x128xf32> %5 = linalg.generic {producer, indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cst : tensor<16x128x128xf32>) outs(%4 : tensor<16x128xf32>) { ^bb0(%in: f32, %out: f32): %8 = arith.maximumf %in, %out : f32 linalg.yield %8 : f32 } -> tensor<16x128xf32> %c16 = arith.constant 16 : index %c32 = arith.constant 32 : index %7 = scf.forall (%arg0, %arg1) in (16, 32) shared_outs(%arg2 = %2) -> (tensor<16x128x128xf32>) { %11 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg1) %extracted_slice = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> %extracted_slice_3 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> %extracted_slice_4 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> %15:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_3, %extracted_slice_4 : tensor<1x4x128xf32>, tensor<1x4xf32>) { ^bb0(%in: f32, %out: f32, %out_9: f32): %22 = arith.subf %cst_1, %in : f32 %23 = math.exp %22 : f32 %24 = arith.addf %23, %out_9 : f32 linalg.yield %23, %24 : f32, f32 } -> (tensor<1x4x128xf32>, tensor<1x4xf32>) %extracted_slice_5 = tensor.extract_slice %5[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> %extracted_slice_6 = tensor.extract_slice %2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> %extracted_slice_7 = tensor.extract_slice %3[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> %19:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%extracted_slice_5 : tensor<1x4xf32>) outs(%extracted_slice_6, %extracted_slice_7 : tensor<1x4x128xf32>, tensor<1x4xf32>) { ^bb0(%in: f32, %out: f32, %out_9: f32): %22 = arith.subf %cst_1, %in : f32 %23 = math.exp %22 : f32 %24 = arith.addf %23, %out_9 : f32 linalg.yield %23, %24 : f32, f32 } -> (tensor<1x4x128xf32>, tensor<1x4xf32>) %extracted_slice_8 = tensor.extract_slice %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<16x128x128xf32> to tensor<1x4x128xf32> %20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15#0, %19#1 : tensor<1x4x128xf32>, tensor<1x4xf32>) outs(%extracted_slice_8 : tensor<1x4x128xf32>) { ^bb0(%in: f32, %in_9: f32, %out: f32): %22 = arith.divf %in, %in_9 : f32 linalg.yield %22 : f32 } -> tensor<1x4x128xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %20 into %arg2[%arg0, %11, 0] [1, 4, 128] [1, 1, 1] : tensor<1x4x128xf32> into tensor<16x128x128xf32> } } return %7 : tensor<16x128x128xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match attributes{producer} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic"> %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall"> transform.structured.fuse_into_containing_op %0 into %1 : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op) transform.yield } } } //////////////////////////////////////////////////////////////////////////////// // Tests below are expected to fail. //////////////////////////////////////////////////////////////////////////////// // ----- // NO-CHECK-LABEL-ON-EXPECTED-ERROR func.func @copy_1d_1024xf16(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2 : tensor<123x789xf32>) -> tensor<123x789xf32> { %0 = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%0 : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32> // expected-note @below {{containing op}} %2 = linalg.matmul ins(%arg0, %arg1 : tensor<123x456xf32>, tensor<456x789xf32>) outs(%1 : tensor<123x789xf32>) -> tensor<123x789xf32> return %2 : tensor<123x789xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %tiled_op, %forall_op = transform.structured.tile_using_forall %1 num_threads [] tile_sizes [50, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) // Note that we pass in %tiled_op, which isn't a container op. // expected-error @+2 {{could not find next producer to fuse into container}} %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %0 into %tiled_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } }