Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CMake_jll = "3f4e10e2-61f2-5801-8945-23b9d642d0e6"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
106 changes: 71 additions & 35 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,93 @@
# build a local version of mlir-jl-tblgen

using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()
using Pkg, Scratch, Preferences, CMake_jll, ArgParse

if haskey(ENV, "GITHUB_ACTIONS")
println(
"::warning ::Using a locally-built mlir-jl-tblgen; A bump of mlir_jl_tblgen_jll will be required before releasing MLIR.jl.",
)
end

using Pkg, Scratch, Preferences, Libdl, CMake_jll

MLIR = Base.UUID("bfde9dd4-8f40-4a1e-be09-1475335e1c92")

# get scratch directories
scratch_dir = get_scratch!(MLIR, "build")
isdir(scratch_dir) && rm(scratch_dir; recursive=true)
s = ArgParseSettings()
#! format: off
@add_arg_table s begin
"--build-dir"
help = "Build directory."
arg_type = String
default = mktempdir()
"--install-dir"
help = "Scratch directory for installation."
arg_type = String
default = get_scratch!(MLIR, "build")
"--llvm-version"
help = "Target LLVM/MLIR version."
arg_type = String
default = "$(Base.libllvm_version.major).$(Base.libllvm_version.minor)"
"--llvm-assertions"
help = "Build LLVM/MLIR with assertions enabled."
arg_type = Bool
default = try
cglobal((:_ZN4llvm24DisableABIBreakingChecksE, Base.libllvm_path()), Cvoid)
false
catch
true
end
"--llvm-dir"
help = "Path to LLVM installation"
arg_type = String
"--debug"
help = "Build with debug symbols."
action = :store_true
end
#! format: on

parsed_args = parse_args(ARGS, s)

println("Parsed args:")
for (k, v) in parsed_args
println(" $k => $v")
end
println()

source_dir = joinpath(@__DIR__, "tblgen")
llvm_version = VersionNumber(parsed_args["llvm-version"])
llvm_assertions = parsed_args["llvm-assertions"]
debug = parsed_args["debug"]

# get build directory
build_dir = if isempty(ARGS)
mktempdir()
else
ARGS[1]
install_dir = parsed_args["install-dir"]
if isdir(install_dir)
println("Removing existing installation at $install_dir")
rm(install_dir; recursive=true)
end

build_dir = parsed_args["build-dir"]
mkpath(build_dir)

# download LLVM
Pkg.activate(; temp=true)
llvm_assertions = try
cglobal((:_ZN4llvm24DisableABIBreakingChecksE, Base.libllvm_path()), Cvoid)
false
catch
true
end
llvm_pkg_version = "$(Base.libllvm_version.major).$(Base.libllvm_version.minor)"
LLVM = if llvm_assertions
Pkg.add(; name="LLVM_full_assert_jll", version=llvm_pkg_version)
using LLVM_full_assert_jll
LLVM_full_assert_jll
# download LLVM and MLIR artifacts if required
llvm_dir = if !isnothing(parsed_args["llvm-dir"])
parsed_args["llvm-dir"]
else
Pkg.add(; name="LLVM_full_jll", version=llvm_pkg_version)
using LLVM_full_jll
LLVM_full_jll
Pkg.activate(; temp=true)

LLVM = if llvm_assertions
Pkg.add(; name="LLVM_full_assert_jll", version=llvm_version)
Base.require(Core.Module(:LLVM_full_assert_jll), :LLVM_full_assert_jll)
else
Pkg.add(; name="LLVM_full_jll", version=llvm_version)
Base.require(Core.Module(:LLVM_full_jll), :LLVM_full_jll)
end
LLVM.artifact_dir
end
LLVM_DIR = joinpath(LLVM.artifact_dir, "lib", "cmake", "llvm")
MLIR_DIR = joinpath(LLVM.artifact_dir, "lib", "cmake", "mlir")

LLVM_DIR = joinpath(llvm_dir, "lib", "cmake", "llvm")
MLIR_DIR = joinpath(llvm_dir, "lib", "cmake", "mlir")

# build and install
@info "Building" source_dir scratch_dir build_dir LLVM_DIR MLIR_DIR
@info "Building" source_dir install_dir build_dir LLVM_DIR MLIR_DIR
cmake() do cmake_path
config_opts = `-DLLVM_ROOT=$(LLVM_DIR) -DMLIR_ROOT=$(MLIR_DIR) -DCMAKE_INSTALL_PREFIX=$(scratch_dir)`
build_type = debug ? "Debug" : "Release"
config_opts = `-DLLVM_ROOT=$(LLVM_DIR) -DMLIR_ROOT=$(MLIR_DIR) -DCMAKE_INSTALL_PREFIX=$(install_dir)`
if Sys.iswindows()
# prevent picking up MSVC
config_opts = `$config_opts -G "MSYS Makefiles"`
Expand All @@ -60,7 +96,7 @@ cmake() do cmake_path
run(`$cmake_path --build $(build_dir) --target install`)
end

bin_path = joinpath(scratch_dir, "bin", only(readdir(joinpath(scratch_dir, "bin"))))
bin_path = joinpath(install_dir, "bin", only(readdir(joinpath(install_dir, "bin"))))
isfile(bin_path) || error("Could not find executable $bin_path in build directory")

# tell LLVM.jl to load our executable instead of the default artifact one
Expand Down
38 changes: 14 additions & 24 deletions deps/tblgen/jl-generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/FormatCommon.h"
Expand All @@ -48,14 +49,6 @@

