diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java index 1f9abb4c180..3c33d783ab1 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java @@ -170,7 +170,7 @@ private ArrayList rewriteDefaultStatementBlock(DMLProgram prog, selectFederatedExecutionPlan(sbHop, paramMap); if(sbHop instanceof FunctionOp) { String funcName = ((FunctionOp) sbHop).getFunctionName(); - Map funcParamMap = getParamMap((FunctionOp) sbHop); + Map funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop); if ( paramMap != null && funcParamMap != null) funcParamMap.putAll(paramMap); paramMap = funcParamMap; @@ -182,22 +182,6 @@ private ArrayList rewriteDefaultStatementBlock(DMLProgram prog, return new ArrayList<>(Collections.singletonList(sb)); } - /** - * Return parameter map containing the mapping from parameter name to input hop - * for all parameters of the function hop. - * @param funcOp hop for which the mapping of parameter names to input hops are made - * @return parameter map or empty map if function has no parameters - */ - private Map getParamMap(FunctionOp funcOp){ - String[] inputNames = funcOp.getInputVariableNames(); - Map paramMap = new HashMap<>(); - if ( inputNames != null ){ - for ( int i = 0; i < funcOp.getInput().size(); i++ ) - paramMap.put(inputNames[i],funcOp.getInput(i)); - } - return paramMap; - } - /** * Set final fedouts of all hops starting from terminal hops. */ @@ -327,13 +311,21 @@ private void visitFedPlanHop(Hop currentHop, Map paramMap) { ArrayList hopRels = getFedPlans(currentHop, paramMap); // Put NONE HopRel into memo table if no FOUT or LOUT HopRels were added if(hopRels.isEmpty()) - hopRels.add(getNONEHopRel(currentHop)); + hopRels.add(getNONEHopRel(currentHop, paramMap)); addTrace(hopRels); hopRelMemo.put(currentHop, hopRels); } - private HopRel getNONEHopRel(Hop currentHop){ - HopRel noneHopRel = new HopRel(currentHop, FederatedOutput.NONE, hopRelMemo); + private ArrayList getHopInputs(Hop currentHop, Map paramMap){ + if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ) + return FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites); + else + return currentHop.getInput(); + } + + private HopRel getNONEHopRel(Hop currentHop, Map paramMap){ + ArrayList inputs = getHopInputs(currentHop, paramMap); + HopRel noneHopRel = new HopRel(currentHop, FederatedOutput.NONE, hopRelMemo, inputs); FType[] inputFType = noneHopRel.getInputDependency().stream().map(HopRel::getFType).toArray(FType[]::new); FType outputFType = getFederatedOut(currentHop, inputFType); noneHopRel.setFType(outputFType); @@ -348,9 +340,7 @@ private HopRel getNONEHopRel(Hop currentHop){ */ private ArrayList getFedPlans(Hop currentHop, Map paramMap){ ArrayList hopRels = new ArrayList<>(); - ArrayList inputHops = currentHop.getInput(); - if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ) - inputHops = getTransientInputs(currentHop, paramMap); + ArrayList inputHops = getHopInputs(currentHop, paramMap); if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) ) transientWrites.put(currentHop.getName(), currentHop); if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.FEDERATED) ) @@ -453,6 +443,8 @@ private void updateExplain(){ private void debugLog(Hop currentHop){ if ( LOG.isDebugEnabled() ){ LOG.debug("Visiting HOP: " + currentHop + " Input size: " + currentHop.getInput().size()); + if (currentHop.getPrivacy() != null) + LOG.debug(currentHop.getPrivacy()); int index = 0; for ( Hop hop : currentHop.getInput()){ if ( hop == null ) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java new file mode 100644 index 00000000000..45b711a41d6 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.sysds.hops.FunctionOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.runtime.DMLRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class FederatedPlannerUtils { + /** + * Get transient inputs from either paramMap or transientWrites. + * Inputs from paramMap has higher priority than inputs from transientWrites. + * @param currentHop hop for which inputs are read from maps + * @param paramMap of local parameters + * @param transientWrites map of transient writes + * @return inputs of currentHop + */ + public static ArrayList getTransientInputs(Hop currentHop, Map paramMap, Map transientWrites){ + Hop tWriteHop = null; + if ( paramMap != null) + tWriteHop = paramMap.get(currentHop.getName()); + if ( tWriteHop == null ) + tWriteHop = transientWrites.get(currentHop.getName()); + if ( tWriteHop == null ) + throw new DMLRuntimeException("Transient write not found for " + currentHop); + else + return new ArrayList<>(Collections.singletonList(tWriteHop)); + } + + /** + * Return parameter map containing the mapping from parameter name to input hop + * for all parameters of the function hop. + * @param funcOp hop for which the mapping of parameter names to input hops are made + * @return parameter map or empty map if function has no parameters + */ + public static Map getParamMap(FunctionOp funcOp){ + String[] inputNames = funcOp.getInputVariableNames(); + Map paramMap = new HashMap<>(); + if ( inputNames != null ){ + for ( int i = 0; i < funcOp.getInput().size(); i++ ) + paramMap.put(inputNames[i],funcOp.getInput(i)); + } + return paramMap; + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java new file mode 100644 index 00000000000..82e43169887 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.sysds.api.DMLException; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.FunctionOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.Statement; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; +import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.privacy.DMLPrivacyException; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator; +import org.apache.sysds.utils.JSONHelper; +import org.apache.wink.json4j.JSONObject; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; + +public class PrivacyConstraintLoader { + + private final Map memo = new HashMap<>(); + private final Map transientWrites = new HashMap<>(); + + public void loadConstraints(DMLProgram prog){ + rewriteStatementBlocks(prog, prog.getStatementBlocks(), null); + } + + private void rewriteStatementBlocks(DMLProgram prog, List sbs, Map paramMap) { + sbs.forEach(block -> rewriteStatementBlock(prog, block, paramMap)); + } + + private void rewriteStatementBlock(DMLProgram prog, StatementBlock block, Map paramMap){ + if(block instanceof WhileStatementBlock) + rewriteWhileStatementBlock(prog, (WhileStatementBlock) block, paramMap); + else if(block instanceof IfStatementBlock) + rewriteIfStatementBlock(prog, (IfStatementBlock) block, paramMap); + else if(block instanceof ForStatementBlock) { + // This also includes ParForStatementBlocks + rewriteForStatementBlock(prog, (ForStatementBlock) block, paramMap); + } + else if(block instanceof FunctionStatementBlock) + rewriteFunctionStatementBlock(prog, (FunctionStatementBlock) block, paramMap); + else { + // StatementBlock type (no subclass) + rewriteDefaultStatementBlock(prog, block, paramMap); + } + } + + private void rewriteWhileStatementBlock(DMLProgram prog, WhileStatementBlock whileSB, Map paramMap) { + Hop whilePredicateHop = whileSB.getPredicateHops(); + loadPrivacyConstraint(whilePredicateHop, paramMap); + for(Statement stm : whileSB.getStatements()) { + WhileStatement whileStm = (WhileStatement) stm; + rewriteStatementBlocks(prog, whileStm.getBody(), paramMap); + } + } + + private void rewriteIfStatementBlock(DMLProgram prog, IfStatementBlock ifSB, Map paramMap) { + loadPrivacyConstraint(ifSB.getPredicateHops(), paramMap); + for(Statement statement : ifSB.getStatements()) { + IfStatement ifStatement = (IfStatement) statement; + rewriteStatementBlocks(prog, ifStatement.getIfBody(), paramMap); + rewriteStatementBlocks(prog, ifStatement.getElseBody(), paramMap); + } + } + + private void rewriteForStatementBlock(DMLProgram prog, ForStatementBlock forSB, Map paramMap) { + loadPrivacyConstraint(forSB.getFromHops(), paramMap); + loadPrivacyConstraint(forSB.getToHops(), paramMap); + loadPrivacyConstraint(forSB.getIncrementHops(), paramMap); + for(Statement statement : forSB.getStatements()) { + ForStatement forStatement = ((ForStatement) statement); + rewriteStatementBlocks(prog, forStatement.getBody(), paramMap); + } + } + + private void rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB, Map paramMap) { + for(Statement statement : funcSB.getStatements()) { + FunctionStatement funcStm = (FunctionStatement) statement; + rewriteStatementBlocks(prog, funcStm.getBody(), paramMap); + } + } + + private void rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb, Map paramMap) { + if(sb.hasHops()) { + for(Hop sbHop : sb.getHops()) { + loadPrivacyConstraint(sbHop, paramMap); + if(sbHop instanceof FunctionOp) { + String funcName = ((FunctionOp) sbHop).getFunctionName(); + Map funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop); + if ( paramMap != null && funcParamMap != null) + funcParamMap.putAll(paramMap); + paramMap = funcParamMap; + FunctionStatementBlock sbFuncBlock = prog.getBuiltinFunctionDictionary().getFunction(funcName); + rewriteStatementBlock(prog, sbFuncBlock, paramMap); + } + } + } + } + + private void loadPrivacyConstraint(Hop root, Map paramMap){ + if ( root != null && !memo.containsKey(root.getHopID()) ){ + for ( Hop input : root.getInput() ){ + loadPrivacyConstraint(input, paramMap); + } + propagatePrivConstraintsLocal(root, paramMap); + memo.put(root.getHopID(), root); + } + } + + private void propagatePrivConstraintsLocal(Hop currentHop, Map paramMap){ + if ( currentHop.isFederatedDataOp() ) + loadFederatedPrivacyConstraints(currentHop); + else if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) ){ + currentHop.setPrivacy(currentHop.getInput(0).getPrivacy()); + transientWrites.put(currentHop.getName(), currentHop); + } + else if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ){ + currentHop.setPrivacy(FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites).get(0).getPrivacy()); + } else { + PrivacyPropagator.hopPropagation(currentHop); + } + } + + /** + * Get privacy constraints from federated workers for DataOps. + * @hop hop for which privacy constraints are loaded + */ + private static void loadFederatedPrivacyConstraints(Hop hop){ + try { + PrivacyConstraint.PrivacyLevel constraintLevel = hop.getInput(0).getInput().stream().parallel() + .map( in -> ((LiteralOp)in).getStringValue() ) + .map(PrivacyConstraintLoader::sendPrivConstraintRequest) + .map(PrivacyConstraintLoader::unwrapPrivConstraint) + .map(constraint -> (constraint != null) ? constraint.getPrivacyLevel() : PrivacyConstraint.PrivacyLevel.None) + .reduce(PrivacyConstraint.PrivacyLevel.None, (out,in) -> { + if ( out == PrivacyConstraint.PrivacyLevel.Private || in == PrivacyConstraint.PrivacyLevel.Private ) + return PrivacyConstraint.PrivacyLevel.Private; + else if ( out == PrivacyConstraint.PrivacyLevel.PrivateAggregation || in == PrivacyConstraint.PrivacyLevel.PrivateAggregation ) + return PrivacyConstraint.PrivacyLevel.PrivateAggregation; + else + return out; + }); + PrivacyConstraint fedDataPrivConstraint = (constraintLevel != PrivacyConstraint.PrivacyLevel.None) ? + new PrivacyConstraint(constraintLevel) : null; + + hop.setPrivacy(fedDataPrivConstraint); + } + catch(Exception ex) { + throw new DMLException(ex); + } + } + + private static Future sendPrivConstraintRequest(String address) + { + try{ + String[] parsedAddress = InitFEDInstruction.parseURL(address); + String host = parsedAddress[0]; + int port = Integer.parseInt(parsedAddress[1]); + PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]); + FederatedRequest privacyRetrieval = + new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever); + InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port); + return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval); + } catch(UnknownHostException ex){ + throw new DMLException(ex); + } + } + + private static PrivacyConstraint unwrapPrivConstraint(Future privConstraintFuture) + { + try { + FederatedResponse privConstraintResponse = privConstraintFuture.get(); + return (PrivacyConstraint) privConstraintResponse.getData()[0]; + } catch(Exception ex){ + throw new DMLException(ex); + } + } + + /** + * FederatedUDF for retrieving privacy constraint of data stored in file name. + */ + public static class PrivacyConstraintRetriever extends FederatedUDF { + private static final long serialVersionUID = 3551741240135587183L; + private final String filename; + + public PrivacyConstraintRetriever(String filename){ + super(new long[]{}); + this.filename = filename; + } + + /** + * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse. + * @param ec execution context + * @param data one or many data objects + * @return FederatedResponse with privacy constraint object + */ + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + PrivacyConstraint privacyConstraint; + FileSystem fs = null; + try { + String mtdname = DataExpression.getMTDFileName(filename); + Path path = new Path(mtdname); + fs = IOUtilFunctions.getFileSystem(mtdname); + try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) { + JSONObject metadataObject = JSONHelper.parse(br); + privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject); + } + } + catch (DMLPrivacyException | FederatedWorkerHandlerException ex){ + throw ex; + } + catch (Exception ex) { + String msg = "Exception in reading metadata of: " + filename; + throw new DMLRuntimeException(msg); + } + finally { + IOUtilFunctions.closeSilently(fs); + } + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint); + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + return null; + } + } + +} diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java index 6be3b9c8ec8..e6c683eb386 100644 --- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java +++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java @@ -23,6 +23,7 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner; +import org.apache.sysds.hops.fedplanner.PrivacyConstraintLoader; import org.apache.sysds.parser.DMLProgram; /** @@ -58,16 +59,24 @@ public boolean isApplicable(FunctionCallGraph fgraph) { */ @Override public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) { - // obtain planner instance according to config String splanner = ConfigurationManager.getDMLConfig() .getTextValue(DMLConfig.FEDERATED_PLANNER); + loadPrivacyConstraints(prog, splanner); + generatePlan(prog, fgraph, fcallSizes, splanner); + return false; + } + + private void loadPrivacyConstraints(DMLProgram prog, String splanner){ + if (FederatedPlanner.isCompiled(splanner)) + new PrivacyConstraintLoader().loadConstraints(prog); + } + + private void generatePlan(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes, String splanner){ FederatedPlanner planner = FederatedPlanner.isCompiled(splanner) ? FederatedPlanner.valueOf(splanner.toUpperCase()) : FederatedPlanner.COMPILE_COST_BASED; - + // run planner rewrite with forced federated exec types planner.getPlanner().rewriteProgram(prog, fgraph, fcallSizes); - - return false; } } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index faec3504e9e..db20ada2808 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -27,10 +27,8 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.CompilerConfig.ConfigType; -import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.hops.fedplanner.FTypes; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.ForStatement; import org.apache.sysds.parser.ForStatementBlock; @@ -141,11 +139,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse } - String planner = ConfigurationManager.getDMLConfig() - .getTextValue(DMLConfig.FEDERATED_PLANNER); - if ( OptimizerUtils.FEDERATED_COMPILATION || FTypes.FederatedPlanner.isCompiled(planner) ) { - _dagRuleSet.add( new RewriteFederatedExecution() ); - } } // cleanup after all rewrites applied diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java deleted file mode 100644 index 822b4b5d952..00000000000 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.hops.rewrite; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.log4j.Logger; -import org.apache.sysds.api.DMLException; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.LiteralOp; -import org.apache.sysds.parser.DataExpression; -import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.federated.FederatedData; -import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; -import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; -import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; -import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException; -import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; -import org.apache.sysds.runtime.io.IOUtilFunctions; -import org.apache.sysds.runtime.lineage.LineageItem; -import org.apache.sysds.runtime.privacy.DMLPrivacyException; -import org.apache.sysds.runtime.privacy.PrivacyConstraint; -import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator; -import org.apache.sysds.utils.JSONHelper; -import org.apache.wink.json4j.JSONObject; - -import javax.net.ssl.SSLException; -import java.io.BufferedReader; -import java.io.InputStreamReader; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.concurrent.Future; - -public class RewriteFederatedExecution extends HopRewriteRule { - private static final Logger LOG = Logger.getLogger(RewriteFederatedExecution.class); - - @Override - public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { - if ( roots != null ) - for ( Hop root : roots ) - rewriteHopDAG(root, state); - return roots; - } - - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { - if ( root != null ) - visitHop(root); - return root; - } - - private void visitHop(Hop hop){ - if (hop.isVisited()) - return; - - LOG.debug("RewriteFederatedExecution visitHop + " + hop); - - // Depth first to get to the input - for ( Hop input : hop.getInput() ) - visitHop(input); - - privacyBasedHopDecisionWithFedCall(hop); - hop.setVisited(); - } - - /** - * Get privacy constraints of DataOps from federated worker, - * propagate privacy constraints from input to current hop, - * and set federated output flag. - * @param hop current hop - */ - private static void privacyBasedHopDecisionWithFedCall(Hop hop){ - loadFederatedPrivacyConstraints(hop); - PrivacyPropagator.hopPropagation(hop); - } - - /** - * Get privacy constraints from federated workers for DataOps. - * @hop hop for which privacy constraints are loaded - */ - private static void loadFederatedPrivacyConstraints(Hop hop){ - if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){ - try { - LOG.debug("Load privacy constraints of " + hop); - PrivacyConstraint privConstraint = unwrapPrivConstraint(sendPrivConstraintRequest(hop)); - LOG.debug("PrivacyConstraint retrieved: " + privConstraint); - hop.setPrivacy(privConstraint); - } - catch(Exception e) { - throw new DMLException(e); - } - } - } - - private static Future sendPrivConstraintRequest(Hop hop) - throws UnknownHostException, SSLException - { - String address = ((LiteralOp) hop.getInput(0).getInput(0)).getStringValue(); - String[] parsedAddress = InitFEDInstruction.parseURL(address); - String host = parsedAddress[0]; - int port = Integer.parseInt(parsedAddress[1]); - PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]); - FederatedRequest privacyRetrieval = - new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever); - InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port); - return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval); - } - - private static PrivacyConstraint unwrapPrivConstraint(Future privConstraintFuture) - throws Exception - { - FederatedResponse privConstraintResponse = privConstraintFuture.get(); - return (PrivacyConstraint) privConstraintResponse.getData()[0]; - } - - /** - * FederatedUDF for retrieving privacy constraint of data stored in file name. - */ - public static class PrivacyConstraintRetriever extends FederatedUDF { - private static final long serialVersionUID = 3551741240135587183L; - private final String filename; - - public PrivacyConstraintRetriever(String filename){ - super(new long[]{}); - this.filename = filename; - } - - /** - * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse. - * @param ec execution context - * @param data one or many data objects - * @return FederatedResponse with privacy constraint object - */ - @Override - public FederatedResponse execute(ExecutionContext ec, Data... data) { - PrivacyConstraint privacyConstraint; - FileSystem fs = null; - try { - String mtdname = DataExpression.getMTDFileName(filename); - Path path = new Path(mtdname); - fs = IOUtilFunctions.getFileSystem(mtdname); - try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) { - JSONObject metadataObject = JSONHelper.parse(br); - privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject); - } - } - catch (DMLPrivacyException | FederatedWorkerHandlerException ex){ - throw ex; - } - catch (Exception ex) { - String msg = "Exception in reading metadata of: " + filename; - throw new DMLRuntimeException(msg); - } - finally { - IOUtilFunctions.closeSilently(fs); - } - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint); - } - - @Override - public Pair getLineageItem(ExecutionContext ec) { - return null; - } - } -} diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java index 8ea061844ad..fc9ba440c85 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java @@ -262,8 +262,11 @@ public boolean equals(Object other){ @Override public String toString(){ - return "General privacy level: " + privacyLevel + System.getProperty("line.separator") - + "Fine-grained privacy level: " + fineGrainedPrivacy.toString(); + String constraintString = "General privacy level: " + privacyLevel; + if ( fineGrainedPrivacy != null && fineGrainedPrivacy.hasConstraints() ) + constraintString = constraintString + System.getProperty("line.separator") + + "Fine-grained privacy level: " + fineGrainedPrivacy.toString(); + return constraintString; } } diff --git a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java index 7e6c0127e52..94834ebc6e1 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java @@ -23,11 +23,18 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Objects; +import org.apache.sysds.api.DMLException; import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggUnaryOp; import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.NaryOp; import org.apache.sysds.hops.ReorgOp; import org.apache.sysds.hops.TernaryOp; import org.apache.sysds.hops.UnaryOp; @@ -168,12 +175,39 @@ else if (privacyConstraint2 != null) * @param hop which the privacy constraints are propagated to */ public static void hopPropagation(Hop hop){ - PrivacyConstraint[] inputConstraints = hop.getInput().stream() + hopPropagation(hop, hop.getInput()); + } + + /** + * Propagate privacy constraints from input hops to given hop. + * @param hop which the privacy constraints are propagated to + * @param inputHops inputs to given hop + */ + public static void hopPropagation(Hop hop, ArrayList inputHops){ + PrivacyConstraint[] inputConstraints = inputHops.stream() .map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new); - if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp ) - hop.setPrivacy(mergeNary(inputConstraints, OperatorType.NonAggregate)); + OperatorType opType = getOpType(hop); + hop.setPrivacy(mergeNary(inputConstraints, opType)); + if (opType == null && Arrays.stream(inputConstraints).anyMatch(Objects::nonNull)) + throw new DMLException("Input has constraint but hop type not recognized by PrivacyPropagator. " + + "Hop is " + hop + " " + hop.getClass()); + } + + /** + * Get operator type of given hop. + * Returns null if hop type is not known. + * @param hop for which operator type is returned + * @return operator type of hop or null if hop type is unknown + */ + private static OperatorType getOpType(Hop hop){ + if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp + || hop instanceof DataOp || hop instanceof LiteralOp || hop instanceof NaryOp + || hop instanceof DataGenOp || hop instanceof FunctionOp ) + return OperatorType.NonAggregate; else if ( hop instanceof AggBinaryOp || hop instanceof AggUnaryOp || hop instanceof UnaryOp ) - hop.setPrivacy(mergeNary(inputConstraints, OperatorType.Aggregate)); + return OperatorType.Aggregate; + else + return null; } /** @@ -406,7 +440,7 @@ private static Instruction throwExceptionIfInputOrInstPrivacy(Instruction inst, if (inputOperands != null){ for ( CPOperand input : inputOperands ){ PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, input); - if ( privacyConstraint != null){ + if ( privacyConstraint != null && privacyConstraint.hasConstraints()){ throw new DMLPrivacyException("Input of instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction."); } } diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java index 589f23a845b..ded46c039a1 100644 --- a/src/main/java/org/apache/sysds/utils/Explain.java +++ b/src/main/java/org/apache/sysds/utils/Explain.java @@ -626,6 +626,9 @@ else if( hop.requiresCheckpoint() ) } } + if ( hop.getPrivacy() != null ) + sb.append(" ").append(hop.getPrivacy().getPrivacyLevel().name()); + sb.append('\n'); hop.setVisited(); diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java index 14c093ebe82..2477bdef851 100644 --- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java @@ -55,6 +55,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase { private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8"; private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9"; private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10"; + private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11"; private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/"; private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml"); @@ -77,6 +78,7 @@ public void setUp() { addTestConfiguration(TEST_NAME_8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"})); addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"})); addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"})); } @Parameterized.Parameters @@ -153,6 +155,12 @@ public void federatedMultiplyPlanningTest10(){ federatedTwoMatricesSingleNodeTest(TEST_NAME_10, expectedHeavyHitters); } + @Test + public void federatedMultiplyPlanningTest11(){ + String[] expectedHeavyHitters = new String[]{"fed_fedinit"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters); + } + private void writeStandardMatrix(String matrixName, long seed){ writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation)); } diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml new file mode 100644 index 00000000000..147bf2cd139 --- /dev/null +++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(addresses=list($X1, $X2), + ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c))) +Y = federated(addresses=list($Y1, $Y2), + ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c))) + +i = 0 +while(i < 10){ + Z0 = X * Y + Z = t(Z0) %*% X + i=i+1 +} + +write(Z, $Z) diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml new file mode 100644 index 00000000000..187623bbfeb --- /dev/null +++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($X1), read($X2)) +Y = rbind(read($Y1), read($Y2)) + +i = 0 +while(i < 10){ + Z0 = X * Y + Z = t(Z0) %*% X + i=i+1 +} + +write(Z, $Z)