diff --git a/include/souper/Infer/InstSynthesis.h b/include/souper/Infer/InstSynthesis.h index 6c03db5bc..7299e6106 100644 --- a/include/souper/Infer/InstSynthesis.h +++ b/include/souper/Infer/InstSynthesis.h @@ -70,10 +70,13 @@ typedef std::pair LocVar; typedef std::pair LocInst; /// A component is a fixed-width instruction kind +/// or created from Origin struct Component { Inst::Kind Kind; unsigned Width; std::vector OpWidths; + Inst *Origin; + std::vector OriginOps; }; /// Unsupported components kinds @@ -94,35 +97,35 @@ static const std::set UnsupportedCompKinds = { /// a component of that width is instantiated. /// Again, note that constants are treated as ordinary inputs static const std::vector CompLibrary = { - Component{Inst::Add, 0, {0,0}}, - Component{Inst::Sub, 0, {0,0}}, - Component{Inst::Mul, 0, {0,0}}, - Component{Inst::UDiv, 0, {0,0}}, - Component{Inst::SDiv, 0, {0,0}}, - Component{Inst::UDivExact, 0, {0,0}}, - Component{Inst::SDivExact, 0, {0,0}}, - Component{Inst::URem, 0, {0,0}}, - Component{Inst::SRem, 0, {0,0}}, - Component{Inst::And, 0, {0,0}}, - Component{Inst::Or, 0, {0,0}}, - Component{Inst::Xor, 0, {0,0}}, - Component{Inst::Shl, 0, {0,0}}, - Component{Inst::LShr, 0, {0,0}}, - Component{Inst::LShrExact, 0, {0,0}}, - Component{Inst::AShr, 0, {0,0}}, - Component{Inst::AShrExact, 0, {0,0}}, - Component{Inst::Select, 0, {1,0,0}}, - Component{Inst::Eq, 1, {0,0}}, - Component{Inst::Ne, 1, {0,0}}, - Component{Inst::Ult, 1, {0,0}}, - Component{Inst::Slt, 1, {0,0}}, - Component{Inst::Ule, 1, {0,0}}, - Component{Inst::Sle, 1, {0,0}}, + Component{Inst::Add, 0, {0,0}, 0, {}}, + Component{Inst::Sub, 0, {0,0}, 0, {}}, + Component{Inst::Mul, 0, {0,0}, 0, {}}, + Component{Inst::UDiv, 0, {0,0}, 0, {}}, + Component{Inst::SDiv, 0, {0,0}, 0, {}}, + Component{Inst::UDivExact, 0, {0,0}, 0, {}}, + Component{Inst::SDivExact, 0, {0,0}, 0, {}}, + Component{Inst::URem, 0, {0,0}, 0, {}}, + Component{Inst::SRem, 0, {0,0}, 0, {}}, + Component{Inst::And, 0, {0,0}, 0, {}}, + Component{Inst::Or, 0, {0,0}, 0, {}}, + Component{Inst::Xor, 0, {0,0}, 0, {}}, + Component{Inst::Shl, 0, {0,0}, 0, {}}, + Component{Inst::LShr, 0, {0,0}, 0, {}}, + Component{Inst::LShrExact, 0, {0,0}, 0, {}}, + Component{Inst::AShr, 0, {0,0}, 0, {}}, + Component{Inst::AShrExact, 0, {0,0}, 0, {}}, + Component{Inst::Select, 0, {1,0,0}, 0, {}}, + Component{Inst::Eq, 1, {0,0}, 0, {}}, + Component{Inst::Ne, 1, {0,0}, 0, {}}, + Component{Inst::Ult, 1, {0,0}, 0, {}}, + Component{Inst::Slt, 1, {0,0}, 0, {}}, + Component{Inst::Ule, 1, {0,0}, 0, {}}, + Component{Inst::Sle, 1, {0,0}, 0, {}}, // - Component{Inst::CtPop, 0, {0}}, - Component{Inst::BSwap, 0, {0}}, - Component{Inst::Cttz, 0, {0}}, - Component{Inst::Ctlz, 0, {0}} + Component{Inst::CtPop, 0, {0}, 0, {}}, + Component{Inst::BSwap, 0, {0}, 0, {}}, + Component{Inst::Cttz, 0, {0}, 0, {}}, + Component{Inst::Ctlz, 0, {0}, 0, {}} }; class InstSynthesis { @@ -132,6 +135,7 @@ class InstSynthesis { const BlockPCs &BPCs, const std::vector &PCs, Inst *TargetLHS, Inst *&RHS, + const std::vector &LHSComps, InstContext &IC, unsigned Timeout); private: @@ -139,6 +143,7 @@ class InstSynthesis { SMTLIBSolver *LSMTSolver; const BlockPCs *LBPCs; const std::vector *LPCs; + const std::vector *LLHSComps; InstContext *LIC; unsigned LTimeout; @@ -291,6 +296,7 @@ class InstSynthesis { /// Helper functions void filterFixedWidthIntrinsicComps(); + Component getCompFromInst(Inst *); void getInputVars(Inst *I, std::vector &InputVars); std::string getLocVarStr(const LocVar &Loc, const std::string Prefix=""); LocVar getLocVarFromStr(const std::string &Str); @@ -318,7 +324,7 @@ class InstSynthesis { }; void findCands(Inst *Root, std::vector &Guesses, InstContext &IC, - int Max); + bool WidthMustMatch, bool FilterVars, int Max); Inst *getInstCopy(Inst *I, InstContext &IC, std::map &InstCache, diff --git a/lib/Extractor/Solver.cpp b/lib/Extractor/Solver.cpp index 33eecd868..f4d18a549 100644 --- a/lib/Extractor/Solver.cpp +++ b/lib/Extractor/Solver.cpp @@ -48,7 +48,7 @@ static cl::opt InferNop("souper-infer-nop", static cl::opt StressNop("souper-stress-nop", cl::desc("stress-test big queries in nop synthesis by always performing all of the small queries (slow!) (default=false)"), cl::init(false)); -static cl::optMaxNops("souper-max-nops", +static cl::optMaxCands("souper-max-cands", cl::desc("maximum number of values from the LHS to try to use as the RHS (default=20)"), cl::init(20)); static cl::opt InferInts("souper-infer-iN", @@ -145,7 +145,8 @@ class BaseSolver : public Solver { if (InferNop) { std::vector Guesses; - findCands(LHS, Guesses, IC, MaxNops); + findCands(LHS, Guesses, IC, /*WidthMustMatch=*/true, /*FilterVars=*/false, + MaxCands); Inst *Ante = IC.getConst(APInt(1, true)); BlockPCs BPCsCopy; @@ -206,8 +207,12 @@ class BaseSolver : public Solver { } if (InferInsts && SMTSolver->supportsModels()) { + std::vector LHSComps; + findCands(LHS, LHSComps, IC, /*WidthMustMatch=*/false, /*FilterVars=*/true, + MaxCands); InstSynthesis IS; - EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, IC, Timeout); + EC = IS.synthesize(SMTSolver.get(), BPCs, PCs, LHS, RHS, + LHSComps, IC, Timeout); if (EC || RHS) return EC; } diff --git a/lib/Infer/InstSynthesis.cpp b/lib/Infer/InstSynthesis.cpp index 1ac03443e..cb0f2e860 100644 --- a/lib/Infer/InstSynthesis.cpp +++ b/lib/Infer/InstSynthesis.cpp @@ -59,6 +59,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver, const BlockPCs &BPCs, const std::vector &PCs, Inst *TargetLHS, Inst *&RHS, + const std::vector &LHSComps, InstContext &IC, unsigned Timeout) { std::error_code EC; @@ -66,6 +67,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver, LSMTSolver = SMTSolver; LBPCs = &BPCs; LPCs = &PCs; + LLHSComps = &LHSComps; LIC = &IC; LTimeout = Timeout; @@ -91,7 +93,7 @@ std::error_code InstSynthesis::synthesize(SMTLIBSolver *SMTSolver, if (DebugLevel > 0) { llvm::outs() << "; starting synthesis for LHS\n"; - PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context); + PrintReplacementLHS(llvm::outs(), BPCs, PCs, LHS, Context, true); if (DebugLevel > 2) printInitInfo(); } @@ -322,7 +324,7 @@ void InstSynthesis::setCompLibrary() { for (auto KindStr : splitString(CmdUserCompKinds.c_str())) { Inst::Kind K = Inst::getKind(KindStr); if (KindStr == Inst::getKindName(Inst::Const)) // Special case - InitConstComps.push_back(Component{Inst::Const, 0, {}}); + InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}}); else if (K == Inst::ZExt || K == Inst::SExt || K == Inst::Trunc) report_fatal_error("don't use zext/sext/trunc explicitly"); else if (K == Inst::None) @@ -338,13 +340,13 @@ void InstSynthesis::setCompLibrary() { InitComps.push_back(Comp); } else { InitComps = CompLibrary; - InitConstComps.push_back(Component{Inst::Const, 0, {}}); + InitConstComps.push_back(Component{Inst::Const, 0, {}, 0, {}}); } for (auto const &In : Inputs) { if (In->Width == DefaultWidth) continue; - Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}}); - Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}}); + Comps.push_back(Component{Inst::ZExt, DefaultWidth, {In->Width}, 0, {}}); + Comps.push_back(Component{Inst::SExt, DefaultWidth, {In->Width}, 0, {}}); } // Second, for each input/constant create a component of DefaultWidth for (auto &Comp : InitComps) { @@ -362,7 +364,11 @@ void InstSynthesis::setCompLibrary() { } // Third, create one trunc comp to match the output width if necessary if (LHS->Width < DefaultWidth) - Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}}); + Comps.push_back(Component{Inst::Trunc, LHS->Width, {DefaultWidth}, 0, {}}); + // Finally, add LHS components (if provided) directly to Comps, + // their widths are already initialized. + for (auto I : *LLHSComps) + Comps.push_back(getCompFromInst(I)); } void InstSynthesis::initInputVars(InstContext &IC) { @@ -438,10 +444,11 @@ void InstSynthesis::filterFixedWidthIntrinsicComps() { void InstSynthesis::initComponents(InstContext &IC) { for (unsigned J = 0; J < Comps.size(); ++J) { - auto const &Comp = Comps[J]; + auto &Comp = Comps[J]; std::string LocVarStr; // First, init component inputs std::vector CompOps; + std::map OpsReplacements; std::vector OpsLocVar; for (unsigned K = 0; K < Comp.OpWidths.size(); ++K) { LocVar In = std::make_pair(J+1, K+1); @@ -464,6 +471,11 @@ void InstSynthesis::initComponents(InstContext &IC) { CompInstMap[In] = OpInst; CompOps.push_back(OpInst); OpsLocVar.push_back(In); + // Update OpsReplacements + if (Comp.Origin) { + assert(Comp.OriginOps.size()); + OpsReplacements.insert(std::make_pair(Comp.OriginOps[K], OpInst)); + } } // Store all input locations CompOpLocVars.push_back(OpsLocVar); @@ -479,13 +491,23 @@ void InstSynthesis::initComponents(InstContext &IC) { // Third, instantiate the component (aka Inst) assert(Comp.Width && "comp width not set"); Inst *CompInst; - if (Comp.Kind == Inst::Select) { - Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]}); - CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]}); - } else { - CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps); + if (Comp.Origin) { + assert(Comp.OriginOps.size() == CompOps.size()); + CompInst = replaceVars(Comp.Origin, *LIC, OpsReplacements); if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst}); + // Update LHS component + Comp.Origin = CompInst; + Comp.OriginOps = CompOps; + } else { + if (Comp.Kind == Inst::Select) { + Inst *C = IC.getInst(Inst::Trunc, 1, {CompOps[0]}); + CompInst = IC.getInst(Comp.Kind, Comp.Width, {C, CompOps[1], CompOps[2]}); + } else { + CompInst = IC.getInst(Comp.Kind, Comp.Width, CompOps); + if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) + CompInst = IC.getInst(Inst::ZExt, DefaultWidth, {CompInst}); + } } // Update CompInstMap map with concrete Inst CompInstMap[Out] = CompInst; @@ -517,12 +539,14 @@ void InstSynthesis::printInitInfo() { llvm::outs() << "N: " << N << ", M: " << M << "\n"; llvm::outs() << "default width: " << DefaultWidth << "\n"; llvm::outs() << "output width: " << LHS->Width << "\n"; - llvm::outs() << "component library: "; + llvm::outs() << "component library: " << Comps.size() << "\n"; for (auto const &Comp : Comps) { llvm::outs() << Inst::getKindName(Comp.Kind) << " (" << Comp.Width << ", { "; for (auto const &Width : Comp.OpWidths) llvm::outs() << Width << " "; - llvm::outs() << "}); "; + llvm::outs() << "})\n"; + if (Comp.Origin) + PrintReplacementRHS(llvm::outs(), Comp.Origin, Context, true); } if (Comps.size()) llvm::outs() << "\n"; @@ -980,19 +1004,35 @@ Inst *InstSynthesis::createInstFromWiring( llvm::outs() << "- creating inst " << Inst::getKindName(Comp.Kind) << ", width " << Comp.Width << "\n"; llvm::outs() << "before junk removal:\n"; - PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops), - Context); + if (Comp.Origin) + PrintReplacementRHS(llvm::outs(), Comp.Origin, Context); + else + PrintReplacementRHS(llvm::outs(), IC.getInst(Comp.Kind, Comp.Width, Ops), + Context); } // Sanity checks if (Ops.size() == 2 && Ops[0]->K == Inst::Const && Ops[1]->K == Inst::Const) report_fatal_error("inst operands are constants!"); assert(Comp.Width == 1 || Comp.Width == DefaultWidth || Comp.Width == LHS->Width); - // Create instruction - if (Comp.Kind == Inst::Select) { + // Instruction is a LHS component + if (Comp.Origin) { + assert(Comp.OriginOps.size() == Ops.size()); + std::map OpsReplacements; + for (unsigned J = 0; J < Ops.size(); ++J) + OpsReplacements.insert(std::make_pair(Comp.OriginOps[J], Ops[J])); + Inst *Copy = replaceVars(Comp.Origin, *LIC, OpsReplacements); + // Update ops + Ops = Copy->Ops; + } + // Create instruction from a component + if (Comp.Kind == Inst::Phi) { + assert(Comp.Origin && "Phi support for LHS components only"); + return IC.getPhi(Comp.Origin->B, Ops); + } else if (Comp.Kind == Inst::Select) { Ops[0] = IC.getInst(Inst::Trunc, 1, {Ops[0]}); return createCleanInst(Comp.Kind, Comp.Width, Ops, IC); - } if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) { + } else if (Comp.Width < DefaultWidth && Comp.Kind != Inst::Trunc) { Inst *Ret = createCleanInst(Comp.Kind, Comp.Width, Ops, IC); return IC.getInst(Inst::ZExt, DefaultWidth, {Ret}); } else @@ -1214,6 +1254,18 @@ Inst *InstSynthesis::createCleanInst(Inst::Kind Kind, unsigned Width, return IC.getInst(Kind, Width, Ops); } +Component InstSynthesis::getCompFromInst(Inst *I) { + std::vector IV; + getInputVars(I, IV); + sort(IV.begin(), IV.end()); + IV.erase(unique(IV.begin(), IV.end()), IV.end()); + std::vector OpWidths; + for (auto In : IV) + OpWidths.push_back(In->Width); + + return Component{I->K, I->Width, OpWidths, I, IV}; +} + void InstSynthesis::getInputVars(Inst *I, std::vector &InputVars) { if (I->K == Inst::Var) InputVars.push_back(I); @@ -1456,7 +1508,7 @@ void InstSynthesis::constrainConstWiring(const Inst *Cand, } void findCands(Inst *Root, std::vector &Guesses, InstContext &IC, - int Max) { + bool WidthMustMatch, bool FilterVars, int Max) { // breadth-first search std::set Visited; std::queue> Q; @@ -1472,19 +1524,16 @@ void findCands(Inst *Root, std::vector &Guesses, InstContext &IC, for (auto Op : I->Ops) Q.push(std::make_tuple(Op, Benefit)); } - if (Benefit > 1 && I->Width == Root->Width && I->Available) + if (Benefit > 1 && I->Available && I->K != Inst::Const + && I->K != Inst::UntypedConst) { + if (WidthMustMatch && I->Width != Root->Width) + continue; + if (FilterVars && I->K == Inst::Var) + continue; Guesses.emplace_back(I); - // TODO: run experiments and see if it's worth doing these - if (0) { - if (Benefit > 2 && I->Width > Root->Width) - Guesses.emplace_back(IC.getInst(Inst::Trunc, Root->Width, {I})); - if (Benefit > 2 && I->Width < Root->Width) { - Guesses.emplace_back(IC.getInst(Inst::SExt, Root->Width, {I})); - Guesses.emplace_back(IC.getInst(Inst::ZExt, Root->Width, {I})); - } + if (Guesses.size() >= Max) + return; } - if (Guesses.size() >= Max) - return; } } } diff --git a/test/Infer/four-adds.opt b/test/Infer/four-adds.opt new file mode 100644 index 000000000..6bdb7b686 --- /dev/null +++ b/test/Infer/four-adds.opt @@ -0,0 +1,14 @@ +; REQUIRES: solver, solver-model + +; -souper-synthesis-comps=const is just a hack to avoid the initialization of the whole component library +; RUN: %souper-check %solver -infer-rhs -souper-infer-inst -souper-synthesis-comps=const -souper-synthesis-ignore-cost %s > %t1 +; RUN: %FileCheck %s < %t1 + +; CHECK: result %4 + +%0:i32 = var +%1:i32 = add 1:i32, %0 +%2:i32 = add 1:i32, %1 +%3:i32 = add 1:i32, %2 +%4:i32 = add 1:i32, %3 +infer %4 diff --git a/test/Infer/lhs-phi-comp.opt b/test/Infer/lhs-phi-comp.opt new file mode 100644 index 000000000..443313252 --- /dev/null +++ b/test/Infer/lhs-phi-comp.opt @@ -0,0 +1,22 @@ +; REQUIRES: solver, solver-model + +; RUN: %souper-check %solver -infer-rhs -souper-infer-nop -souper-infer-inst -souper-synthesis-comps=and,or,xor %s > %t1 +; RUN: %FileCheck %s < %t1 + +; CHECK: %10:i64 = xor %71, %72 +; CHECK-NEXT: result %10 + +%71:i64 = var +%72in:i64 = var +%1 = block 2 +%junk:i64 = var +%72 = phi %1, %72in, %junk +%200 = xor %71, -1 +%201 = xor %72, -1 +%202 = and %71, %201 +%203 = and %72, %200 +%204 = or %202, %203 +infer %204 + +; %205 = xor %71, %72 +; result %205