namespace
{

llvm::cl::opt<bool> ExplainMissing(
"explain-missing",
llvm::cl::desc("Print the reason for skipping operations from output"));
llvm::cl::opt<std::string> DialectName(
"dialect-name", llvm::cl::desc("Override the inferred dialect name, used as the name for the generated Julia module."),
llvm::cl::value_desc("dialect"));

using namespace mlir;
using namespace mlir::tblgen;

Expand Down Expand Up @@ -126,13 +119,7 @@ namespace
return mlir::tblgen::Operator(op).getDialectName() ==
any_op.getDialectName();
}));
std::string dialect_name;
if (DialectName.empty()) {
dialect_name = any_op.getDialectName().str();
} else {
dialect_name = DialectName;
}
return dialect_name;
return any_op.getDialectName().str();
}

std::string sanitizeName(std::string name, std::optional<std::string> modulename = std::nullopt) {
Expand Down Expand Up @@ -161,18 +148,21 @@ namespace

} // namespace

extern bool disableModuleWrap;
extern bool isExternal;

bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper,
llvm::raw_ostream &os)
bool emitOpTableDefs(llvm::raw_ostream &os, const llvm::RecordKeeper &recordKeeper, bool disableModuleWrap, bool isExternal, std::optional<std::string> dialectName)
{

#if LLVM_VERSION_MAJOR >= 16
std::vector<llvm::Record *> opdefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Op");
auto _opdefs = recordKeeper.getAllDerivedDefinitionsIfDefined("Op");
#else
std::vector<llvm::Record *> opdefs = recordKeeper.getAllDerivedDefinitions("Op");
auto _opdefs = recordKeeper.getAllDerivedDefinitions("Op");
#endif

// LLVM 20 changed the return type to `ArrayRef<const Record*>` and the const cast away was giving me headaches
llvm::ArrayRef<llvm::Record*> opdefs(
const_cast<llvm::Record* const*>(_opdefs.data()),
_opdefs.size()
);

const char *imports;
if (isExternal)
{
Expand Down Expand Up @@ -230,9 +220,9 @@ end
std::string modulecontents = "";

std::string modulename;
if (!DialectName.empty())
if (!dialectName.has_value())
{
modulename = DialectName;
modulename = dialectName.value();
} else {
modulename = getDialectName(opdefs);
}
Expand Down
85 changes: 53 additions & 32 deletions deps/tblgen/mlir-jl-tblgen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <optional>

#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -26,45 +28,64 @@

using namespace llvm;

using generator_function = bool(const llvm::RecordKeeper& recordKeeper,
llvm::raw_ostream& os);

struct GeneratorInfo {
const char* name;
generator_function* generator;
enum ActionType {
EmitOpTableDefs,
};

extern generator_function emitOpTableDefs;
extern generator_function emitTestTableDefs;
// defined in jl-generators.cc
extern bool emitOpTableDefs(llvm::raw_ostream &os, const llvm::RecordKeeper &recordKeeper, bool disableModuleWrap, bool isExternal, std::optional<std::string> dialectName);

static std::array<GeneratorInfo, 1> generators {{
{"jl-op-defs", emitOpTableDefs},
}};
cl::opt<ActionType> generator(
"generator",
cl::desc("Generator to run"),
cl::values(clEnumValN(EmitOpTableDefs, "emit-op-table-defs",
"Emit Julia definitions for MLIR operations")),
cl::Required
);

generator_function* generator;
bool disableModuleWrap;
bool isExternal;
cl::opt<bool> disableModuleWrap(
"disable-module-wrap",
cl::desc("Disable module wrap"),
cl::init(false)
);

int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::opt<std::string> generatorOpt("generator", llvm::cl::desc("Generator to run"), cl::Required);
llvm::cl::opt<bool> disableModuleWrapOpt("disable-module-wrap", llvm::cl::desc("Disable module wrap"), cl::init(false));
llvm::cl::opt<bool> isExternalOpt("external", llvm::cl::desc("Mark the dialect as external and generate bindings accordingly"), cl::init(false));
cl::ParseCommandLineOptions(argc, argv);
for (const auto& spec : generators) {
if (generatorOpt == spec.name) {
generator = spec.generator;
cl::opt<bool> isExternal(
"external",
cl::desc("Mark the dialect as external and generate bindings accordingly"),
cl::init(false)
);

cl::opt<std::string> dialectName(
"dialect-name",
llvm::cl::desc("Override the inferred dialect name, used as the name for the generated Julia module."),
llvm::cl::value_desc("dialect")
);

#if LLVM_VERSION_MAJOR < 20
static bool MlirJuliaTablegenMain(llvm::raw_ostream &os, llvm::RecordKeeper &recordKeeper) {
#else
static bool MlirJuliaTablegenMain(llvm::raw_ostream &os, const llvm::RecordKeeper &recordKeeper) {
#endif
switch (generator) {
case EmitOpTableDefs: {
std::optional<std::string> dialectNameOpt;
if (!dialectName.empty()) dialectNameOpt = dialectName;
return emitOpTableDefs(os, recordKeeper, disableModuleWrap, isExternal, dialectNameOpt);
break;
}
default:
llvm::errs() << "Invalid generator type\n";
return true;
}
if (!generator) {
llvm::errs() << "Invalid generator type\n";
abort();
}
disableModuleWrap = disableModuleWrapOpt;
isExternal = isExternalOpt;
}

return TableGenMain(argv[0], [](raw_ostream& os, RecordKeeper &records) {
return generator(records, os);
});
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
cl::ParseCommandLineOptions(argc, argv);
#if LLVM_VERSION_MAJOR < 20
return TableGenMain(argv[0], MlirJuliaTablegenMain);
#else
std::function<bool(llvm::raw_ostream&, const llvm::RecordKeeper&)> mainFn = MlirJuliaTablegenMain;
return TableGenMain(argv[0], mainFn);
#endif
}
Loading