Skip to content

Commit 677840c

Browse files
author
Simon Klix
committed
split up the python bindings and fixed minor bug for gate label creation for empty gate lists
1 parent f1352b5 commit 677840c

11 files changed

Lines changed: 57 additions & 3256 deletions

File tree

plugins/machine_learning/include/machine_learning/labels/subgraph_label.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace hal
4444
* @param[in] ctx - The machine learning context.
4545
* @returns A vector of label vectors for each pair on success, an error otherwise.
4646
*/
47-
virtual Result<std::vector<std::vector<u32>>> calculate_labels(Context& ctx, const std::vector<Gate*>& subgraphs) const = 0;
47+
virtual Result<std::vector<std::vector<u32>>> calculate_labels(Context& ctx, const std::vector<std::vector<Gate*>>& subgraphs) const = 0;
4848

4949
virtual std::string to_string() const = 0;
5050
};
@@ -74,7 +74,7 @@ namespace hal
7474
* @param[in] ctx - The machine learning context.
7575
* @returns A vector of label vectors for each pair on success, an error otherwise.
7676
*/
77-
Result<std::vector<std::vector<u32>>> calculate_labels(Context& ctx, const std::vector<Gate*>& subgraphs) const override;
77+
Result<std::vector<std::vector<u32>>> calculate_labels(Context& ctx, const std::vector<std::vector<Gate*>>& subgraphs) const override;
7878

7979
/**
8080
* @brief Helper function to annotate contained components to a falttened subgraph of gates in front of a register by analyzing a unflattened twin netlist
@@ -130,7 +130,7 @@ namespace hal
130130
u32 label_size = m_key_words.size() + 1;
131131

132132
std::vector<u32> v(label_size, 0);
133-
for (const auto [m_idx, m_count] : matches)
133+
for (const auto& [m_idx, m_count] : matches)
134134
{
135135
if (m_idx >= m_key_words.size())
136136
{
@@ -168,7 +168,7 @@ namespace hal
168168
* @param[in] ctx - The machine learning context.
169169
* @returns A vector of label vectors for each pair on success, an error otherwise.
170170
*/
171-
Result<std::vector<std::vector<u32>>> calculate_labels(Context& ctx, const std::vector<Gate*>& subgraphs) const override;
171+
Result<std::vector<std::vector<u32>>> calculate_labels(Context& ctx, const std::vector<std::vector<Gate*>>& subgraphs) const override;
172172

