structSimplifyRedundantTranspose:publicmlir::OpRewritePattern<TransposeOp>{/// We register this pattern to match every toy.transpose in the IR./// The "benefit" is used by the framework to order the patterns and process/// them in order of profitability.SimplifyRedundantTranspose(mlir::MLIRContext*context):OpRewritePattern<TransposeOp>(context,/*benefit=*/1){}/// This method attempts to match a pattern and rewrite it. The rewriter/// argument is the orchestrator of the sequence of rewrites. The pattern is/// expected to interact with it to perform any changes to the IR from here.mlir::LogicalResultmatchAndRewrite(TransposeOpop,mlir::PatternRewriter&rewriter)constoverride{std::cout<<"call matchAndRewrite transpose op "<<op.getOperationName().str()<<"\n";// Look through the input of the current transpose.mlir::ValuetransposeInput=op.getOperand();TransposeOptransposeInputOp=transposeInput.getDefiningOp<TransposeOp>();// Input defined by another transpose? If not, no match.if(!transposeInputOp)returnfailure();op.emitWarning()<<"arrive here"<<"\n";// Otherwise, we have a redundant transpose. Use the rewriter.rewriter.replaceOp(op,{transposeInputOp.getOperand()});returnsuccess();}};
// "toyc.cpp"if(enableOpt){mlir::PassManagerpm(module.get()->getName());// Apply any generic pass manager command line options and run the pipeline.if(mlir::failed(mlir::applyPassManagerCLOptions(pm)))return4;// Add a run of the canonicalizer to optimize the mlir module.pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());if(mlir::failed(pm.run(*module)))return4;}
// Register our patterns as "canonicalization" patterns on the TransposeOp so// that they can be picked up by the Canonicalization framework.voidTransposeOp::getCanonicalizationPatterns(RewritePatternSet&results,MLIRContext*context){results.add<SimplifyRedundantTranspose>(context);}
structReshapeReshapeOptPattern:public::mlir::RewritePattern{ReshapeReshapeOptPattern(::mlir::MLIRContext*context):::mlir::RewritePattern("toy.reshape",2,context,{"toy.reshape"}){}::mlir::LogicalResultmatchAndRewrite(::mlir::Operation*op0,::mlir::PatternRewriter&rewriter)constoverride{// Variables for capturing values and attributes used while creating ops::mlir::Operation::operand_rangearg(op0->getOperands());::llvm::SmallVector<::mlir::Operation*,4>tblgen_ops;// Matchtblgen_ops.push_back(op0);autocastedOp0=::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op0);(void)castedOp0;{auto*op1=(*castedOp0.getODSOperands(0).begin()).getDefiningOp();if(!(op1)){returnrewriter.notifyMatchFailure(castedOp0,[&](::mlir::Diagnostic&diag){diag<<"There's no operation that defines operand 0 of castedOp0";});}autocastedOp1=::llvm::dyn_cast<::mlir::toy::ReshapeOp>(op1);(void)castedOp1;if(!(castedOp1)){returnrewriter.notifyMatchFailure(op1,[&](::mlir::Diagnostic&diag){diag<<"castedOp1 is not ::mlir::toy::ReshapeOp type";});}arg=castedOp1.getODSOperands(0);tblgen_ops.push_back(op1);}// RewriteautoodsLoc=rewriter.getFusedLoc({tblgen_ops[0]->getLoc(),tblgen_ops[1]->getLoc()});(void)odsLoc;::llvm::SmallVector<::mlir::Value,4>tblgen_repl_values;::mlir::toy::ReshapeOptblgen_ReshapeOp_0;{::llvm::SmallVector<::mlir::Value,4>tblgen_values;(void)tblgen_values;::llvm::SmallVector<::mlir::NamedAttribute,4>tblgen_attrs;(void)tblgen_attrs;tblgen_values.push_back((*arg.begin()));::llvm::SmallVector<::mlir::Type,4>tblgen_types;(void)tblgen_types;for(autov:castedOp0.getODSResults(0)){tblgen_types.push_back(v.getType());}tblgen_ReshapeOp_0=rewriter.create<::mlir::toy::ReshapeOp>(odsLoc,tblgen_types,tblgen_values,tblgen_attrs);}for(autov:::llvm::SmallVector<::mlir::Value,4>{tblgen_ReshapeOp_0.getODSResults(0)}){tblgen_repl_values.push_back(v);}rewriter.replaceOp(op0,tblgen_repl_values);return::mlir::success();};};voidLLVM_ATTRIBUTE_UNUSEDpopulateWithGenerated(::mlir::RewritePatternSet&patterns){patterns.add<FoldConstantReshapeOptPattern>(patterns.getContext());patterns.add<RedundantReshapeOptPattern>(patterns.getContext());patterns.add<ReshapeReshapeOptPattern>(patterns.getContext());}