173173
/**
174174
* @brief Helper function to annotate contained components to top module of netlöist by reading from netlist metadata at a path
@@ -224,7 +224,7 @@ namespace hal
224224
u32 label_size = m_key_words.size() + 1;
225225

226226
std::vector<u32> v(label_size, 0);
227-
for (const auto [m_idx, m_count] : matches)
227+
for (const auto& [m_idx, m_count] : matches)
228228
{
229229
if (m_idx >= m_key_words.size())
230230
{
@@ -238,4 +238,4 @@ namespace hal
238238
};
239239
} // namespace subgraph_label
240240
} // namespace machine_learning
241-
} // namespace hal
241+
} // namespace hal

plugins/machine_learning/python/bindings/edge_feature.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace hal
88
{
99
void bind_edge_features(py::module& m, py::module& py_edge_feature)
1010
{
11+
UNUSED(m);
12+
1113
py::class_<machine_learning::edge_feature::EdgeFeature, RawPtrWrapper<machine_learning::edge_feature::EdgeFeature>> py_edge_feature_class(py_edge_feature,
1214
"EdgeFeature",
1315
R"(

plugins/machine_learning/python/bindings/gate_label.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace hal
88
{
99
void bind_gate_labels(py::module& m, py::module& py_gate_label)
1010
{
11+
UNUSED(m);
12+
1113
py::class_<machine_learning::gate_label::GateLabel, std::shared_ptr<machine_learning::gate_label::GateLabel>> py_gate_label_class(py_gate_label, "GateLabel", R"(
1214
Base class for calculating labels for machine learning models.
1315

plugins/machine_learning/python/bindings/gate_pair_feature.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace hal
88
{
99
void bind_gate_pair_features(py::module& m, py::module& py_gate_pair_feature)
1010
{
11+
UNUSED(m);
12+
1113
// machine_learning::features::gate_pair_feature
1214
py::class_<machine_learning::gate_pair_feature::GatePairFeature, std::shared_ptr<machine_learning::gate_pair_feature::GatePairFeature>> py_gate_pair_feature_class(
1315
py_gate_pair_feature, "GatePairFeature", R"(

plugins/machine_learning/python/bindings/gate_pair_label.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace hal
88
{
99
void bind_gate_pair_labels(py::module& m, py::module& py_gate_pair_label)
1010
{
11+
UNUSED(m);
12+
1113
py::class_<machine_learning::gate_pair_label::GatePairLabel> py_gate_pair_label_class(py_gate_pair_label,
1214
"GatePairLabel",
1315
R"(

plugins/machine_learning/python/bindings/core.cpp renamed to plugins/machine_learning/python/bindings/plugin.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
#include "bindings/register.h"
1+
#include "register.h"
22

33
namespace hal
44
{
55
namespace machine_learning
66
{
77
namespace python
88
{
9-
void bind_core(py::module& m)
9+
void bind_plugin(py::module& m)
1010
{
1111
py::class_<MachineLearningPlugin, RawPtrWrapper<MachineLearningPlugin>, BasePluginInterface> py_machine_learning_plugin(
1212
m, "MachineLearningPlugin", R"(Provides machine learning functionality as a plugin within the HAL framework.)");

plugins/machine_learning/python/bindings/register.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace hal
88
{
99
namespace python
1010
{
11-
void bind_core(py::module& m);
11+
void bind_plugin(py::module& m);
1212
void bind_gate_features(py::module& m, py::module& py_gate_feature);
1313
void bind_gate_pair_features(py::module& m, py::module& py_gate_pair_feature);
1414
void bind_gate_pair_labels(py::module& m, py::module& py_gate_pair_label);

plugins/machine_learning/python/bindings/subgraph_label.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace hal
88
{
99
void bind_subgraph_labels(py::module& m, py::module& py_subgraph_label)
1010
{
11+
UNUSED(m);
12+
1113
py::class_<machine_learning::subgraph_label::SubgraphLabel, std::shared_ptr<machine_learning::subgraph_label::SubgraphLabel>> py_subgraph_label_class(py_subgraph_label, "SubgraphLabel", R"(
1214
Base class for calculating labels for machine learning models.
1315
@@ -39,7 +41,7 @@ namespace hal
3941

4042
py_subgraph_label_class.def(
4143
"calculate_labels",
42-
[](const machine_learning::subgraph_label::SubgraphLabel& self, machine_learning::Context& ctx, const std::vector<Gate*>& subgraphs) -> std::optional<std::vector<std::vector<u32>>> {
44+
[](const machine_learning::subgraph_label::SubgraphLabel& self, machine_learning::Context& ctx, const std::vector<std::vector<Gate*>>& subgraphs) -> std::optional<std::vector<std::vector<u32>>> {
4345
auto res = self.calculate_labels(ctx, subgraphs);
4446
if (res.is_ok())
4547
{
@@ -111,7 +113,7 @@ Construct a ContainedComponents labeler.
111113

112114
py_contained_components.def(
113115
"calculate_labels",
114-
[](const machine_learning::subgraph_label::SubgraphLabel& self, machine_learning::Context& ctx, const std::vector<Gate*>& subgraphs) -> std::optional<std::vector<std::vector<u32>>> {
116+
[](const machine_learning::subgraph_label::SubgraphLabel& self, machine_learning::Context& ctx, const std::vector<std::vector<Gate*>>& subgraphs) -> std::optional<std::vector<std::vector<u32>>> {
115117
auto res = self.calculate_labels(ctx, subgraphs);
116118
if (res.is_ok())
117119
{
@@ -212,7 +214,7 @@ Construct a ContainedComponentsNetlist labeler.
212214

213215
py_contained_components_netlist.def(
214216
"calculate_labels",
215-
[](const machine_learning::subgraph_label::SubgraphLabel& self, machine_learning::Context& ctx, const std::vector<Gate*>& subgraphs) -> std::optional<std::vector<std::vector<u32>>> {
217+
[](const machine_learning::subgraph_label::SubgraphLabel& self, machine_learning::Context& ctx, const std::vector<std::vector<Gate*>>& subgraphs) -> std::optional<std::vector<std::vector<u32>>> {
216218
auto res = self.calculate_labels(ctx, subgraphs);
217219
if (res.is_ok())
218220
{

0 commit comments

Comments
 (0)