diff --git a/heron/mlmgr/src/main/java/org/apache/heron/LocalStormDoTask.java b/heron/mlmgr/src/main/java/org/apache/heron/LocalStormDoTask.java new file mode 100644 index 00000000000..71b7aff474e --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/LocalStormDoTask.java @@ -0,0 +1,82 @@ +package org.apache.heron; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.samoa.topology.impl.StormSamoaUtils; +import org.apache.samoa.topology.impl.StormTopology; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.commons.configuration.Configuration; + +import backtype.storm.Config; +import backtype.storm.utils.Utils; + +/** + * The main class to execute a SAMOA task in LOCAL mode in Storm. + * + * @author Arinto Murdopo + * + */ +public class LocalStormDoTask { + + private static final Logger logger = LoggerFactory.getLogger(LocalStormDoTask.class); + private static final String EXECUTION_DURATION_KEY ="samoa.storm.local.mode.execution.duration"; + private static final String SAMOA_STORM_PROPERTY_FILE_LOC ="samoa-storm.properties"; + /** + * The main method. + * + * @param args + * the arguments + */ + public static void main(String[] args) { + + List tmpArgs = new ArrayList(Arrays.asList(args)); + + int numWorker = StormSamoaUtils.numWorkers(tmpArgs); + + args = tmpArgs.toArray(new String[0]); + + // convert the arguments into Storm topology + StormTopology stormTopo = StormSamoaUtils.argsToTopology(args); + String topologyName = stormTopo.getTopologyName(); + + Config conf = new Config(); + // conf.putAll(Utils.readStormConfig()); + conf.setDebug(false); + + // local mode + conf.setMaxTaskParallelism(numWorker); + + backtype.storm.LocalCluster cluster = new backtype.storm.LocalCluster(); + cluster.submitTopology(topologyName, conf, stormTopo.getStormBuilder().createTopology()); + + // Read local mode execution duration from property file + Configuration stormConfig = StormSamoaUtils.getPropertyConfig(LocalStormDoTask.SAMOA_STORM_PROPERTY_FILE_LOC); + long executionDuration= stormConfig.getLong(LocalStormDoTask.EXECUTION_DURATION_KEY); + backtype.storm.utils.Utils.sleep(executionDuration * 1000); + + cluster.killTopology(topologyName); + cluster.shutdown(); + + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/AdaptiveLearner.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/AdaptiveLearner.java new file mode 100644 index 00000000000..2fa97cc527f --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/AdaptiveLearner.java @@ -0,0 +1,46 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; + +/** + * The Interface Adaptive Learner. Initializing Classifier should initalize PI to connect the Classifier with the input + * stream and initialize result stream so that other PI can connect to the classification result of this classifier + */ + +public interface AdaptiveLearner extends Learner{ + + /** + * Gets the change detector item. + * + * @return the change detector item + */ + public ChangeDetector getChangeDetector(); + + /** + * Sets the change detector item. + * + * @param cd + * the change detector item + */ + public void setChangeDetector(ChangeDetector cd); + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/ClassificationLearner.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/ClassificationLearner.java new file mode 100644 index 00000000000..ee9bea1b404 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/ClassificationLearner.java @@ -0,0 +1,26 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.samoa.learners.Learner; + +public interface ClassificationLearner extends Learner { + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/InstanceContent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/InstanceContent.java new file mode 100644 index 00000000000..5a12eefe31b --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/InstanceContent.java @@ -0,0 +1,202 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * License + */ + +import net.jcip.annotations.Immutable; + +import org.apache.samoa.core.SerializableInstance; +import org.apache.samoa.instances.Instance; + +import java.io.Serializable; + +/** + * The Class InstanceContent. + */ +@Immutable +final public class InstanceContent implements Serializable { + + private static final long serialVersionUID = -8620668863064613841L; + + private long instanceIndex; + private int classifierIndex; + private int evaluationIndex; + private SerializableInstance instance; + private boolean isTraining; + private boolean isTesting; + private boolean isLast = false; + + public InstanceContent() { + + } + + /** + * Instantiates a new instance event. + * + * @param index + * the index + * @param instance + * the instance + * @param isTraining + * the is training + */ + public InstanceContent(long index, Instance instance, + boolean isTraining, boolean isTesting) { + if (instance != null) { + this.instance = new SerializableInstance(instance); + } + this.instanceIndex = index; + this.isTraining = isTraining; + this.isTesting = isTesting; + } + + /** + * Gets the single instance of InstanceEvent. + * + * @return the instance. + */ + public Instance getInstance() { + return instance; + } + + /** + * Gets the instance index. + * + * @return the index of the data vector. + */ + public long getInstanceIndex() { + return instanceIndex; + } + + /** + * Gets the class id. + * + * @return the true class of the vector. + */ + public int getClassId() { + // return classId; + return (int) instance.classValue(); + } + + /** + * Checks if is training. + * + * @return true if this is training data. + */ + public boolean isTraining() { + return isTraining; + } + + /** + * Set training flag. + * + * @param training + * flag. + */ + public void setTraining(boolean training) { + this.isTraining = training; + } + + /** + * Checks if is testing. + * + * @return true if this is testing data. + */ + public boolean isTesting() { + return isTesting; + } + + /** + * Set testing flag. + * + * @param testing + * flag. + */ + public void setTesting(boolean testing) { + this.isTesting = testing; + } + + /** + * Gets the classifier index. + * + * @return the classifier index + */ + public int getClassifierIndex() { + return classifierIndex; + } + + /** + * Sets the classifier index. + * + * @param classifierIndex + * the new classifier index + */ + public void setClassifierIndex(int classifierIndex) { + this.classifierIndex = classifierIndex; + } + + /** + * Gets the evaluation index. + * + * @return the evaluation index + */ + public int getEvaluationIndex() { + return evaluationIndex; + } + + /** + * Sets the evaluation index. + * + * @param evaluationIndex + * the new evaluation index + */ + public void setEvaluationIndex(int evaluationIndex) { + this.evaluationIndex = evaluationIndex; + } + + /** + * Sets the instance index. + * + * @param instanceIndex + * the new evaluation index + */ + public void setInstanceIndex(long instanceIndex) { + this.instanceIndex = instanceIndex; + } + + public boolean isLastEvent() { + return isLast; + } + + public void setLast(boolean isLast) { + this.isLast = isLast; + } + + @Override + public String toString() { + return String + .format( + "InstanceContent [instanceIndex=%s, classifierIndex=%s, evaluationIndex=%s, instance=%s, isTraining=%s, isTesting=%s, isLast=%s]", + instanceIndex, classifierIndex, evaluationIndex, instance, isTraining, isTesting, isLast); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/InstanceContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/InstanceContentEvent.java new file mode 100644 index 00000000000..0c82d5ce5df --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/InstanceContentEvent.java @@ -0,0 +1,200 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.SerializableInstance; +import org.apache.samoa.instances.Instance; + +import net.jcip.annotations.Immutable; + +/** + * The Class InstanceContentEvent. + */ +@Immutable +final public class InstanceContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = -8620668863064613845L; + private InstanceContent instanceContent; + + public InstanceContentEvent() { + + } + + /** + * Instantiates a new instance event. + * + * @param index + * the index + * @param instance + * the instance + * @param isTraining + * the is training + */ + public InstanceContentEvent(long index, Instance instance, + boolean isTraining, boolean isTesting) { + this.instanceContent = new InstanceContent(index, instance, isTraining, isTesting); + } + + /** + * Gets the single instance of InstanceEvent. + * + * @return the instance. + */ + public Instance getInstance() { + return this.instanceContent.getInstance(); + } + + /** + * Gets the instance index. + * + * @return the index of the data vector. + */ + public long getInstanceIndex() { + return this.instanceContent.getInstanceIndex(); + } + + /** + * Gets the class id. + * + * @return the true class of the vector. + */ + public int getClassId() {return this.instanceContent.getClassId(); + } + + /** + * Checks if is training. + * + * @return true if this is training data. + */ + public boolean isTraining() { + return this.instanceContent.isTraining(); + } + + /** + * Set training flag. + * + * @param training + * flag. + */ + public void setTraining(boolean training) {this.instanceContent.setTraining(training);} + + + /** + * Checks if is testing. + * + * @return true if this is testing data. + */ + public boolean isTesting() { + return this.instanceContent.isTesting(); + } + + /** + * Set testing flag. + * + * @param testing + * flag. + */ + public void setTesting(boolean testing) { + this.instanceContent.setTesting(testing); + } + + /** + * Gets the classifier index. + * + * @return the classifier index + */ + public int getClassifierIndex() { + return this.instanceContent.getClassifierIndex(); + } + + /** + * Sets the classifier index. + * + * @param classifierIndex + * the new classifier index + */ + public void setClassifierIndex(int classifierIndex) { + this.instanceContent.setClassifierIndex(classifierIndex); + } + + /** + * Gets the evaluation index. + * + * @return the evaluation index + */ + public int getEvaluationIndex() { + return this.instanceContent.getEvaluationIndex(); + } + + /** + * Sets the evaluation index. + * + * @param evaluationIndex + * the new evaluation index + */ + public void setEvaluationIndex(int evaluationIndex) { + this.instanceContent.setEvaluationIndex(evaluationIndex); + } + + /* + * (non-Javadoc) + * + * @see samoa.core.ContentEvent#getKey(int) + */ + public String getKey(int key) { + if (key == 0) + return Long.toString(this.getEvaluationIndex()); + else + return Long.toString(10000 + * this.getEvaluationIndex() + + this.getClassifierIndex()); + } + + @Override + public String getKey() { + // System.out.println("InstanceContentEvent "+Long.toString(this.instanceIndex)); + return Long.toString(this.getClassifierIndex()); + } + + @Override + public void setKey(String str) { + this.instanceContent.setInstanceIndex(Long.parseLong(str)); + } + + @Override + public boolean isLastEvent() { + return this.instanceContent.isLastEvent(); + } + + public void setLast(boolean isLast) { + this.instanceContent.setLast(isLast); + } + /** + * Gets the Instance Content. + * + * @return the instance content + */ + public InstanceContent getInstanceContent() { + return instanceContent; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/InstancesContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/InstancesContentEvent.java new file mode 100644 index 00000000000..0f1e9afe5f8 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/InstancesContentEvent.java @@ -0,0 +1,121 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import net.jcip.annotations.Immutable; + +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; + +/** + * The Class InstanceEvent. + */ +@Immutable +final public class InstancesContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = -8620668863064613845L; + + protected List instanceList = new LinkedList(); + + public InstancesContentEvent() { + + } + + /** + * Instantiates a new event with a list of InstanceContent. + * + */ + + public InstancesContentEvent(InstanceContentEvent event) { + this.add(event.getInstanceContent()); + } + + + public void add(InstanceContent instance) { + instanceList.add(instance); + } + + /** + * Gets the single instance of InstanceEvent. + * + * @return the instance. + */ + public InstanceContent[] getInstances() { + return instanceList.toArray(new InstanceContent[instanceList.size()]); + } + + + /** + * Gets the classifier index. + * + * @return the classifier index + */ + public int getClassifierIndex() { + return this.instanceList.get(0).getClassifierIndex(); + } + + /** + * Gets the evaluation index. + * + * @return the evaluation index + */ + public int getEvaluationIndex() { + return this.instanceList.get(0).getEvaluationIndex(); + } + + + /* + * (non-Javadoc) + * + * @see samoa.core.ContentEvent#getKey(int) + */ + public String getKey(int key) { + if (key == 0) + return Long.toString(this.getEvaluationIndex()); + else + return Long.toString(10000 + * this.getEvaluationIndex() + + this.getClassifierIndex()); + } + + @Override + public String getKey() { + // System.out.println("InstanceContentEvent "+Long.toString(this.instanceIndex)); + return Long.toString(this.getClassifierIndex()); + } + + @Override + public void setKey(String key) { + //No needed + } + + @Override + public boolean isLastEvent() { + return this.instanceList.get(this.instanceList.size()-1).isLastEvent(); + } + + public List getList() { + return this.instanceList; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/Learner.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/Learner.java new file mode 100644 index 00000000000..8efd93d1bff --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/Learner.java @@ -0,0 +1,63 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.io.Serializable; +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +/** + * A Learner instance learns a model, and it can be either a classifier or a regressor. Initializing a Learner should + * initialize a {@link Processor}, connect the Learner with the input stream, and initialize the result stream so that other processors + * can subscribe to the results of this learner. + */ + +public interface Learner extends Serializable { + + /** + * Inits the Learner object. + * + * @param topologyBuilder + * the topology builder + * @param dataset + * the dataset + * @param parallelism + * the parallelism + */ + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism); + + /** + * Gets the input processing item. + * + * @return the input processing item + */ + public Processor getInputProcessor(); + + /** + * Gets the result streams + * + * @return the set of result streams + */ + public Set getResultStreams(); +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/RegressionLearner.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/RegressionLearner.java new file mode 100644 index 00000000000..dc4f7c5b6d6 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/RegressionLearner.java @@ -0,0 +1,25 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import org.apache.samoa.learners.Learner; + +public interface RegressionLearner extends Learner { + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/ResultContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/ResultContentEvent.java new file mode 100644 index 00000000000..e44642a611d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/ResultContentEvent.java @@ -0,0 +1,212 @@ +package org.apache.heron.learners; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.SerializableInstance; +import org.apache.samoa.instances.Instance; + +/** + * License + */ + +/** + * The Class ResultEvent. + */ +final public class ResultContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = -2650420235386873306L; + private long instanceIndex; + private int classifierIndex; + private int evaluationIndex; + private SerializableInstance instance; + + private int classId; + private double[] classVotes; + + private final boolean isLast; + + public ResultContentEvent() { + this.isLast = false; + } + + public ResultContentEvent(boolean isLast) { + this.isLast = isLast; + } + + /** + * Instantiates a new result event. + * + * @param instanceIndex + * the instance index + * @param instance + * the instance + * @param classId + * the class id + * @param classVotes + * the class votes + */ + public ResultContentEvent(long instanceIndex, Instance instance, int classId, + double[] classVotes, boolean isLast) { + if (instance != null) { + this.instance = new SerializableInstance(instance); + } + this.instanceIndex = instanceIndex; + this.classId = classId; + this.classVotes = classVotes; + this.isLast = isLast; + } + + /** + * Gets the single instance of ResultEvent. + * + * @return single instance of ResultEvent + */ + public SerializableInstance getInstance() { + return instance; + } + + /** + * Sets the instance. + * + * @param instance + * the new instance + */ + public void setInstance(SerializableInstance instance) { + this.instance = instance; + } + + /** + * Gets the num classes. + * + * @return the num classes + */ + public int getNumClasses() { // To remove + return instance.numClasses(); + } + + /** + * Gets the instance index. + * + * @return the index of the data vector. + */ + public long getInstanceIndex() { + return instanceIndex; + } + + /** + * Gets the class id. + * + * @return the true class of the vector. + */ + public int getClassId() { // To remove + return classId;// (int) instance.classValue();//classId; + } + + /** + * Gets the class votes. + * + * @return the class votes + */ + public double[] getClassVotes() { + return classVotes; + } + + /** + * Sets the class votes. + * + * @param classVotes + * the new class votes + */ + public void setClassVotes(double[] classVotes) { + this.classVotes = classVotes; + } + + /** + * Gets the classifier index. + * + * @return the classifier index + */ + public int getClassifierIndex() { + return classifierIndex; + } + + /** + * Sets the classifier index. + * + * @param classifierIndex + * the new classifier index + */ + public void setClassifierIndex(int classifierIndex) { + this.classifierIndex = classifierIndex; + } + + /** + * Gets the evaluation index. + * + * @return the evaluation index + */ + public int getEvaluationIndex() { + return evaluationIndex; + } + + /** + * Sets the evaluation index. + * + * @param evaluationIndex + * the new evaluation index + */ + public void setEvaluationIndex(int evaluationIndex) { + this.evaluationIndex = evaluationIndex; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.ContentEvent#getKey(int) + */ + // @Override + public String getKey(int key) { + if (key == 0) + return Long.toString(this.getEvaluationIndex()); + else + return Long.toString(this.getEvaluationIndex() + + 1000 * this.getInstanceIndex()); + } + + @Override + public String getKey() { + return Long.toString(this.getEvaluationIndex() % 100); + } + + @Override + public void setKey(String str) { + this.evaluationIndex = Integer.parseInt(str); + } + + @Override + public boolean isLastEvent() { + return isLast; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/LocalLearner.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/LocalLearner.java new file mode 100644 index 00000000000..d42bd6f5e22 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/LocalLearner.java @@ -0,0 +1,76 @@ +package org.apache.heron.learners.classifiers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.io.Serializable; +import java.util.Map; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; + +/** + * Learner interface for non-distributed learners. + * + * @author abifet + */ +public interface LocalLearner extends Serializable { + + /** + * Creates a new learner object. + * + * @return the learner + */ + LocalLearner create(); + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + double[] getVotesForInstance(Instance inst); + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + void resetLearning(); + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + void trainOnInstance(Instance inst); + + /** + * Sets where to obtain the information of attributes of Instances + * + * @param dataset + * the dataset that contains the information + */ + @Deprecated + public void setDataset(Instances dataset); + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/LocalLearnerProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/LocalLearnerProcessor.java new file mode 100644 index 00000000000..e82713bb9d1 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/LocalLearnerProcessor.java @@ -0,0 +1,223 @@ +package org.apache.heron.learners.classifiers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.samoa.moa.core.Utils.maxIndex; + +/** + * The Class LearnerProcessor. + */ +final public class LocalLearnerProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -1577910988699148691L; + + private static final Logger logger = LoggerFactory.getLogger(LocalLearnerProcessor.class); + + private LocalLearner model; + private Stream outputStream; + private int modelId; + private long instancesCount = 0; + + /** + * Sets the learner. + * + * @param model + * the model to set + */ + public void setLearner(LocalLearner model) { + this.model = model; + } + + /** + * Gets the learner. + * + * @return the model + */ + public LocalLearner getLearner() { + return model; + } + + /** + * Set the output streams. + * + * @param outputStream + * the new output stream + */ + public void setOutputStream(Stream outputStream) { + this.outputStream = outputStream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the instances count. + * + * @return number of observation vectors used in training iteration. + */ + public long getInstancesCount() { + return instancesCount; + } + + /** + * Update stats. + * + * @param event + * the event + */ + private void updateStats(InstanceContentEvent event) { + Instance inst = event.getInstance(); + this.model.trainOnInstance(inst); + this.instancesCount++; + if (this.changeDetector != null) { + boolean correctlyClassifies = this.correctlyClassifies(inst); + double oldEstimation = this.changeDetector.getEstimation(); + this.changeDetector.input(correctlyClassifies ? 0 : 1); + if (this.changeDetector.getChange() && this.changeDetector.getEstimation() > oldEstimation) { + // Start a new classifier + this.model.resetLearning(); + this.changeDetector.resetLearning(); + } + } + } + + /** + * Gets whether this classifier correctly classifies an instance. Uses getVotesForInstance to obtain the prediction + * and the instance to obtain its true class. + * + * + * @param inst + * the instance to be classified + * @return true if the instance is correctly classified + */ + private boolean correctlyClassifies(Instance inst) { + return maxIndex(model.getVotesForInstance(inst)) == (int) inst.classValue(); + } + + /** The test. */ + protected int test; // to delete + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + InstanceContentEvent inEvent = (InstanceContentEvent) event; + Instance instance = inEvent.getInstance(); + + if (inEvent.getInstanceIndex() < 0) { + // end learning + ResultContentEvent outContentEvent = new ResultContentEvent(-1, instance, 0, + new double[0], inEvent.isLastEvent()); + outContentEvent.setClassifierIndex(this.modelId); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + return false; + } + + if (inEvent.isTesting()) { + double[] dist = model.getVotesForInstance(instance); + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), + instance, inEvent.getClassId(), dist, inEvent.isLastEvent()); + outContentEvent.setClassifierIndex(this.modelId); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + logger.trace(inEvent.getInstanceIndex() + " {} {}", modelId, dist); + outputStream.put(outContentEvent); + } + + if (inEvent.isTraining()) { + updateStats(inEvent); + } + return false; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#onCreate(int) + */ + @Override + public void onCreate(int id) { + this.modelId = id; + model = model.create(); + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + LocalLearnerProcessor newProcessor = new LocalLearnerProcessor(); + LocalLearnerProcessor originProcessor = (LocalLearnerProcessor) sourceProcessor; + + if (originProcessor.getLearner() != null) { + newProcessor.setLearner(originProcessor.getLearner().create()); + } + + if (originProcessor.getChangeDetector() != null) { + newProcessor.setChangeDetector(originProcessor.getChangeDetector()); + } + + newProcessor.setOutputStream(originProcessor.getOutputStream()); + return newProcessor; + } + + protected ChangeDetector changeDetector; + + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/NaiveBayes.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/NaiveBayes.java new file mode 100644 index 00000000000..92655ec787f --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/NaiveBayes.java @@ -0,0 +1,263 @@ +package org.apache.heron.learners.classifiers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; +import org.apache.samoa.moa.core.GaussianEstimator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of a non-distributed Naive Bayes classifier. + * + * At the moment, the implementation models all attributes as numeric attributes. + * + * @author Olivier Van Laere (vanlaere yahoo-inc dot com) + */ +public class NaiveBayes implements LocalLearner { + + /** + * Default smoothing factor. For now fixed to 1E-20. + */ + private static final double ADDITIVE_SMOOTHING_FACTOR = 1e-20; + + /** + * serialVersionUID for serialization + */ + private static final long serialVersionUID = 1325775209672996822L; + + /** + * Instance of a logger for use in this class. + */ + private static final Logger logger = LoggerFactory.getLogger(NaiveBayes.class); + + /** + * The actual model. + */ + protected Map attributeObservers; + + /** + * Class statistics + */ + protected Map classInstances; + + /** + * Class zero-prototypes. + */ + protected Map classPrototypes; + + /** + * Retrieve the number of classes currently known to this local model + * + * @return the number of classes currently known to this local model + */ + protected int getNumberOfClasses() { + return this.classInstances.size(); + } + + /** + * Track training instances seen. + */ + protected long instancesSeen = 0L; + + /** + * Explicit no-arg constructor. + */ + public NaiveBayes() { + // Init the model + resetLearning(); + } + + /** + * Create an instance of this LocalLearner implementation. + */ + @Override + public LocalLearner create() { + return new NaiveBayes(); + } + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * will be all zero. + * + * Smoothing is being implemented by the AttributeClassObserver classes. At the moment, the + * GaussianNumericProbabilityAttributeClassObserver needs no smoothing as it processes continuous variables. + * + * Please note that we transform the scores to log space to avoid underflow, and we replace the multiplication with + * addition. + * + * The resulting scores are no longer probabilities, as a mixture of probability densities and probabilities can be + * used in the computation. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership scores of the test instance in each class, in log space. + */ + @Override + public double[] getVotesForInstance(Instance inst) { + // Prepare the results array + double[] votes = new double[getNumberOfClasses()]; + // Over all classes + for (int classIndex = 0; classIndex < votes.length; classIndex++) { + // Get the prior for this class + votes[classIndex] = Math.log(getPrior(classIndex)); + // Iterate over the instance attributes + for (int index = 0; index < inst.numAttributes(); index++) { + int attributeID = inst.index(index); + // Skip class attribute + if (attributeID == inst.classIndex()) + continue; + Double value = inst.value(attributeID); + // Get the observer for the given attribute + GaussianNumericAttributeClassObserver obs = attributeObservers.get(attributeID); + // Init the estimator to null by default + GaussianEstimator estimator = null; + if (obs != null && obs.getEstimator(classIndex) != null) { + // Get the estimator + estimator = obs.getEstimator(classIndex); + } + double valueNonZero; + // The null case should be handled by smoothing! + if (estimator != null) { + // Get the score for a NON-ZERO attribute value + valueNonZero = estimator.probabilityDensity(value); + } + // We don't have an estimator + else { + // Assign a very small probability that we do see this value + valueNonZero = ADDITIVE_SMOOTHING_FACTOR; + } + votes[classIndex] += Math.log(valueNonZero); // - Math.log(valueZero); + } + // Check for null in the case of prequential evaluation + if (this.classPrototypes.get(classIndex) != null) { + // Add the prototype for the class, already in log space + votes[classIndex] += Math.log(this.classPrototypes.get(classIndex)); + } + } + return votes; + } + + /** + * Compute the prior for the given classIndex. + * + * Implemented by maximum likelihood at the moment. + * + * @param classIndex + * Id of the class for which we want to compute the prior. + * @return Prior probability for the requested class + */ + private double getPrior(int classIndex) { + // Maximum likelihood + Double currentCount = this.classInstances.get(classIndex); + if (currentCount == null || currentCount == 0) + return 0; + else + return currentCount * 1. / this.instancesSeen; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + */ + @Override + public void resetLearning() { + // Reset priors + this.instancesSeen = 0L; + this.classInstances = new HashMap<>(); + this.classPrototypes = new HashMap<>(); + // Init the attribute observers + this.attributeObservers = new HashMap<>(); + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + // Update class statistics with weights + int classIndex = (int) inst.classValue(); + Double weight = this.classInstances.get(classIndex); + if (weight == null) + weight = 0.; + this.classInstances.put(classIndex, weight + inst.weight()); + + // Get the class prototype + Double classPrototype = this.classPrototypes.get(classIndex); + if (classPrototype == null) + classPrototype = 1.; + + // Iterate over the attributes of the given instance + for (int attributePosition = 0; attributePosition < inst + .numAttributes(); attributePosition++) { + // Get the attribute index - Dense -> 1:1, Sparse is remapped + int attributeID = inst.index(attributePosition); + // Skip class attribute + if (attributeID == inst.classIndex()) + continue; + // Get the attribute observer for the current attribute + GaussianNumericAttributeClassObserver obs = this.attributeObservers + .get(attributeID); + // Lazy init of observers, if null, instantiate a new one + if (obs == null) { + // FIXME: At this point, we model everything as a numeric + // attribute + obs = new GaussianNumericAttributeClassObserver(); + this.attributeObservers.put(attributeID, obs); + } + + // Get the probability density function under the current model + GaussianEstimator obs_estimator = obs.getEstimator(classIndex); + if (obs_estimator != null) { + // Fetch the probability that the feature value is zero + double probDens_zero_current = obs_estimator.probabilityDensity(0); + classPrototype -= probDens_zero_current; + } + + // FIXME: Sanity check on data values, for now just learn + // Learn attribute value for given class + obs.observeAttributeClass(inst.valueSparse(attributePosition), + (int) inst.classValue(), inst.weight()); + + // Update obs_estimator to fetch the pdf from the updated model + obs_estimator = obs.getEstimator(classIndex); + // Fetch the probability that the feature value is zero + double probDens_zero_updated = obs_estimator.probabilityDensity(0); + // Update the class prototype + classPrototype += probDens_zero_updated; + } + // Store the class prototype + this.classPrototypes.put(classIndex, classPrototype); + // Count another training instance + this.instancesSeen++; + } + + @Override + public void setDataset(Instances dataset) { + // Do nothing + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/SimpleClassifierAdapter.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/SimpleClassifierAdapter.java new file mode 100644 index 00000000000..8b850d7dda1 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/SimpleClassifierAdapter.java @@ -0,0 +1,153 @@ +package org.apache.heron.learners.classifiers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * License + */ +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.moa.classifiers.functions.MajorityClass; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Base class for adapting external classifiers. + * + */ +public class SimpleClassifierAdapter implements LocalLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 4372366401338704353L; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Classifier to train.", org.apache.samoa.moa.classifiers.Classifier.class, MajorityClass.class.getName()); + /** + * The learner. + */ + protected org.apache.samoa.moa.classifiers.Classifier learner; + + /** + * The is init. + */ + protected Boolean isInit; + + /** + * The dataset. + */ + protected Instances dataset; + + @Override + public void setDataset(Instances dataset) { + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner + * the learner + * @param dataset + * the dataset + */ + public SimpleClassifierAdapter(org.apache.samoa.moa.classifiers.Classifier learner, Instances dataset) { + this.learner = learner.copy(); + this.isInit = false; + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + */ + public SimpleClassifierAdapter() { + this.learner = ((org.apache.samoa.moa.classifiers.Classifier) this.learnerOption.getValue()).copy(); + this.isInit = false; + } + + /** + * Creates a new learner object. + * + * @return the learner + */ + @Override + public SimpleClassifierAdapter create() { + SimpleClassifierAdapter l = new SimpleClassifierAdapter(learner, dataset); + if (dataset == null) { + System.out.println("dataset null while creating"); + } + return l; + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + if (!this.isInit) { + this.isInit = true; + InstancesHeader instances = new InstancesHeader(dataset); + this.learner.setModelContext(instances); + this.learner.prepareForUse(); + } + if (inst.weight() > 0) { + inst.setDataset(dataset); + learner.trainOnInstance(inst); + } + } + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + @Override + public double[] getVotesForInstance(Instance inst) { + double[] ret; + inst.setDataset(dataset); + if (!this.isInit) { + ret = new double[dataset.numClasses()]; + } else { + ret = learner.getVotesForInstance(inst); + } + return ret; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + @Override + public void resetLearning() { + learner.resetLearning(); + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/SingleClassifier.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/SingleClassifier.java new file mode 100644 index 00000000000..eda5560d3e0 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/SingleClassifier.java @@ -0,0 +1,113 @@ +package org.apache.heron.learners.classifiers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.AdaptiveLearner; +import org.apache.samoa.learners.ClassificationLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Classifier that contain a single classifier. + * + */ +public final class SingleClassifier implements ClassificationLearner, AdaptiveLearner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private LocalLearnerProcessor learnerP; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Classifier to train.", LocalLearner.class, SimpleClassifierAdapter.class.getName()); + + private TopologyBuilder builder; + + private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + learnerP = new LocalLearnerProcessor(); + learnerP.setChangeDetector(this.getChangeDetector()); + LocalLearner learner = this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + + // learnerPI = this.builder.createPi(learnerP, 1); + this.builder.addProcessor(learnerP, parallelism); + resultStream = this.builder.createStream(learnerP); + + learnerP.setOutputStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + return learnerP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } + + protected ChangeDetector changeDetector; + + @Override + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + @Override + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/AdaptiveBagging.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/AdaptiveBagging.java new file mode 100644 index 00000000000..e8643c888cf --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/AdaptiveBagging.java @@ -0,0 +1,153 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.AdaptiveLearner; +import org.apache.samoa.learners.ClassificationLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.moa.classifiers.core.driftdetection.ADWINChangeDetector; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * An adaptive version of the Bagging Classifier by Oza and Russell. + */ +public class AdaptiveBagging implements ClassificationLearner, Configurable { + + private static final long serialVersionUID = 8217274236558839040L; + private static final Logger logger = LoggerFactory.getLogger(AdaptiveBagging.class); + + /** The base learner option. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", AdaptiveLearner.class, VerticalHoeffdingTree.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'd', + "Drift detection method to use.", ChangeDetector.class, ADWINChangeDetector.class.getName()); + + /** The distributor processor. */ + private BaggingDistributorProcessor distributorP; + + /** The input streams for the ensemble, one per member. */ + private Stream[] ensembleStreams; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected AdaptiveLearner[] ensemble; + + /** + * Sets the layout. + */ + protected void setLayout() { + int ensembleSize = this.ensembleSizeOption.getValue(); + + distributorP = new BaggingDistributorProcessor(); + distributorP.setEnsembleSize(ensembleSize); + builder.addProcessor(distributorP, 1); + + // instantiate classifier + ensemble = new AdaptiveLearner[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + try { + ensemble[i] = (AdaptiveLearner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(), + baseLearnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create members of the ensemble. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue()); + ensemble[i].init(builder, this.dataset, 1); // sequential + } + + PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor(); + predictionCombinerP.setEnsembleSize(ensembleSize); + this.builder.addProcessor(predictionCombinerP, 1); + + // Streams + resultStream = builder.createStream(predictionCombinerP); + predictionCombinerP.setOutputStream(resultStream); + + for (AdaptiveLearner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions + } + } + + ensembleStreams = new Stream[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + ensembleStreams[i] = builder.createStream(distributorP); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } + + distributorP.setOutputStreams(ensembleStreams); + } + + /** The builder. */ + private TopologyBuilder builder; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributorP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Bagging.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Bagging.java new file mode 100644 index 00000000000..f3751ca9806 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Bagging.java @@ -0,0 +1,148 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.ClassificationLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; +import com.google.common.collect.ImmutableSet; + +/** + * The Bagging Classifier by Oza and Russell. + */ +public class Bagging implements ClassificationLearner, Configurable { + + /** The Constant serialVersionUID. */ + private static final long serialVersionUID = -2971850264864952099L; + private static final Logger logger = LoggerFactory.getLogger(Bagging.class); + + /** The base learner option. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + /** The distributor processor. */ + private BaggingDistributorProcessor distributorP; + + /** The input streams for the ensemble, one per member. */ + private Stream[] ensembleStreams; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner[] ensemble; + + /** + * Sets the layout. + * + * @throws Exception + */ + protected void setLayout() { + int ensembleSize = this.ensembleSizeOption.getValue(); + + distributorP = new BaggingDistributorProcessor(); + distributorP.setEnsembleSize(ensembleSize); + builder.addProcessor(distributorP, 1); + + // instantiate classifier + ensemble = new Learner[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + try { + ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(), + baseLearnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create members of the ensemble. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].init(builder, this.dataset, 1); // sequential + } + + PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor(); + predictionCombinerP.setEnsembleSize(ensembleSize); + this.builder.addProcessor(predictionCombinerP, 1); + + // Streams + resultStream = builder.createStream(predictionCombinerP); + predictionCombinerP.setOutputStream(resultStream); + + for (Learner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions + } + } + + ensembleStreams = new Stream[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + ensembleStreams[i] = builder.createStream(distributorP); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } + + distributorP.setOutputStreams(ensembleStreams); + } + + /** The builder. */ + private TopologyBuilder builder; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributorP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BaggingDistributorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BaggingDistributorProcessor.java new file mode 100644 index 00000000000..f5b36190984 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BaggingDistributorProcessor.java @@ -0,0 +1,152 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ + +import java.util.Arrays; +import java.util.Random; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.moa.core.MiscUtils; +import org.apache.samoa.topology.Stream; + +import com.google.common.base.Preconditions; + +/** + * The Class BaggingDistributorPE. + */ +public class BaggingDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192730L; + + /** The ensemble size. */ + private int ensembleSize; + + /** The stream ensemble. */ + private Stream[] ensembleStreams; + + /** Ramdom number generator. */ + protected Random random = new Random(); //TODO make random seed configurable + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + Preconditions.checkState(ensembleSize == ensembleStreams.length, String.format( + "Ensemble size ({}) and number of enseble streams ({}) do not match.", ensembleSize, ensembleStreams.length)); + InstanceContentEvent inEvent = (InstanceContentEvent) event; + + if (inEvent.getInstanceIndex() < 0) { + // end learning + for (Stream stream : ensembleStreams) + stream.put(event); + return false; + } + + if (inEvent.isTesting()) { + Instance testInstance = inEvent.getInstance(); + for (int i = 0; i < ensembleSize; i++) { + Instance instanceCopy = testInstance.copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy, + false, true); + instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore + ensembleStreams[i].put(instanceContentEvent); + } + } + + // estimate model parameters using the training data + if (inEvent.isTraining()) { + train(inEvent); + } + return true; + } + + /** + * Train. + * + * @param inEvent + * the in event + */ + protected void train(InstanceContentEvent inEvent) { + Instance trainInstance = inEvent.getInstance(); + for (int i = 0; i < ensembleSize; i++) { + int k = MiscUtils.poisson(1.0, this.random); + if (k > 0) { + Instance weightedInstance = trainInstance.copy(); + weightedInstance.setWeight(trainInstance.weight() * k); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), + weightedInstance, true, false); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + ensembleStreams[i].put(instanceContentEvent); + } + } + } + + @Override + public void onCreate(int id) { + // do nothing + } + + public Stream[] getOutputStreams() { + return ensembleStreams; + } + + public void setOutputStreams(Stream[] ensembleStreams) { + this.ensembleStreams = ensembleStreams; + } + + public int getEnsembleSize() { + return ensembleSize; + } + + public void setEnsembleSize(int ensembleSize) { + this.ensembleSize = ensembleSize; + } + + @Override + public Processor newProcessor(Processor sourceProcessor) { + BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor(); + BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStreams() != null) { + newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(), + originProcessor.getOutputStreams().length)); + } + newProcessor.setEnsembleSize(originProcessor.getEnsembleSize()); + /* + * if (originProcessor.getLearningCurve() != null){ + * newProcessor.setLearningCurve((LearningCurve) + * originProcessor.getLearningCurve().copy()); } + */ + return newProcessor; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Boosting.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Boosting.java new file mode 100644 index 00000000000..f765cbb4f69 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Boosting.java @@ -0,0 +1,149 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.ClassificationLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.SingleClassifier; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * The Bagging Classifier by Oza and Russell. + */ +public class Boosting implements ClassificationLearner, Configurable { + + /** The Constant serialVersionUID. */ + private static final long serialVersionUID = -2971850264864952099L; + + /** The base learner option. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, SingleClassifier.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + /** The distributor processor. */ + private BoostingDistributorProcessor distributorP; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner classifier; + + protected int parallelism; + + /** + * Sets the layout. + */ + protected void setLayout() { + + int sizeEnsemble = this.ensembleSizeOption.getValue(); + + distributorP = new BoostingDistributorProcessor(); + distributorP.setEnsembleSize(sizeEnsemble); + this.builder.addProcessor(distributorP, 1); + + // instantiate classifier + classifier = this.baseLearnerOption.getValue(); + classifier.init(builder, this.dataset, sizeEnsemble); + + BoostingPredictionCombinerProcessor predictionCombinerP = new BoostingPredictionCombinerProcessor(); + predictionCombinerP.setEnsembleSize(sizeEnsemble); + this.builder.addProcessor(predictionCombinerP, 1); + + // Streams + resultStream = this.builder.createStream(predictionCombinerP); + predictionCombinerP.setOutputStream(resultStream); + + for (Stream subResultStream : classifier.getResultStreams()) { + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); + } + + /* The testing stream. */ + Stream testingStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor()); + + /* The prediction stream. */ + Stream predictionStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor()); + +// distributorP.setOutputStream(testingStream); +// distributorP.setPredictionStream(predictionStream); + + // Addition to Bagging: stream to train + /* The training stream. */ + Stream trainingStream = this.builder.createStream(predictionCombinerP); + predictionCombinerP.setTrainingStream(trainingStream); + this.builder.connectInputKeyStream(trainingStream, classifier.getInputProcessor()); + + } + + /** The builder. */ + private TopologyBuilder builder; + + /* + * (non-Javadoc) + * + * @see samoa.classifiers.Classifier#init(samoa.engines.Engine, + * samoa.core.Stream, weka.core.Instances) + */ + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributorP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BoostingDistributorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BoostingDistributorProcessor.java new file mode 100644 index 00000000000..80ee8fd127d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BoostingDistributorProcessor.java @@ -0,0 +1,34 @@ +package org.apache.heron.learners.classifiers.ensemble; + +import org.apache.heron.learners.InstanceContentEvent; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * The Class BoostingDistributorProcessor. + */ +public class BoostingDistributorProcessor extends BaggingDistributorProcessor { + + @Override + protected void train(InstanceContentEvent inEvent) { + // Boosting is trained from the prediction combiner, not from the input + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java new file mode 100644 index 00000000000..92dac338099 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java @@ -0,0 +1,178 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.Utils; +import org.apache.samoa.topology.Stream; + +/** + * The Class BoostingPredictionCombinerProcessor. + */ +public class BoostingPredictionCombinerProcessor extends PredictionCombinerProcessor { + + private static final long serialVersionUID = -1606045723451191232L; + + // Weigths classifier + protected double[] scms; + + // Weights instance + protected double[] swms; + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + ResultContentEvent inEvent = (ResultContentEvent) event; + double[] prediction = inEvent.getClassVotes(); + int instanceIndex = (int) inEvent.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); + // Boosting + addPredictions(instanceIndex, inEvent, prediction); + + if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null) { + combinedVote = new DoubleVector(); + } + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), + inEvent.getInstance(), inEvent.getClassId(), + combinedVote.getArrayCopy(), inEvent.isLastEvent()); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + clearStatisticsInstance(instanceIndex); + // Boosting + computeBoosting(inEvent, instanceIndex); + return true; + } + return false; + + } + + protected Random random; + + protected int trainingWeightSeenByModel; + + @Override + protected double getEnsembleMemberWeight(int i) { + double em = this.swms[i] / (this.scms[i] + this.swms[i]); + if ((em == 0.0) || (em > 0.5)) { + return 0.0; + } + double Bm = em / (1.0 - em); + return Math.log(1.0 / Bm); + } + + @Override + public void reset() { + this.random = new Random(); + this.trainingWeightSeenByModel = 0; + this.scms = new double[this.ensembleSize]; + this.swms = new double[this.ensembleSize]; + } + + private boolean correctlyClassifies(int i, Instance inst, int instanceIndex) { + int predictedClass = (int) mapPredictions.get(instanceIndex).getValue(i); + return predictedClass == (int) inst.classValue(); + } + + protected Map mapPredictions; + + private void addPredictions(int instanceIndex, ResultContentEvent inEvent, double[] prediction) { + if (this.mapPredictions == null) { + this.mapPredictions = new HashMap<>(); + } + DoubleVector predictions = this.mapPredictions.get(instanceIndex); + if (predictions == null) { + predictions = new DoubleVector(); + } + predictions.setValue(inEvent.getClassifierIndex(), Utils.maxIndex(prediction)); + this.mapPredictions.put(instanceIndex, predictions); + } + + private void computeBoosting(ResultContentEvent inEvent, int instanceIndex) { + // Starts code for Boosting + // Send instances to train + double lambda_d = 1.0; + for (int i = 0; i < this.ensembleSize; i++) { + double k = lambda_d; + Instance inst = inEvent.getInstance(); + if (k > 0.0) { + Instance weightedInst = inst.copy(); + weightedInst.setWeight(inst.weight() * k); + // this.ensemble[i].trainOnInstance(weightedInst); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent( + inEvent.getInstanceIndex(), weightedInst, true, false); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + trainingStream.put(instanceContentEvent); + } + if (this.correctlyClassifies(i, inst, instanceIndex)) { + this.scms[i] += lambda_d; + lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]); + } else { + this.swms[i] += lambda_d; + lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]); + } + } + } + + /** + * Gets the training stream. + * + * @return the training stream + */ + public Stream getTrainingStream() { + return trainingStream; + } + + /** + * Sets the training stream. + * + * @param trainingStream + * the new training stream + */ + public void setTrainingStream(Stream trainingStream) { + this.trainingStream = trainingStream; + } + + /** The training stream. */ + private Stream trainingStream; + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/PredictionCombinerProcessor.java new file mode 100644 index 00000000000..b68c96b0f92 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/PredictionCombinerProcessor.java @@ -0,0 +1,184 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * License + */ +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.topology.Stream; + +/** + * Combines predictions coming from an ensemble. Equivalent to a majority-vote classifier. + */ +public class PredictionCombinerProcessor implements Processor { + + private static final long serialVersionUID = -1606045723451191132L; + + /** + * The ensemble size. + */ + protected int ensembleSize; + + /** + * The output stream. + */ + protected Stream outputStream; + + /** + * Sets the output stream. + * + * @param stream + * the new output stream + */ + public void setOutputStream(Stream stream) { + outputStream = stream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the size ensemble. + * + * @return the ensembleSize + */ + public int getEnsembleSize() { + return ensembleSize; + } + + /** + * Sets the size ensemble. + * + * @param ensembleSize + * the new size ensemble + */ + public void setEnsembleSize(int ensembleSize) { + this.ensembleSize = ensembleSize; + } + + protected Map mapCountsforInstanceReceived; + + protected Map mapVotesforInstanceReceived; + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + + ResultContentEvent inEvent = (ResultContentEvent) event; + double[] prediction = inEvent.getClassVotes(); + int instanceIndex = (int) inEvent.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); + if (hasAllVotesArrivedInstance(instanceIndex)) { + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null) { + combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]); + } + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), combinedVote.getArrayCopy(), inEvent.isLastEvent()); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + clearStatisticsInstance(instanceIndex); + return true; + } + return false; + + } + + @Override + public void onCreate(int id) { + this.reset(); + } + + public void reset() { + } + + /* + * (non-Javadoc) + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor(); + PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) { + newProcessor.setOutputStream(originProcessor.getOutputStream()); + } + newProcessor.setEnsembleSize(originProcessor.getEnsembleSize()); + return newProcessor; + } + + protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) { + if (this.mapCountsforInstanceReceived == null) { + this.mapCountsforInstanceReceived = new HashMap<>(); + this.mapVotesforInstanceReceived = new HashMap<>(); + } + DoubleVector vote = new DoubleVector(prediction); + if (vote.sumOfValues() > 0.0) { + vote.normalize(); + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null) { + combinedVote = new DoubleVector(); + } + vote.scaleValues(getEnsembleMemberWeight(classifierIndex)); + combinedVote.addValues(vote); + + this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote); + } + Integer count = this.mapCountsforInstanceReceived.get(instanceIndex); + if (count == null) { + count = 0; + } + this.mapCountsforInstanceReceived.put(instanceIndex, count + add); + } + + protected boolean hasAllVotesArrivedInstance(int instanceIndex) { + return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize); + } + + protected void clearStatisticsInstance(int instanceIndex) { + this.mapCountsforInstanceReceived.remove(instanceIndex); + this.mapVotesforInstanceReceived.remove(instanceIndex); + } + + protected double getEnsembleMemberWeight(int i) { + return 1.0; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Sharding.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Sharding.java new file mode 100644 index 00000000000..27534c5915d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/Sharding.java @@ -0,0 +1,142 @@ +package org.apache.heron.learners.classifiers.ensemble; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; +import com.google.common.collect.ImmutableSet; + +/** + * Simple sharding meta-classifier. It trains an ensemble of learners by shuffling the training stream among them, so + * that each learner is completely independent from each other. + */ +public class Sharding implements Learner, Configurable { + + private static final long serialVersionUID = -2971850264864952099L; + private static final Logger logger = LoggerFactory.getLogger(Sharding.class); + + /** The base learner class. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + /** The distributor processor. */ + private ShardingDistributorProcessor distributor; + + /** The input streams for the ensemble, one per member. */ + private Stream[] ensembleStreams; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner[] ensemble; + + /** + * Sets the layout. + */ + protected void setLayout() { + + int ensembleSize = this.ensembleSizeOption.getValue(); + + distributor = new ShardingDistributorProcessor(); + distributor.setEnsembleSize(ensembleSize); + this.builder.addProcessor(distributor, 1); + + // instantiate classifier + ensemble = new Learner[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + try { + ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(), + baseLearnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create members of the ensemble. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].init(builder, this.dataset, 1); // sequential + } + + PredictionCombinerProcessor predictionCombiner = new PredictionCombinerProcessor(); + predictionCombiner.setEnsembleSize(ensembleSize); + this.builder.addProcessor(predictionCombiner, 1); + + // Streams + resultStream = this.builder.createStream(predictionCombiner); + predictionCombiner.setOutputStream(resultStream); + + for (Learner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, predictionCombiner); // the key is the instance id to combine predictions + } + } + + ensembleStreams = new Stream[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + ensembleStreams[i] = builder.createStream(distributor); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } + + distributor.setOutputStreams(ensembleStreams); + } + + /** The builder. */ + private TopologyBuilder builder; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributor; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/ShardingDistributorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/ShardingDistributorProcessor.java new file mode 100644 index 00000000000..4addfb9e00a --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/ensemble/ShardingDistributorProcessor.java @@ -0,0 +1,160 @@ +package org.apache.heron.learners.classifiers.ensemble; + +import java.util.Arrays; +import java.util.Random; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * License + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.topology.Stream; + +/** + * The Class BaggingDistributorPE. + */ +public class ShardingDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192730L; + + /** The ensemble size. */ + private int ensembleSize; + + /** The stream ensemble. */ + private Stream[] ensembleStreams; + + /** Ramdom number generator. */ + protected Random random = new Random(); //TODO make random seed configurable + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + InstanceContentEvent inEvent = (InstanceContentEvent) event; + if (inEvent.isLastEvent()) { + // end learning + for (Stream stream : ensembleStreams) + stream.put(event); + return false; + } + + if (inEvent.isTesting()) { + Instance testInstance = inEvent.getInstance(); + for (int i = 0; i < ensembleSize; i++) { + Instance instanceCopy = testInstance.copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy, + false, true); + instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore + ensembleStreams[i].put(instanceContentEvent); + } + } + + // estimate model parameters using the training data + if (inEvent.isTraining()) { + train(inEvent); + } + return false; + } + + /** + * Train. + * + * @param inEvent + * the in event + */ + protected void train(InstanceContentEvent inEvent) { + Instance trainInst = inEvent.getInstance().copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), trainInst, + true, false); + int i = random.nextInt(ensembleSize); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + ensembleStreams[i].put(instanceContentEvent); + } + + /* + * (non-Javadoc) + * + * @see org.apache.s4.core.ProcessingElement#onCreate() + */ + @Override + public void onCreate(int id) { + // do nothing + } + + public Stream[] getOutputStreams() { + return ensembleStreams; + } + + public void setOutputStreams(Stream[] ensembleStreams) { + this.ensembleStreams = ensembleStreams; + } + + /** + * Gets the size ensemble. + * + * @return the size ensemble + */ + public int getEnsembleSize() { + return ensembleSize; + } + + /** + * Sets the size ensemble. + * + * @param ensembleSize + * the new size ensemble + */ + public void setEnsembleSize(int ensembleSize) { + this.ensembleSize = ensembleSize; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + ShardingDistributorProcessor newProcessor = new ShardingDistributorProcessor(); + ShardingDistributorProcessor originProcessor = (ShardingDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStreams() != null) { + newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(), + originProcessor.getOutputStreams().length)); + } + newProcessor.setEnsembleSize(originProcessor.getEnsembleSize()); + /* + * if (originProcessor.getLearningCurve() != null){ + * newProcessor.setLearningCurve((LearningCurve) + * originProcessor.getLearningCurve().copy()); } + */ + return newProcessor; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/AMRulesRegressor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/AMRulesRegressor.java new file mode 100644 index 00000000000..71a426ac185 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/AMRulesRegressor.java @@ -0,0 +1,177 @@ +package org.apache.heron.learners.classifiers.rules; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.rules.centralized.AMRulesRegressorProcessor; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.Configurable; +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; + +/** + * AMRules Regressor is the task for the serialized implementation of AMRules algorithm for regression rule. It is + * adapted to SAMOA from the implementation of AMRules in MOA. + * + * @author Anh Thu Vu + * + */ + +public class AMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 1L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 0.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[] { + "Adaptative", "Perceptron", "Target Mean" }, new String[] { + "Adaptative", "Perceptron", "Target Mean" }, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public ClassOption votingTypeOption = new ClassOption("votingType", + 'V', "Voting Type.", + ErrorWeightedVote.class, + "InverseErrorWeightedVote"); + + // Processor + private AMRulesRegressorProcessor processor; + + // Stream + private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + this.processor = new AMRulesRegressorProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .voteType((ErrorWeightedVote) votingTypeOption.getValue()) + .build(); + + topologyBuilder.addProcessor(processor, parallelism); + + this.resultStream = topologyBuilder.createStream(processor); + this.processor.setResultStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + return processor; + } + + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/HorizontalAMRulesRegressor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/HorizontalAMRulesRegressor.java new file mode 100644 index 00000000000..c2dc769410d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/HorizontalAMRulesRegressor.java @@ -0,0 +1,239 @@ +package org.apache.heron.learners.classifiers.rules; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRDefaultRuleProcessor; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRLearnerProcessor; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRRuleSetProcessor; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; + +/** + * Horizontal AMRules Regressor is a distributed learner for regression rules learner. It applies both horizontal + * parallelism (dividing incoming streams) and vertical parallelism on AMRules algorithm. + * + * @author Anh Thu Vu + * + */ +public class HorizontalAMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 2785944439173586051L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 0.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[] { + "Adaptative", "Perceptron", "Target Mean" }, new String[] { + "Adaptative", "Perceptron", "Target Mean" }, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public MultiChoiceOption votingTypeOption = new MultiChoiceOption( + "votingType", 'V', "Voting Type.", new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, 0); + + public IntOption learnerParallelismOption = new IntOption( + "leanerParallelism", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + public IntOption ruleSetParallelismOption = new IntOption( + "modelParallelism", + 'r', + "The number of replicated model (rule set) PIs", + 1, 1, Integer.MAX_VALUE); + + // Processor + private AMRRuleSetProcessor model; + + private Stream modelResultStream; + + private Stream rootResultStream; + + // private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + // Create MODEL PIs + this.model = new AMRRuleSetProcessor.Builder(dataset) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .voteType(votingTypeOption.getChosenIndex()) + .build(); + + topologyBuilder.addProcessor(model, this.ruleSetParallelismOption.getValue()); + + // MODEL PIs streams + Stream forwardToRootStream = topologyBuilder.createStream(this.model); + Stream forwardToLearnerStream = topologyBuilder.createStream(this.model); + this.modelResultStream = topologyBuilder.createStream(this.model); + + this.model.setDefaultRuleStream(forwardToRootStream); + this.model.setStatisticsStream(forwardToLearnerStream); + this.model.setResultStream(this.modelResultStream); + + // Create DefaultRule PI + AMRDefaultRuleProcessor root = new AMRDefaultRuleProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .build(); + + topologyBuilder.addProcessor(root); + + // Default Rule PI streams + Stream newRuleStream = topologyBuilder.createStream(root); + this.rootResultStream = topologyBuilder.createStream(root); + + root.setRuleStream(newRuleStream); + root.setResultStream(this.rootResultStream); + + // Create Learner PIs + AMRLearnerProcessor learner = new AMRLearnerProcessor.Builder(dataset) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .build(); + + topologyBuilder.addProcessor(learner, this.learnerParallelismOption.getValue()); + + Stream predicateStream = topologyBuilder.createStream(learner); + learner.setOutputStream(predicateStream); + + // Connect streams + // to MODEL + topologyBuilder.connectInputAllStream(newRuleStream, this.model); + topologyBuilder.connectInputAllStream(predicateStream, this.model); + // to ROOT + topologyBuilder.connectInputShuffleStream(forwardToRootStream, root); + // to LEARNER + topologyBuilder.connectInputKeyStream(forwardToLearnerStream, learner); + topologyBuilder.connectInputAllStream(newRuleStream, learner); + } + + @Override + public Processor getInputProcessor() { + return model; + } + + @Override + public Set getResultStreams() { + Set streams = ImmutableSet.of(this.modelResultStream, this.rootResultStream); + return streams; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/VerticalAMRulesRegressor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/VerticalAMRulesRegressor.java new file mode 100644 index 00000000000..2d9870480d1 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/VerticalAMRulesRegressor.java @@ -0,0 +1,200 @@ +package org.apache.heron.learners.classifiers.rules; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRulesAggregatorProcessor; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRulesStatisticsProcessor; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.google.common.collect.ImmutableSet; + +/** + * Vertical AMRules Regressor is a distributed learner for regression rules learner. It applies vertical parallelism on + * AMRules regressor. + * + * @author Anh Thu Vu + * + */ + +public class VerticalAMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 2785944439173586051L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 00.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[] { + "Adaptative", "Perceptron", "Target Mean" }, new String[] { + "Adaptative", "Perceptron", "Target Mean" }, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public MultiChoiceOption votingTypeOption = new MultiChoiceOption( + "votingType", 'V', "Voting Type.", new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, 0); + + public IntOption parallelismHintOption = new IntOption( + "parallelismHint", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + + // Processor + private AMRulesAggregatorProcessor aggregator; + + // Stream + private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + this.aggregator = new AMRulesAggregatorProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .voteType(votingTypeOption.getChosenIndex()) + .build(); + + topologyBuilder.addProcessor(aggregator); + + Stream statisticsStream = topologyBuilder.createStream(aggregator); + this.resultStream = topologyBuilder.createStream(aggregator); + + this.aggregator.setResultStream(resultStream); + this.aggregator.setStatisticsStream(statisticsStream); + + AMRulesStatisticsProcessor learner = new AMRulesStatisticsProcessor.Builder(dataset) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .build(); + + topologyBuilder.addProcessor(learner, this.parallelismHintOption.getValue()); + + topologyBuilder.connectInputKeyStream(statisticsStream, learner); + + Stream predicateStream = topologyBuilder.createStream(learner); + learner.setOutputStream(predicateStream); + + topologyBuilder.connectInputShuffleStream(predicateStream, aggregator); + } + + @Override + public Processor getInputProcessor() { + return aggregator; + } + + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java new file mode 100644 index 00000000000..239ad6448e7 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java @@ -0,0 +1,512 @@ +package org.apache.heron.learners.classifiers.rules.centralized; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.Perceptron; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.topology.Stream; + +/** + * AMRules Regressor Processor is the main (and only) processor for AMRulesRegressor task. It is adapted from the + * AMRules implementation in MOA. + * + * @author Anh Thu Vu + * + */ +public class AMRulesRegressorProcessor implements Processor { + /** + * + */ + private static final long serialVersionUID = 1L; + + private int processorId; + + // Rules & default rule + protected List ruleSet; + protected ActiveRule defaultRule; + protected int ruleNumberID; + protected double[] statistics; + + // SAMOA Stream + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + protected ErrorWeightedVote voteType; + + /* + * Constructor + */ + public AMRulesRegressorProcessor(Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.numericObserver = builder.numericObserver; + this.voteType = builder.voteType; + } + + /* + * Process + */ + @Override + public boolean process(ContentEvent event) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + + // predict + if (instanceEvent.isTesting()) { + this.predictOnInstance(instanceEvent); + } + + // train + if (instanceEvent.isTraining()) { + this.trainOnInstance(instanceEvent); + } + + return true; + } + + /* + * Prediction + */ + private void predictOnInstance(InstanceContentEvent instanceEvent) { + double[] prediction = getVotesForInstance(instanceEvent.getInstance()); + ResultContentEvent rce = newResultContentEvent(prediction, instanceEvent); + resultStream.put(rce); + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and its prediction result. + * + * @param prediction + * The predicted class label from the decision tree model. + * @param inEvent + * The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + /** + * getVotesForInstance extension of the instance method getVotesForInstance in moa.classifier.java returns the + * prediction of the instance. Called in EvaluateModelRegression + */ + private double[] getVotesForInstance(Instance instance) { + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + int numberOfRulesCovering = 0; + + for (ActiveRule rule : ruleSet) { + if (rule.isCovering(instance) == true) { + numberOfRulesCovering++; + double[] vote = rule.getPrediction(instance); + double error = rule.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) { // Ordered Rules Option. + break; // Only one rule cover the instance. + } + } + } + + if (numberOfRulesCovering == 0) { + double[] vote = defaultRule.getPrediction(instance); + double error = defaultRule.getCurrentError(); + errorWeightedVote.addVote(vote, error); + } + double[] weightedVote = errorWeightedVote.computeWeightedVote(); + + return weightedVote; + } + + public ErrorWeightedVote newErrorWeightedVote() { + return voteType.getACopy(); + } + + /* + * Training + */ + private void trainOnInstance(InstanceContentEvent instanceEvent) { + this.trainOnInstanceImpl(instanceEvent.getInstance()); + } + + public void trainOnInstanceImpl(Instance instance) { + /* + AMRules Algorithm + + For each rule in the rule set + If rule covers the instance + if the instance is not an anomaly + Update Change Detection Tests + Compute prediction error + Call PHTest + If change is detected then + Remove rule + Else + Update sufficient statistics of rule + If number of examples in rule > Nmin + Expand rule + If ordered set then + break + If none of the rule covers the instance + Update sufficient statistics of default rule + If number of examples in default rule is multiple of Nmin + Expand default rule and add it to the set of rules + Reset the default rule + */ + boolean rulesCoveringInstance = false; + Iterator ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.isCovering(instance) == true) { + rulesCoveringInstance = true; + if (isAnomaly(instance, rule) == false) { + // Update Change Detection Tests + double error = rule.computeError(instance); // Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) { + rule.split(); + } + } + } + if (!this.unorderedRules) + break; + } + } + } + + if (rulesCoveringInstance == false) { + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), + (RuleActiveRegressionNode) defaultRule.getLearningNode(), + ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + this.ruleSet.add(this.defaultRule); + + defaultRule = newDefaultRule; + + } + } + } + } + + /** + * Method to verify if the instance is an anomaly. + * + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, ActiveRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Create new rules + */ + // TODO check this after finish rule, LN + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r = newRule(ID); + + if (node != null) + { + if (node.getPerceptron() != null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics == null) + { + double mean; + if (node.getNodeStatistics().getValue(0) > 0) { + mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) + { + double mean; + if (statistics[0] > 0) { + mean = statistics[1] / statistics[0]; + ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r = new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + /* + * Init processor + */ + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics = new double[] { 0.0, 0, 0 }; + this.ruleNumberID = 0; + this.defaultRule = newRule(++this.ruleNumberID); + + this.ruleSet = new LinkedList(); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRulesRegressorProcessor oldProcessor = (AMRulesRegressorProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRulesRegressorProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + return newProcessor; + } + + /* + * Output stream + */ + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Others + */ + public boolean isRandomizable() { + return true; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private ErrorWeightedVote voteType; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesRegressorProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.numericObserver = processor.numericObserver; + this.voteType = processor.voteType; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder voteType(ErrorWeightedVote voteType) { + this.voteType = voteType; + return this; + } + + public AMRulesRegressorProcessor build() { + return new AMRulesRegressorProcessor(this); + } + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/ActiveRule.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/ActiveRule.java new file mode 100644 index 00000000000..0cd079c9312 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/ActiveRule.java @@ -0,0 +1,228 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.io.Serializable; + +import org.apache.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import org.apache.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate; + +/** + * ActiveRule is a LearningRule that actively update its LearningNode with incoming instances. + * + * @author Anh Thu Vu + * + */ + +public class ActiveRule extends LearningRule { + + private static final long serialVersionUID = 1L; + + private double[] statisticsOtherBranchSplit; + + private Builder builder; + + private RuleActiveRegressionNode learningNode; + + private RuleSplitNode lastUpdatedRuleSplitNode; + + /* + * Constructor with Builder + */ + public ActiveRule() { + super(); + this.builder = null; + this.learningNode = null; + this.ruleNumberID = 0; + } + + public ActiveRule(Builder builder) { + super(); + this.setBuilder(builder); + this.learningNode = newRuleActiveLearningNode(builder); + // JD - use builder ID to set ruleNumberID + this.ruleNumberID = builder.id; + } + + private RuleActiveRegressionNode newRuleActiveLearningNode(Builder builder) { + return new RuleActiveRegressionNode(builder); + } + + /* + * Setters & getters + */ + public Builder getBuilder() { + return builder; + } + + public void setBuilder(Builder builder) { + this.builder = builder; + } + + @Override + public RuleRegressionNode getLearningNode() { + return this.learningNode; + } + + @Override + public void setLearningNode(RuleRegressionNode learningNode) { + this.learningNode = (RuleActiveRegressionNode) learningNode; + } + + public double[] statisticsOtherBranchSplit() { + return this.statisticsOtherBranchSplit; + } + + public RuleSplitNode getLastUpdatedRuleSplitNode() { + return this.lastUpdatedRuleSplitNode; + } + + /* + * Builder + */ + public static class Builder implements Serializable { + + private static final long serialVersionUID = 1712887264918475622L; + protected boolean changeDetection; + protected boolean usePerceptron; + protected double threshold; + protected double alpha; + protected int predictionFunction; + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double[] statistics; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + protected double lastTargetMean; + + public int id; + + public Builder() { + } + + public Builder changeDetection(boolean changeDetection) { + this.changeDetection = changeDetection; + return this; + } + + public Builder threshold(double threshold) { + this.threshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.alpha = alpha; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder statistics(double[] statistics) { + this.statistics = statistics; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantLearningRatioDecay) { + this.constantLearningRatioDecay = constantLearningRatioDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder id(int id) { + this.id = id; + return this; + } + + public ActiveRule build() { + return new ActiveRule(this); + } + + } + + /** + * Try to Expand method. + * + * @param splitConfidence + * @param tieThreshold + * @return + */ + public boolean tryToExpand(double splitConfidence, double tieThreshold) { + + boolean shouldSplit = this.learningNode.tryToExpand(splitConfidence, tieThreshold); + return shouldSplit; + + } + + // JD: Only call after tryToExpand returning true + public void split() + { + // this.statisticsOtherBranchSplit = + // this.learningNode.getStatisticsOtherBranchSplit(); + // create a split node, + int splitIndex = this.learningNode.getSplitIndex(); + InstanceConditionalTest st = this.learningNode.getBestSuggestion().splitTest; + if (st instanceof NumericAttributeBinaryTest) { + NumericAttributeBinaryTest splitTest = (NumericAttributeBinaryTest) st; + NumericAttributeBinaryRulePredicate predicate = new NumericAttributeBinaryRulePredicate( + splitTest.getAttsTestDependsOn()[0], splitTest.getSplitValue(), + splitIndex + 1); + lastUpdatedRuleSplitNode = new RuleSplitNode(predicate, this.learningNode.getStatisticsBranchSplit()); + if (this.nodeListAdd(lastUpdatedRuleSplitNode)) { + // create a new learning node + RuleActiveRegressionNode newLearningNode = newRuleActiveLearningNode(this.getBuilder().statistics( + this.learningNode.getStatisticsNewRuleActiveLearningNode())); + newLearningNode.initialize(this.learningNode); + this.learningNode = newLearningNode; + } + } + else + throw new UnsupportedOperationException("AMRules (currently) only supports numerical attributes."); + } + + // protected void debug(String string, int level) { + // if (this.amRules.VerbosityOption.getValue()>=level) { + // System.out.println(string); + // } + // } + + /** + * MOA GUI output + */ + @Override + public void getDescription(StringBuilder sb, int indent) { + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/LearningRule.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/LearningRule.java new file mode 100644 index 00000000000..80b23e71ca7 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/LearningRule.java @@ -0,0 +1,122 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.StringUtils; + +/** + * Rule with LearningNode (statistical data). + * + * @author Anh Thu Vu + * + */ +public abstract class LearningRule extends Rule { + + /** + * + */ + private static final long serialVersionUID = 1L; + + /* + * Constructor + */ + public LearningRule() { + super(); + } + + /* + * LearningNode + */ + public abstract RuleRegressionNode getLearningNode(); + + public abstract void setLearningNode(RuleRegressionNode learningNode); + + /* + * No. of instances seen + */ + public long getInstancesSeen() { + return this.getLearningNode().getInstancesSeen(); + } + + /* + * Error and change detection + */ + public double computeError(Instance instance) { + return this.getLearningNode().computeError(instance); + } + + /* + * Prediction + */ + public double[] getPrediction(Instance instance, int mode) { + return this.getLearningNode().getPrediction(instance, mode); + } + + public double[] getPrediction(Instance instance) { + return this.getLearningNode().getPrediction(instance); + } + + public double getCurrentError() { + return this.getLearningNode().getCurrentError(); + } + + /* + * Anomaly detection + */ + public boolean isAnomaly(Instance instance, + double uniVariateAnomalyProbabilityThreshold, + double multiVariateAnomalyProbabilityThreshold, + int numberOfInstanceesForAnomaly) { + return this.getLearningNode().isAnomaly(instance, uniVariateAnomalyProbabilityThreshold, + multiVariateAnomalyProbabilityThreshold, + numberOfInstanceesForAnomaly); + } + + /* + * Update + */ + public void updateStatistics(Instance instance) { + this.getLearningNode().updateStatistics(instance); + } + + public String printRule() { + StringBuilder out = new StringBuilder(); + int indent = 1; + StringUtils.appendIndented(out, indent, "Rule Nr." + this.ruleNumberID + " Instances seen:" + + this.getLearningNode().getInstancesSeen() + "\n"); // AC + for (RuleSplitNode node : nodeList) { + StringUtils.appendIndented(out, indent, node.getSplitTest().toString()); + StringUtils.appendIndented(out, indent, " "); + StringUtils.appendIndented(out, indent, node.toString()); + } + DoubleVector pred = new DoubleVector(this.getLearningNode().getSimplePrediction()); + StringUtils.appendIndented(out, 0, " --> y: " + pred.toString()); + StringUtils.appendNewline(out); + + if (getLearningNode().perceptron != null) { + ((RuleActiveRegressionNode) this.getLearningNode()).perceptron.getModelDescription(out, 0); + StringUtils.appendNewline(out); + } + return (out.toString()); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/NonLearningRule.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/NonLearningRule.java new file mode 100644 index 00000000000..efe1383543d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/NonLearningRule.java @@ -0,0 +1,51 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * The most basic rule: inherit from Rule the ID and list of features. + * + * @author Anh Thu Vu + * + */ +/* + * This branch (Non-learning rule) was created for an old implementation. + * Probably should remove None-Learning and Learning Rule classes, merge Rule + * with LearningRule. + */ +public class NonLearningRule extends Rule { + + /** + * + */ + private static final long serialVersionUID = -1210907339230307784L; + + public NonLearningRule(ActiveRule rule) { + this.nodeList = rule.nodeList; + this.ruleNumberID = rule.ruleNumberID; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // do nothing + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/PassiveRule.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/PassiveRule.java new file mode 100644 index 00000000000..5c571a6374f --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/PassiveRule.java @@ -0,0 +1,70 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.LinkedList; + +/** + * PassiveRule is a LearningRule that update its LearningNode with the received new LearningNode. + * + * @author Anh Thu Vu + * + */ +public class PassiveRule extends LearningRule { + + /** + * + */ + private static final long serialVersionUID = -5551571895910530275L; + + private RulePassiveRegressionNode learningNode; + + /* + * Constructor to turn an ActiveRule into a PassiveRule + */ + public PassiveRule(ActiveRule rule) { + this.nodeList = new LinkedList<>(); + for (RuleSplitNode node : rule.nodeList) { + this.nodeList.add(node.getACopy()); + } + + this.learningNode = new RulePassiveRegressionNode(rule.getLearningNode()); + this.ruleNumberID = rule.ruleNumberID; + } + + @Override + public RuleRegressionNode getLearningNode() { + return this.learningNode; + } + + @Override + public void setLearningNode(RuleRegressionNode learningNode) { + this.learningNode = (RulePassiveRegressionNode) learningNode; + } + + /* + * MOA GUI + */ + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/Perceptron.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/Perceptron.java new file mode 100644 index 00000000000..bede56b2a65 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/Perceptron.java @@ -0,0 +1,487 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.io.Serializable; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.AbstractClassifier; +import org.apache.samoa.moa.classifiers.Regressor; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.Measurement; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +/** + * Prediction scheme using Perceptron: Predictions are computed according to a linear function of the attributes. + * + * @author Anh Thu Vu + * + */ +public class Perceptron extends AbstractClassifier implements Regressor { + + private final double SD_THRESHOLD = 0.0000001; // THRESHOLD for normalizing attribute and target values + + private static final long serialVersionUID = 1L; + + // public FlagOption constantLearningRatioDecayOption = new FlagOption( + // "learningRatio_Decay_set_constant", 'd', + // "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + // + // public FloatOption learningRatioOption = new FloatOption( + // "learningRatio", 'l', + // "Constante Learning Ratio to use for training the Perceptrons in the leaves.", + // 0.01); + // + // public FloatOption learningRateDecayOption = new FloatOption( + // "learningRateDecay", 'm', + // " Learning Rate decay to use for training the Perceptron.", 0.001); + // + // public FloatOption fadingFactorOption = new FloatOption( + // "fadingFactor", 'e', + // "Fading factor for the Perceptron accumulated error", 0.99, 0, 1); + + protected boolean constantLearningRatioDecay; + protected double originalLearningRatio; + + private double nError; + protected double fadingFactor = 0.99; + private double learningRatio; + protected double learningRateDecay = 0.001; + + // The Perception weights + protected double[] weightAttribute; + + // Statistics used for error calculations + public DoubleVector perceptronattributeStatistics = new DoubleVector(); + public DoubleVector squaredperceptronattributeStatistics = new DoubleVector(); + + // The number of instances contributing to this model + protected int perceptronInstancesSeen; + protected int perceptronYSeen; + + protected double accumulatedError; + + // If the model (weights) should be reset or not + protected boolean initialisePerceptron; + + protected double perceptronsumY; + protected double squaredperceptronsumY; + + public Perceptron() { + this.initialisePerceptron = true; + } + + /* + * Perceptron + */ + public Perceptron(Perceptron p) { + this(p, false); + } + + public Perceptron(Perceptron p, boolean copyAccumulatedError) { + super(); + // this.constantLearningRatioDecayOption = + // p.constantLearningRatioDecayOption; + // this.learningRatioOption = p.learningRatioOption; + // this.learningRateDecayOption=p.learningRateDecayOption; + // this.fadingFactorOption = p.fadingFactorOption; + this.constantLearningRatioDecay = p.constantLearningRatioDecay; + this.originalLearningRatio = p.originalLearningRatio; + if (copyAccumulatedError) + this.accumulatedError = p.accumulatedError; + this.nError = p.nError; + this.fadingFactor = p.fadingFactor; + this.learningRatio = p.learningRatio; + this.learningRateDecay = p.learningRateDecay; + if (p.weightAttribute != null) + this.weightAttribute = p.weightAttribute.clone(); + + this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics); + this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics); + this.perceptronInstancesSeen = p.perceptronInstancesSeen; + + this.initialisePerceptron = p.initialisePerceptron; + this.perceptronsumY = p.perceptronsumY; + this.squaredperceptronsumY = p.squaredperceptronsumY; + this.perceptronYSeen = p.perceptronYSeen; + } + + public Perceptron(PerceptronData p) { + super(); + this.constantLearningRatioDecay = p.constantLearningRatioDecay; + this.originalLearningRatio = p.originalLearningRatio; + this.nError = p.nError; + this.fadingFactor = p.fadingFactor; + this.learningRatio = p.learningRatio; + this.learningRateDecay = p.learningRateDecay; + if (p.weightAttribute != null) + this.weightAttribute = p.weightAttribute.clone(); + + this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics); + this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics); + this.perceptronInstancesSeen = p.perceptronInstancesSeen; + + this.initialisePerceptron = p.initialisePerceptron; + this.perceptronsumY = p.perceptronsumY; + this.squaredperceptronsumY = p.squaredperceptronsumY; + this.perceptronYSeen = p.perceptronYSeen; + this.accumulatedError = p.accumulatedError; + } + + // private void printPerceptron() { + // System.out.println("Learning Ratio:"+this.learningRatio+" ("+this.originalLearningRatio+")"); + // System.out.println("Constant Learning Ratio Decay:"+this.constantLearningRatioDecay+" ("+this.learningRateDecay+")"); + // System.out.println("Error:"+this.accumulatedError+"/"+this.nError); + // System.out.println("Fading factor:"+this.fadingFactor); + // System.out.println("Perceptron Y:"+this.perceptronsumY+"/"+this.squaredperceptronsumY+"/"+this.perceptronYSeen); + // } + + /* + * Weights + */ + public void setWeights(double[] w) { + this.weightAttribute = w; + } + + public double[] getWeights() { + return this.weightAttribute; + } + + /* + * No. of instances seen + */ + public int getInstancesSeen() { + return perceptronInstancesSeen; + } + + public void setInstancesSeen(int pInstancesSeen) { + this.perceptronInstancesSeen = pInstancesSeen; + } + + /** + * A method to reset the model + */ + public void resetLearningImpl() { + this.initialisePerceptron = true; + this.reset(); + } + + public void reset() { + this.nError = 0.0; + this.accumulatedError = 0.0; + this.perceptronInstancesSeen = 0; + this.perceptronattributeStatistics = new DoubleVector(); + this.squaredperceptronattributeStatistics = new DoubleVector(); + this.perceptronsumY = 0.0; + this.squaredperceptronsumY = 0.0; + this.perceptronYSeen = 0; + } + + public void resetError() { + this.nError = 0.0; + this.accumulatedError = 0.0; + } + + /** + * Update the model using the provided instance + */ + public void trainOnInstanceImpl(Instance inst) { + accumulatedError = Math.abs(this.prediction(inst) - inst.classValue()) + fadingFactor * accumulatedError; + nError = 1 + fadingFactor * nError; + // Initialise Perceptron if necessary + if (this.initialisePerceptron) { + // this.fadingFactor=this.fadingFactorOption.getValue(); + // this.classifierRandom.setSeed(randomSeedOption.getValue()); + this.classifierRandom.setSeed(randomSeed); + this.initialisePerceptron = false; // not in resetLearningImpl() because it needs Instance! + this.weightAttribute = new double[inst.numAttributes()]; + for (int j = 0; j < inst.numAttributes(); j++) { + weightAttribute[j] = 2 * this.classifierRandom.nextDouble() - 1; + } + // Update Learning Rate + learningRatio = originalLearningRatio; + // this.learningRateDecay = learningRateDecayOption.getValue(); + + } + + // Update attribute statistics + this.perceptronInstancesSeen++; + this.perceptronYSeen++; + + for (int j = 0; j < inst.numAttributes() - 1; j++) + { + perceptronattributeStatistics.addToValue(j, inst.value(j)); + squaredperceptronattributeStatistics.addToValue(j, inst.value(j) * inst.value(j)); + } + this.perceptronsumY += inst.classValue(); + this.squaredperceptronsumY += inst.classValue() * inst.classValue(); + + if (!constantLearningRatioDecay) { + learningRatio = originalLearningRatio / (1 + perceptronInstancesSeen * learningRateDecay); + } + + this.updateWeights(inst, learningRatio); + // this.printPerceptron(); + } + + /** + * Output the prediction made by this perceptron on the given instance + */ + private double prediction(Instance inst) + { + double[] normalizedInstance = normalizedInstance(inst); + double normalizedPrediction = prediction(normalizedInstance); + return denormalizedPrediction(normalizedPrediction); + } + + public double normalizedPrediction(Instance inst) + { + double[] normalizedInstance = normalizedInstance(inst); + return prediction(normalizedInstance); + } + + private double denormalizedPrediction(double normalizedPrediction) { + if (!this.initialisePerceptron) { + double meanY = perceptronsumY / perceptronYSeen; + double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen); + if (sdY > SD_THRESHOLD) + return normalizedPrediction * sdY + meanY; + else + return normalizedPrediction + meanY; + } + else + return normalizedPrediction; // Perceptron may have been "reseted". Use old weights to predict + + } + + public double prediction(double[] instanceValues) + { + double prediction = 0.0; + if (!this.initialisePerceptron) + { + for (int j = 0; j < instanceValues.length - 1; j++) { + prediction += this.weightAttribute[j] * instanceValues[j]; + } + prediction += this.weightAttribute[instanceValues.length - 1]; + } + return prediction; + } + + public double[] normalizedInstance(Instance inst) { + // Normalize Instance + double[] normalizedInstance = new double[inst.numAttributes()]; + for (int j = 0; j < inst.numAttributes() - 1; j++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(j); + double mean = perceptronattributeStatistics.getValue(j) / perceptronYSeen; + double sd = computeSD(squaredperceptronattributeStatistics.getValue(j), + perceptronattributeStatistics.getValue(j), perceptronYSeen); + if (sd > SD_THRESHOLD) + normalizedInstance[j] = (inst.value(instAttIndex) - mean) / sd; + else + normalizedInstance[j] = inst.value(instAttIndex) - mean; + } + return normalizedInstance; + } + + public double computeSD(double squaredVal, double val, int size) { + if (size > 1) { + return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0)); + } + return 0.0; + } + + public double updateWeights(Instance inst, double learningRatio) { + // Normalize Instance + double[] normalizedInstance = normalizedInstance(inst); + // Compute the Normalized Prediction of Perceptron + double normalizedPredict = prediction(normalizedInstance); + double normalizedY = normalizeActualClassValue(inst); + double sumWeights = 0.0; + double delta = normalizedY - normalizedPredict; + + for (int j = 0; j < inst.numAttributes() - 1; j++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(j); + if (inst.attribute(instAttIndex).isNumeric()) { + this.weightAttribute[j] += learningRatio * delta * normalizedInstance[j]; + sumWeights += Math.abs(this.weightAttribute[j]); + } + } + this.weightAttribute[inst.numAttributes() - 1] += learningRatio * delta; + sumWeights += Math.abs(this.weightAttribute[inst.numAttributes() - 1]); + if (sumWeights > inst.numAttributes()) { // Lasso regression + for (int j = 0; j < inst.numAttributes() - 1; j++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(j); + if (inst.attribute(instAttIndex).isNumeric()) { + this.weightAttribute[j] = this.weightAttribute[j] / sumWeights; + } + } + this.weightAttribute[inst.numAttributes() - 1] = this.weightAttribute[inst.numAttributes() - 1] / sumWeights; + } + + return denormalizedPrediction(normalizedPredict); + } + + public void normalizeWeights() { + double sumWeights = 0.0; + + for (double aWeightAttribute : this.weightAttribute) { + sumWeights += Math.abs(aWeightAttribute); + } + for (int j = 0; j < this.weightAttribute.length; j++) { + this.weightAttribute[j] = this.weightAttribute[j] / sumWeights; + } + } + + private double normalizeActualClassValue(Instance inst) { + double meanY = perceptronsumY / perceptronYSeen; + double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen); + + double normalizedY; + if (sdY > SD_THRESHOLD) { + normalizedY = (inst.classValue() - meanY) / sdY; + } else { + normalizedY = inst.classValue() - meanY; + } + return normalizedY; + } + + @Override + public boolean isRandomizable() { + return true; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return new double[] { this.prediction(inst) }; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + if (this.weightAttribute != null) { + for (int i = 0; i < this.weightAttribute.length - 1; ++i) + { + if (this.weightAttribute[i] >= 0 && i > 0) + out.append(" +" + Math.round(this.weightAttribute[i] * 1000) / 1000.0 + " X" + i); + else + out.append(" " + Math.round(this.weightAttribute[i] * 1000) / 1000.0 + " X" + i); + } + if (this.weightAttribute[this.weightAttribute.length - 1] >= 0) + out.append(" +" + Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000) / 1000.0); + else + out.append(" " + Math.round(this.weightAttribute[this.weightAttribute.length - 1] * 1000) / 1000.0); + } + } + + public void setLearningRatio(double learningRatio) { + this.learningRatio = learningRatio; + + } + + public double getCurrentError() + { + if (nError > 0) + return accumulatedError / nError; + else + return Double.MAX_VALUE; + } + + public static class PerceptronData implements Serializable { + /** + * + */ + private static final long serialVersionUID = 6727623208744105082L; + + private boolean constantLearningRatioDecay; + // If the model (weights) should be reset or not + private boolean initialisePerceptron; + + private double nError; + private double fadingFactor; + private double originalLearningRatio; + private double learningRatio; + private double learningRateDecay; + private double accumulatedError; + private double perceptronsumY; + private double squaredperceptronsumY; + + // The Perception weights + private double[] weightAttribute; + + // Statistics used for error calculations + private DoubleVector perceptronattributeStatistics; + private DoubleVector squaredperceptronattributeStatistics; + + // The number of instances contributing to this model + private int perceptronInstancesSeen; + private int perceptronYSeen; + + public PerceptronData() { + + } + + public PerceptronData(Perceptron p) { + this.constantLearningRatioDecay = p.constantLearningRatioDecay; + this.initialisePerceptron = p.initialisePerceptron; + this.nError = p.nError; + this.fadingFactor = p.fadingFactor; + this.originalLearningRatio = p.originalLearningRatio; + this.learningRatio = p.learningRatio; + this.learningRateDecay = p.learningRateDecay; + this.accumulatedError = p.accumulatedError; + this.perceptronsumY = p.perceptronsumY; + this.squaredperceptronsumY = p.squaredperceptronsumY; + this.weightAttribute = p.weightAttribute; + this.perceptronattributeStatistics = p.perceptronattributeStatistics; + this.squaredperceptronattributeStatistics = p.squaredperceptronattributeStatistics; + this.perceptronInstancesSeen = p.perceptronInstancesSeen; + this.perceptronYSeen = p.perceptronYSeen; + } + + public Perceptron build() { + return new Perceptron(this); + } + + } + + public static final class PerceptronSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, Perceptron p) { + kryo.writeObjectOrNull(output, new PerceptronData(p), PerceptronData.class); + } + + @Override + public Perceptron read(Kryo kryo, Input input, Class type) { + PerceptronData perceptronData = kryo.readObjectOrNull(input, PerceptronData.class); + return perceptronData.build(); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/Rule.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/Rule.java new file mode 100644 index 00000000000..fae9a87d67d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/Rule.java @@ -0,0 +1,111 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.AbstractMOAObject; +import org.apache.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate; + +/** + * The base class for "rule". Represents the most basic rule with and ID and a list of features (nodeList). + * + * @author Anh Thu Vu + * + */ +public abstract class Rule extends AbstractMOAObject { + private static final long serialVersionUID = 1L; + + protected int ruleNumberID; + + protected List nodeList; + + /* + * Constructor + */ + public Rule() { + this.nodeList = new LinkedList(); + } + + /* + * Rule ID + */ + public int getRuleNumberID() { + return ruleNumberID; + } + + public void setRuleNumberID(int ruleNumberID) { + this.ruleNumberID = ruleNumberID; + } + + /* + * RuleSplitNode list + */ + public List getNodeList() { + return nodeList; + } + + public void setNodeList(List nodeList) { + this.nodeList = nodeList; + } + + /* + * Covering + */ + public boolean isCovering(Instance inst) { + boolean isCovering = true; + for (RuleSplitNode node : nodeList) { + if (node.evaluate(inst) == false) { + isCovering = false; + break; + } + } + return isCovering; + } + + /* + * Add RuleSplitNode + */ + public boolean nodeListAdd(RuleSplitNode ruleSplitNode) { + // Check that the node is not already in the list + boolean isIncludedInNodeList = false; + boolean isUpdated = false; + for (RuleSplitNode node : nodeList) { + NumericAttributeBinaryRulePredicate nodeTest = (NumericAttributeBinaryRulePredicate) node.getSplitTest(); + NumericAttributeBinaryRulePredicate ruleSplitNodeTest = (NumericAttributeBinaryRulePredicate) ruleSplitNode + .getSplitTest(); + if (nodeTest.isUsingSameAttribute(ruleSplitNodeTest)) { + isIncludedInNodeList = true; + if (nodeTest.isIncludedInRuleNode(ruleSplitNodeTest) == true) { // remove this line to keep the most recent attribute value + // replace the value + nodeTest.setAttributeValue(ruleSplitNodeTest); + isUpdated = true; // if is updated (i.e. an expansion happened) a new learning node should be created + } + } + } + if (isIncludedInNodeList == false) { + this.nodeList.add(ruleSplitNode); + } + return (!isIncludedInNodeList || isUpdated); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleActiveLearningNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleActiveLearningNode.java new file mode 100644 index 00000000000..a492192f16d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleActiveLearningNode.java @@ -0,0 +1,34 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/* +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * Interface for Rule's LearningNode that updates both statistics for expanding rule and computing predictions. + * + * @author Anh Thu Vu + * + */ +public interface RuleActiveLearningNode extends RulePassiveLearningNode { + + public boolean tryToExpand(double splitConfidence, double tieThreshold); + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleActiveRegressionNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleActiveRegressionNode.java new file mode 100644 index 00000000000..aaa83d650de --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleActiveRegressionNode.java @@ -0,0 +1,331 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.FIMTDDNumericAttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.splitcriteria.SDRSplitCriterionAMRules; +import org.apache.samoa.moa.classifiers.rules.driftdetection.PageHinkleyFading; +import org.apache.samoa.moa.classifiers.rules.driftdetection.PageHinkleyTest; +import org.apache.samoa.moa.core.AutoExpandVector; +import org.apache.samoa.moa.core.DoubleVector; + +/** + * LearningNode for regression rule that updates both statistics for expanding rule and computing predictions. + * + * @author Anh Thu Vu + * + */ +public class RuleActiveRegressionNode extends RuleRegressionNode implements RuleActiveLearningNode { + + /** + * + */ + private static final long serialVersionUID = 519854943188168546L; + + protected int splitIndex = 0; + + protected PageHinkleyTest pageHinckleyTest; + protected boolean changeDetection; + + protected double[] statisticsNewRuleActiveLearningNode = null; + protected double[] statisticsBranchSplit = null; + protected double[] statisticsOtherBranchSplit; + + protected AttributeSplitSuggestion bestSuggestion = null; + + protected AutoExpandVector attributeObservers = new AutoExpandVector<>(); + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + + /* + * Simple setters & getters + */ + public int getSplitIndex() { + return splitIndex; + } + + public void setSplitIndex(int splitIndex) { + this.splitIndex = splitIndex; + } + + public double[] getStatisticsOtherBranchSplit() { + return statisticsOtherBranchSplit; + } + + public void setStatisticsOtherBranchSplit(double[] statisticsOtherBranchSplit) { + this.statisticsOtherBranchSplit = statisticsOtherBranchSplit; + } + + public double[] getStatisticsBranchSplit() { + return statisticsBranchSplit; + } + + public void setStatisticsBranchSplit(double[] statisticsBranchSplit) { + this.statisticsBranchSplit = statisticsBranchSplit; + } + + public double[] getStatisticsNewRuleActiveLearningNode() { + return statisticsNewRuleActiveLearningNode; + } + + public void setStatisticsNewRuleActiveLearningNode( + double[] statisticsNewRuleActiveLearningNode) { + this.statisticsNewRuleActiveLearningNode = statisticsNewRuleActiveLearningNode; + } + + public AttributeSplitSuggestion getBestSuggestion() { + return bestSuggestion; + } + + public void setBestSuggestion(AttributeSplitSuggestion bestSuggestion) { + this.bestSuggestion = bestSuggestion; + } + + /* + * Constructor with builder + */ + public RuleActiveRegressionNode() { + super(); + } + + public RuleActiveRegressionNode(ActiveRule.Builder builder) { + super(builder.statistics); + this.changeDetection = builder.changeDetection; + if (!builder.changeDetection) { + this.pageHinckleyTest = new PageHinkleyFading(builder.threshold, builder.alpha); + } + this.predictionFunction = builder.predictionFunction; + this.learningRatio = builder.learningRatio; + this.ruleNumberID = builder.id; + this.numericObserver = builder.numericObserver; + + this.perceptron = new Perceptron(); + this.perceptron.prepareForUse(); + this.perceptron.originalLearningRatio = builder.learningRatio; + this.perceptron.constantLearningRatioDecay = builder.constantLearningRatioDecay; + + if (this.predictionFunction != 1) + { + this.targetMean = new TargetMean(); + if (builder.statistics[0] > 0) + this.targetMean.reset(builder.statistics[1] / builder.statistics[0], (long) builder.statistics[0]); + } + this.predictionFunction = builder.predictionFunction; + if (builder.statistics != null) + this.nodeStatistics = new DoubleVector(builder.statistics); + } + + /* + * Update with input instance + */ + public boolean updatePageHinckleyTest(double error) { + boolean changeDetected = false; + if (!this.changeDetection) { + changeDetected = pageHinckleyTest.update(error); + } + return changeDetected; + } + + public boolean updateChangeDetection(double error) { + return !changeDetection && pageHinckleyTest.update(error); + } + + @Override + public void updateStatistics(Instance inst) { + // Update the statistics for this node + // number of instances passing through the node + nodeStatistics.addToValue(0, 1); + // sum of y values + nodeStatistics.addToValue(1, inst.classValue()); + // sum of squared y values + nodeStatistics.addToValue(2, inst.classValue() * inst.classValue()); + + for (int i = 0; i < inst.numAttributes() - 1; i++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs == null) { + // At this stage all nominal attributes are ignored + if (inst.attribute(instAttIndex).isNumeric()) // instAttIndex + { + obs = newNumericClassObserver(); + this.attributeObservers.set(i, obs); + } + } + if (obs != null) { + ((FIMTDDNumericAttributeClassObserver) obs).observeAttributeClass(inst.value(instAttIndex), inst.classValue(), + inst.weight()); + } + } + + this.perceptron.trainOnInstance(inst); + if (this.predictionFunction != 1) { // Train target mean if prediction function is not Perceptron + this.targetMean.trainOnInstance(inst); + } + } + + protected AttributeClassObserver newNumericClassObserver() { + // return new FIMTDDNumericAttributeClassObserver(); + // return new FIMTDDNumericAttributeClassLimitObserver(); + // return + // (AttributeClassObserver)((AttributeClassObserver)this.numericObserverOption.getPreMaterializedObject()).copy(); + FIMTDDNumericAttributeClassLimitObserver newObserver = new FIMTDDNumericAttributeClassLimitObserver(); + newObserver.setMaxNodes(numericObserver.getMaxNodes()); + return newObserver; + } + + /* + * Init after being split from oldLearningNode + */ + public void initialize(RuleRegressionNode oldLearningNode) { + if (oldLearningNode.perceptron != null) + { + this.perceptron = new Perceptron(oldLearningNode.perceptron); + this.perceptron.resetError(); + this.perceptron.setLearningRatio(oldLearningNode.learningRatio); + } + + if (oldLearningNode.targetMean != null) + { + this.targetMean = new TargetMean(oldLearningNode.targetMean); + this.targetMean.resetError(); + } + // reset statistics + this.nodeStatistics.setValue(0, 0); + this.nodeStatistics.setValue(1, 0); + this.nodeStatistics.setValue(2, 0); + } + + /* + * Expand + */ + @Override + public boolean tryToExpand(double splitConfidence, double tieThreshold) { + + // splitConfidence. Hoeffding Bound test parameter. + // tieThreshold. Hoeffding Bound test parameter. + SplitCriterion splitCriterion = new SDRSplitCriterionAMRules(); + // SplitCriterion splitCriterion = new SDRSplitCriterionAMRulesNode();//JD + // for assessing only best branch + + // Using this criterion, find the best split per attribute and rank the + // results + AttributeSplitSuggestion[] bestSplitSuggestions = this.getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSplitSuggestions); + // Declare a variable to determine if any of the splits should be performed + boolean shouldSplit = false; + + // If only one split was returned, use it + if (bestSplitSuggestions.length < 2) { + shouldSplit = ((bestSplitSuggestions.length > 0) && (bestSplitSuggestions[0].merit > 0)); + bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + } // Otherwise, consider which of the splits proposed may be worth trying + else { + // Determine the hoeffding bound value, used to select how many instances + // should be used to make a test decision + // to feel reasonably confident that the test chosen by this sample is the + // same as what would be chosen using infinite examples + double hoeffdingBound = computeHoeffdingBound(1, splitConfidence, getInstancesSeen()); + // Determine the top two ranked splitting suggestions + bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2]; + + // If the upper bound of the sample mean for the ratio of SDR(best + // suggestion) to SDR(second best suggestion), + // as determined using the hoeffding bound, is less than 1, then the true + // mean is also less than 1, and thus at this + // particular moment of observation the bestSuggestion is indeed the best + // split option with confidence 1-delta, and + // splitting should occur. + // Alternatively, if two or more splits are very similar or identical in + // terms of their splits, then a threshold limit + // (default 0.05) is applied to the hoeffding bound; if the hoeffding + // bound is smaller than this limit then the two + // competing attributes are equally good, and the split will be made on + // the one with the higher SDR value. + + if (bestSuggestion.merit > 0) { + if ((((secondBestSuggestion.merit / bestSuggestion.merit) + hoeffdingBound) < 1) + || (hoeffdingBound < tieThreshold)) { + shouldSplit = true; + } + } + } + + if (shouldSplit) { + AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + double minValue = Double.MAX_VALUE; + double[] branchMerits = SDRSplitCriterionAMRules + .computeBranchSplitMerits(bestSuggestion.resultingClassDistributions); + + for (int i = 0; i < bestSuggestion.numSplits(); i++) { + double value = branchMerits[i]; + if (value < minValue) { + minValue = value; + splitIndex = i; + statisticsNewRuleActiveLearningNode = bestSuggestion.resultingClassDistributionFromSplit(i); + } + } + statisticsBranchSplit = splitDecision.resultingClassDistributionFromSplit(splitIndex); + statisticsOtherBranchSplit = bestSuggestion.resultingClassDistributionFromSplit(splitIndex == 0 ? 1 : 0); + + } + return shouldSplit; + } + + public AutoExpandVector getAttributeObservers() { + return this.attributeObservers; + } + + public AttributeSplitSuggestion[] getBestSplitSuggestions(SplitCriterion criterion) { + + List bestSuggestions = new LinkedList(); + + // Set the nodeStatistics up as the preSplitDistribution, rather than the + // observedClassDistribution + double[] nodeSplitDist = this.nodeStatistics.getArrayCopy(); + for (int i = 0; i < this.attributeObservers.size(); i++) { + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs != null) { + + // AT THIS STAGE NON-NUMERIC ATTRIBUTES ARE IGNORED + AttributeSplitSuggestion bestSuggestion = null; + if (obs instanceof FIMTDDNumericAttributeClassObserver) { + bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, nodeSplitDist, i, true); + } + + if (bestSuggestion != null) { + bestSuggestions.add(bestSuggestion); + } + } + } + return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RulePassiveLearningNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RulePassiveLearningNode.java new file mode 100644 index 00000000000..4f0a45cd866 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RulePassiveLearningNode.java @@ -0,0 +1,32 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * Interface for Rule's LearningNode that does not update statistics for expanding rule. It only updates statistics for + * computing predictions. + * + * @author Anh Thu Vu + * + */ +public interface RulePassiveLearningNode { + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RulePassiveRegressionNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RulePassiveRegressionNode.java new file mode 100644 index 00000000000..9eff9d46a1d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RulePassiveRegressionNode.java @@ -0,0 +1,74 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.core.DoubleVector; + +/** + * LearningNode for regression rule that does not update statistics for expanding rule. It only updates statistics for + * computing predictions. + * + * @author Anh Thu Vu + * + */ +public class RulePassiveRegressionNode extends RuleRegressionNode implements RulePassiveLearningNode { + + /** + * + */ + private static final long serialVersionUID = 3720878438856489690L; + + public RulePassiveRegressionNode(double[] statistics) { + super(statistics); + } + + public RulePassiveRegressionNode() { + super(); + } + + public RulePassiveRegressionNode(RuleRegressionNode activeLearningNode) { + this.predictionFunction = activeLearningNode.predictionFunction; + this.ruleNumberID = activeLearningNode.ruleNumberID; + this.nodeStatistics = new DoubleVector(activeLearningNode.nodeStatistics); + this.learningRatio = activeLearningNode.learningRatio; + this.perceptron = new Perceptron(activeLearningNode.perceptron, true); + this.targetMean = new TargetMean(activeLearningNode.targetMean); + } + + /* + * Update with input instance + */ + @Override + public void updateStatistics(Instance inst) { + // Update the statistics for this node + // number of instances passing through the node + nodeStatistics.addToValue(0, 1); + // sum of y values + nodeStatistics.addToValue(1, inst.classValue()); + // sum of squared y values + nodeStatistics.addToValue(2, inst.classValue() * inst.classValue()); + + this.perceptron.trainOnInstance(inst); + if (this.predictionFunction != 1) { // Train target mean if prediction function is not Perceptron + this.targetMean.trainOnInstance(inst); + } + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleRegressionNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleRegressionNode.java new file mode 100644 index 00000000000..15587169882 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleRegressionNode.java @@ -0,0 +1,294 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.io.Serializable; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.core.DoubleVector; + +/** + * The base class for LearningNode for regression rule. + * + * @author Anh Thu Vu + * + */ +public abstract class RuleRegressionNode implements Serializable { + + private static final long serialVersionUID = 9129659494380381126L; + + protected int predictionFunction; + protected int ruleNumberID; + // The statistics for this node: + // Number of instances that have reached it + // Sum of y values + // Sum of squared y values + protected DoubleVector nodeStatistics; + + protected Perceptron perceptron; + protected TargetMean targetMean; + protected double learningRatio; + + /* + * Simple setters & getters + */ + public Perceptron getPerceptron() { + return perceptron; + } + + public void setPerceptron(Perceptron perceptron) { + this.perceptron = perceptron; + } + + public TargetMean getTargetMean() { + return targetMean; + } + + public void setTargetMean(TargetMean targetMean) { + this.targetMean = targetMean; + } + + /* + * Create a new RuleRegressionNode + */ + public RuleRegressionNode(double[] initialClassObservations) { + this.nodeStatistics = new DoubleVector(initialClassObservations); + } + + public RuleRegressionNode() { + this(new double[0]); + } + + /* + * Update statistics with input instance + */ + public abstract void updateStatistics(Instance instance); + + /* + * Predictions + */ + public double[] getPrediction(Instance instance) { + int predictionMode = this.getLearnerToUse(this.predictionFunction); + return getPrediction(instance, predictionMode); + } + + public double[] getSimplePrediction() { + if (this.targetMean != null) + return this.targetMean.getVotesForInstance(); + else + return new double[] { 0 }; + } + + public double[] getPrediction(Instance instance, int predictionMode) { + double[] ret; + if (predictionMode == 1) + ret = this.perceptron.getVotesForInstance(instance); + else + ret = this.targetMean.getVotesForInstance(instance); + return ret; + } + + public double getNormalizedPrediction(Instance instance) { + double res; + double[] aux; + switch (this.predictionFunction) { + // perceptron - 1 + case 1: + res = this.perceptron.normalizedPrediction(instance); + break; + // target mean - 2 + case 2: + aux = this.targetMean.getVotesForInstance(); + res = normalize(aux[0]); + break; + // adaptive - 0 + case 0: + int predictionMode = this.getLearnerToUse(0); + if (predictionMode == 1) + { + res = this.perceptron.normalizedPrediction(instance); + } + else { + aux = this.targetMean.getVotesForInstance(instance); + res = normalize(aux[0]); + } + break; + default: + throw new UnsupportedOperationException("Prediction mode not in range."); + } + return res; + } + + /* + * Get learner mode + */ + public int getLearnerToUse(int predMode) { + int predictionMode = predMode; + if (predictionMode == 0) { + double perceptronError = this.perceptron.getCurrentError(); + double meanTargetError = this.targetMean.getCurrentError(); + if (perceptronError < meanTargetError) + predictionMode = 1; // PERCEPTRON + else + predictionMode = 2; // TARGET MEAN + } + return predictionMode; + } + + /* + * Error and change detection + */ + public double computeError(Instance instance) { + double normalizedPrediction = getNormalizedPrediction(instance); + double normalizedClassValue = normalize(instance.classValue()); + return Math.abs(normalizedClassValue - normalizedPrediction); + } + + public double getCurrentError() { + double error; + if (this.perceptron != null) { + if (targetMean == null) + error = perceptron.getCurrentError(); + else { + double errorP = perceptron.getCurrentError(); + double errorTM = targetMean.getCurrentError(); + error = (errorP < errorTM) ? errorP : errorTM; + } + } + else + error = Double.MAX_VALUE; + return error; + } + + /* + * no. of instances seen + */ + public long getInstancesSeen() { + if (nodeStatistics != null) { + return (long) this.nodeStatistics.getValue(0); + } else { + return 0; + } + } + + public DoubleVector getNodeStatistics() { + return this.nodeStatistics; + } + + /* + * Anomaly detection + */ + public boolean isAnomaly(Instance instance, + double uniVariateAnomalyProbabilityThreshold, + double multiVariateAnomalyProbabilityThreshold, + int numberOfInstanceesForAnomaly) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + long perceptronIntancesSeen = this.perceptron.getInstancesSeen(); + if (perceptronIntancesSeen >= numberOfInstanceesForAnomaly) { + double attribSum; + double attribSquaredSum; + double D = 0.0; + double N = 0.0; + double anomaly; + + for (int x = 0; x < instance.numAttributes() - 1; x++) { + // Perceptron is initialized each rule. + // this is a local anomaly. + int instAttIndex = modelAttIndexToInstanceAttIndex(x, instance); + attribSum = this.perceptron.perceptronattributeStatistics.getValue(x); + attribSquaredSum = this.perceptron.squaredperceptronattributeStatistics.getValue(x); + double mean = attribSum / perceptronIntancesSeen; + double sd = computeSD(attribSquaredSum, attribSum, perceptronIntancesSeen); + double probability = computeProbability(mean, sd, instance.value(instAttIndex)); + + if (probability > 0.0) { + D = D + Math.abs(Math.log(probability)); + if (probability < uniVariateAnomalyProbabilityThreshold) {// 0.10 + N = N + Math.abs(Math.log(probability)); + } + } + } + + anomaly = 0.0; + if (D != 0.0) { + anomaly = N / D; + } + if (anomaly >= multiVariateAnomalyProbabilityThreshold) { + // debuganomaly(instance, + // uniVariateAnomalyProbabilityThreshold, + // multiVariateAnomalyProbabilityThreshold, + // anomaly); + return true; + } + } + return false; + } + + /* + * Helpers + */ + public static double computeProbability(double mean, double sd, double value) { + double probability = 0.0; + + if (sd > 0.0) { + double k = (Math.abs(value - mean) / sd); // One tailed variant of Chebyshev's inequality + probability = 1.0 / (1 + k * k); + } + + return probability; + } + + public static double computeHoeffdingBound(double range, double confidence, double n) { + return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) / (2.0 * n)); + } + + private double normalize(double value) { + double meanY = this.nodeStatistics.getValue(1) / this.nodeStatistics.getValue(0); + double sdY = computeSD(this.nodeStatistics.getValue(2), this.nodeStatistics.getValue(1), + (long) this.nodeStatistics.getValue(0)); + double normalizedY = 0.0; + if (sdY > 0.0000001) { + normalizedY = (value - meanY) / (sdY); + } + return normalizedY; + } + + public double computeSD(double squaredVal, double val, long size) { + if (size > 1) { + return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0)); + } + return 0.0; + } + + /** + * Gets the index of the attribute in the instance, given the index of the attribute in the learner. + * + * @param index + * the index of the attribute in the learner + * @param inst + * the instance + * @return the index in the instance + */ + protected static int modelAttIndexToInstanceAttIndex(int index, Instance inst) { + return index <= inst.classIndex() ? index : index + 1; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleSplitNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleSplitNode.java new file mode 100644 index 00000000000..22714de4be0 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/RuleSplitNode.java @@ -0,0 +1,68 @@ +package org.apache.heron.learners.classifiers.rules.common; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.classifiers.trees.SplitNode; +import org.apache.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import org.apache.samoa.moa.classifiers.rules.core.Predicate; +import org.apache.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate; + +/** + * Represent a feature of rules (an element of ruleÅ› nodeList). + * + * @author Anh Thu Vu + * + */ +public class RuleSplitNode extends SplitNode { + + protected double lastTargetMean; + protected int operatorObserver; + + private static final long serialVersionUID = 1L; + + public InstanceConditionalTest getSplitTest() { + return this.splitTest; + } + + /** + * Create a new RuleSplitNode + */ + public RuleSplitNode() { + this(null, new double[0]); + } + + public RuleSplitNode(InstanceConditionalTest splitTest, double[] classObservations) { + super(splitTest, classObservations); + } + + public RuleSplitNode getACopy() { + InstanceConditionalTest splitTest = new NumericAttributeBinaryRulePredicate( + (NumericAttributeBinaryRulePredicate) this.getSplitTest()); + return new RuleSplitNode(splitTest, this.getObservedClassDistribution()); + } + + public boolean evaluate(Instance instance) { + Predicate predicate = (Predicate) this.splitTest; + return predicate.evaluate(instance); + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/TargetMean.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/TargetMean.java new file mode 100644 index 00000000000..acada4bb396 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/common/TargetMean.java @@ -0,0 +1,223 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.heron.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * Prediction scheme using TargetMean: + * TargetMean - Returns the mean of the target variable of the training instances + * + * @author Joao Duarte + * + * */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.AbstractClassifier; +import org.apache.samoa.moa.classifiers.Regressor; +import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.StringUtils; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.github.javacliparser.FloatOption; + +public class TargetMean extends AbstractClassifier implements Regressor { + + /** + * + */ + protected long n; + protected double sum; + protected double errorSum; + protected double nError; + private double fadingErrorFactor; + + private static final long serialVersionUID = 7152547322803559115L; + + public FloatOption fadingErrorFactorOption = new FloatOption( + "fadingErrorFactor", 'e', + "Fading error factor for the TargetMean accumulated error", 0.99, 0, 1); + + @Override + public boolean isRandomizable() { + return false; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return getVotesForInstance(); + } + + public double[] getVotesForInstance() { + double[] currentMean = new double[1]; + if (n > 0) + currentMean[0] = sum / n; + else + currentMean[0] = 0; + return currentMean; + } + + @Override + public void resetLearningImpl() { + sum = 0; + n = 0; + errorSum = Double.MAX_VALUE; + nError = 0; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + updateAccumulatedError(inst); + ++this.n; + this.sum += inst.classValue(); + } + + protected void updateAccumulatedError(Instance inst) { + double mean = 0; + nError = 1 + fadingErrorFactor * nError; + if (n > 0) + mean = sum / n; + errorSum = Math.abs(inst.classValue() - mean) + fadingErrorFactor * errorSum; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + StringUtils.appendIndented(out, indent, "Current Mean: " + this.sum / this.n); + StringUtils.appendNewline(out); + + } + + /* + * JD Resets the learner but initializes with a starting point + */ + public void reset(double currentMean, long numberOfInstances) { + this.sum = currentMean * numberOfInstances; + this.n = numberOfInstances; + this.resetError(); + } + + /* + * JD Resets the learner but initializes with a starting point + */ + public double getCurrentError() { + if (this.nError > 0) + return this.errorSum / this.nError; + else + return Double.MAX_VALUE; + } + + public TargetMean(TargetMean t) { + super(); + this.n = t.n; + this.sum = t.sum; + this.errorSum = t.errorSum; + this.nError = t.nError; + this.fadingErrorFactor = t.fadingErrorFactor; + this.fadingErrorFactorOption = t.fadingErrorFactorOption; + } + + public TargetMean(TargetMeanData td) { + this(); + this.n = td.n; + this.sum = td.sum; + this.errorSum = td.errorSum; + this.nError = td.nError; + this.fadingErrorFactor = td.fadingErrorFactor; + this.fadingErrorFactorOption.setValue(td.fadingErrorFactorOptionValue); + } + + public TargetMean() { + super(); + fadingErrorFactor = fadingErrorFactorOption.getValue(); + } + + public void resetError() { + this.errorSum = 0; + this.nError = 0; + } + + public static class TargetMeanData { + private long n; + private double sum; + private double errorSum; + private double nError; + private double fadingErrorFactor; + private double fadingErrorFactorOptionValue; + + public TargetMeanData() { + + } + + public TargetMeanData(TargetMean tm) { + this.n = tm.n; + this.sum = tm.sum; + this.errorSum = tm.errorSum; + this.nError = tm.nError; + this.fadingErrorFactor = tm.fadingErrorFactor; + if (tm.fadingErrorFactorOption != null) + this.fadingErrorFactorOptionValue = tm.fadingErrorFactorOption.getValue(); + else + this.fadingErrorFactorOptionValue = 0.99; + } + + public TargetMean build() { + return new TargetMean(this); + } + } + + public static final class TargetMeanSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, TargetMean t) { + kryo.writeObjectOrNull(output, new TargetMeanData(t), TargetMeanData.class); + } + + @Override + public TargetMean read(Kryo kryo, Input input, Class type) { + TargetMeanData data = kryo.readObjectOrNull(input, TargetMeanData.class); + return data.build(); + } + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java new file mode 100644 index 00000000000..4b8e3117cfe --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java @@ -0,0 +1,337 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/* +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.Perceptron; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Default Rule Learner Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRDefaultRuleProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 23702084591044447L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRDefaultRuleProcessor.class); + + private int processorId; + + // Default rule + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream ruleStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + /* + * Constructor + */ + public AMRDefaultRuleProcessor(Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.numericObserver = builder.numericObserver; + } + + @Override + public boolean process(ContentEvent event) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + // predict + if (instanceEvent.isTesting()) { + this.predictOnInstance(instanceEvent); + } + + // train + if (instanceEvent.isTraining()) { + this.trainOnInstance(instanceEvent); + } + + return false; + } + + /* + * Prediction + */ + private void predictOnInstance(InstanceContentEvent instanceEvent) { + double[] vote = defaultRule.getPrediction(instanceEvent.getInstance()); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + /* + * Training + */ + private void trainOnInstance(InstanceContentEvent instanceEvent) { + this.trainOnInstanceImpl(instanceEvent.getInstance()); + } + + public void trainOnInstanceImpl(Instance instance) { + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), + (RuleActiveRegressionNode) defaultRule.getLearningNode(), + ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + // send out the new rule + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule = newDefaultRule; + } + } + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r = newRule(ID); + + if (node != null) + { + if (node.getPerceptron() != null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics == null) + { + double mean; + if (node.getNodeStatistics().getValue(0) > 0) { + mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) + { + double mean; + if (statistics[0] > 0) { + mean = statistics[1] / statistics[0]; + ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r = new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics = new double[] { 0.0, 0, 0 }; + this.ruleNumberID = 0; + this.defaultRule = newRule(++this.ruleNumberID); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRDefaultRuleProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.ruleStream = oldProcessor.ruleStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.ruleStream.put(rce); + } + + /* + * Output streams + */ + public void setRuleStream(Stream ruleStream) { + this.ruleStream = ruleStream; + } + + public Stream getRuleStream() { + return this.ruleStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRDefaultRuleProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.numericObserver = processor.numericObserver; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public AMRDefaultRuleProcessor build() { + return new AMRDefaultRuleProcessor(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRLearnerProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRLearnerProcessor.java new file mode 100644 index 00000000000..a9cfe74123d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRLearnerProcessor.java @@ -0,0 +1,258 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.LearningRule; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RuleSplitNode; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Learner Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRLearnerProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -2302897295090248013L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRLearnerProcessor.class); + + private int processorId; + + private transient List ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + public AMRLearnerProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + // System.out.println("Processor:"+this.processorId+": Rule:"+ruleID+" -> Counter="+counter); + Iterator ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); // Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode) rule.getLearningNode()); + } + + } + + } + } + + return; + } + } + } + + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRLearnerProcessor oldProcessor = (AMRLearnerProcessor) p; + AMRLearnerProcessor newProcessor = + new AMRLearnerProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRLearnerProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public AMRLearnerProcessor build() { + return new AMRLearnerProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java new file mode 100644 index 00000000000..201d7182850 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java @@ -0,0 +1,372 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.concurrent.CopyOnWriteArrayList; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.LearningRule; +import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model Aggregator Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRRuleSetProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -6544096255649379334L; + private static final Logger logger = LoggerFactory.getLogger(AMRRuleSetProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient CopyOnWriteArrayList ruleSet; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + private Stream defaultRuleStream; + + // Options + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected int voteType; + + /* + * Constructor + */ + public AMRRuleSetProcessor(Builder builder) { + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.voteType = builder.voteType; + } + + /* + * (non-Javadoc) + * + * @see org.apache.samoa.core.Processor#process(org.apache.samoa.core. + * ContentEvent) + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + this.processInstanceEvent((InstanceContentEvent) event); + } + else if (event instanceof PredicateContentEvent) { + PredicateContentEvent pce = (PredicateContentEvent) event; + if (pce.getRuleSplitNode() == null) { + this.updateLearningNode(pce); + } + else { + this.updateRuleSplitNode(pce); + } + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + else { + addRule(rce.getRule()); + } + } + return true; + } + + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + for (PassiveRule aRuleSet : this.ruleSet) { + if (!continuePrediction && !continueTraining) + break; + + if (aRuleSet.isCovering(instance)) { + predictionCovered = true; + + if (continuePrediction) { + double[] vote = aRuleSet.getPrediction(instance); + double error = aRuleSet.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) + continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, aRuleSet)) { + trainingCovered = true; + aRuleSet.updateStatistics(instance); + + // Send instance to statistics PIs + sendInstanceToRule(instance, aRuleSet.getRuleNumberID()); + + if (!this.unorderedRules) + continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + + boolean defaultPrediction = instanceEvent.isTesting() && !predictionCovered; + boolean defaultTraining = instanceEvent.isTraining() && !trainingCovered; + if (defaultPrediction || defaultTraining) { + instanceEvent.setTesting(defaultPrediction); + instanceEvent.setTraining(defaultTraining); + this.defaultRuleStream.put(instanceEvent); + } + } + + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + // TODO: do a reset instead of init a new object + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (!this.noAnomalyDetection) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + rule.nodeListAdd(pce.getRuleSplitNode()); + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + private void updateLearningNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Add new rule/Remove rule + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(new PassiveRule(rule)); + return true; + } + + private void removeRule(int ruleID) { + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new CopyOnWriteArrayList(); + + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRRuleSetProcessor oldProcessor = (AMRRuleSetProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRRuleSetProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + newProcessor.defaultRuleStream = oldProcessor.defaultRuleStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + public Stream getDefaultRuleStream() { + return this.defaultRuleStream; + } + + public void setDefaultRuleStream(Stream defaultRuleStream) { + this.defaultRuleStream = defaultRuleStream; + } + + /* + * Builder + */ + public static class Builder { + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + // private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRRuleSetProcessor processor) { + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.voteType = processor.voteType; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRRuleSetProcessor build() { + return new AMRRuleSetProcessor(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java new file mode 100644 index 00000000000..efaef75a361 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java @@ -0,0 +1,530 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.LearningRule; +import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; +import org.apache.samoa.learners.classifiers.rules.common.Perceptron; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model Aggregator Processor (VAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRulesAggregatorProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 6303385725332704251L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesAggregatorProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient List ruleSet; + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + protected int voteType; + + /* + * Constructor + */ + public AMRulesAggregatorProcessor(Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.numericObserver = builder.numericObserver; + this.voteType = builder.voteType; + } + + /* + * Process + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + this.processInstanceEvent(instanceEvent); + } + else if (event instanceof PredicateContentEvent) { + this.updateRuleSplitNode((PredicateContentEvent) event); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + } + + return true; + } + + // Merge predict and train so we only check for covering rules one time + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + Iterator ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + if (!continuePrediction && !continueTraining) + break; + + PassiveRule rule = ruleIterator.next(); + + if (rule.isCovering(instance) == true) { + predictionCovered = true; + + if (continuePrediction) { + double[] vote = rule.getPrediction(instance); + double error = rule.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) + continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, rule)) { + trainingCovered = true; + rule.updateStatistics(instance); + // Send instance to statistics PIs + sendInstanceToRule(instance, rule.getRuleNumberID()); + + if (!this.unorderedRules) + continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + else if (instanceEvent.isTesting()) { + // predict with default rule + double[] vote = defaultRule.getPrediction(instance); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + if (!trainingCovered && instanceEvent.isTraining()) { + // train default rule with this instance + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), + (RuleActiveRegressionNode) defaultRule.getLearningNode(), + ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + this.ruleSet.add(new PassiveRule(this.defaultRule)); + // send to statistics PI + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule = newDefaultRule; + } + } + } + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and its prediction result. + * + * @param prediction + * The predicted class label from the decision tree model. + * @param inEvent + * The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r = newRule(ID); + + if (node != null) + { + if (node.getPerceptron() != null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics == null) + { + double mean; + if (node.getNodeStatistics().getValue(0) > 0) { + mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) + { + double mean; + if (statistics[0] > 0) { + mean = statistics[1] / statistics[0]; + ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r = new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + if (pce.getRuleSplitNode() != null) + rule.nodeListAdd(pce.getRuleSplitNode()); + if (pce.getLearningNode() != null) + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Remove rule + */ + private void removeRule(int ruleID) { + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics = new double[] { 0.0, 0, 0 }; + this.ruleNumberID = 0; + this.defaultRule = newRule(++this.ruleNumberID); + + this.ruleSet = new LinkedList(); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRulesAggregatorProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.statisticsStream.put(rce); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Others + */ + public boolean isRandomizable() { + return true; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesAggregatorProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.numericObserver = processor.numericObserver; + this.voteType = processor.voteType; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRulesAggregatorProcessor build() { + return new AMRulesAggregatorProcessor(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java new file mode 100644 index 00000000000..616fba0bd42 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java @@ -0,0 +1,218 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RuleSplitNode; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Learner Processor (VAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRulesStatisticsProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 5268933189695395573L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesStatisticsProcessor.class); + + private int processorId; + + private transient List ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + public AMRulesStatisticsProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + this.frequency = builder.frequency; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + Iterator ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + // Skip anomaly check as Aggregator's perceptron should be well-updated + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); // Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode) rule.getLearningNode()); + } + } + } + } + + return; + } + } + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor) p; + AMRulesStatisticsProcessor newProcessor = + new AMRulesStatisticsProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesStatisticsProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + this.frequency = processor.frequency; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder frequency(int frequency) { + this.frequency = frequency; + return this; + } + + public AMRulesStatisticsProcessor build() { + return new AMRulesStatisticsProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AssignmentContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AssignmentContentEvent.java new file mode 100644 index 00000000000..33040e18e0c --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/AssignmentContentEvent.java @@ -0,0 +1,74 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.instances.Instance; + +/** + * Forwarded instances from Model Agrregator to Learners/Default Rule Learner. + * + * @author Anh Thu Vu + * + */ +public class AssignmentContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 1031695762172836629L; + + private int ruleNumberID; + private Instance instance; + + public AssignmentContentEvent() { + this(0, null); + } + + public AssignmentContentEvent(int ruleID, Instance instance) { + this.ruleNumberID = ruleID; + this.instance = instance; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + public Instance getInstance() { + return this.instance; + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/PredicateContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/PredicateContentEvent.java new file mode 100644 index 00000000000..3b2526598ca --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/PredicateContentEvent.java @@ -0,0 +1,83 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RuleSplitNode; + +/** + * New features (of newly expanded rules) from Learners to Model Aggregators. + * + * @author Anh Thu Vu + * + */ +public class PredicateContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 7909435830443732451L; + + private int ruleNumberID; + private RuleSplitNode ruleSplitNode; + private RulePassiveRegressionNode learningNode; + + /* + * Constructor + */ + public PredicateContentEvent() { + this(0, null, null); + } + + public PredicateContentEvent(int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) { + this.ruleNumberID = ruleID; + this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating learningNode only + this.learningNode = learningNode; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; // N/A + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + + public RuleSplitNode getRuleSplitNode() { + return this.ruleSplitNode; + } + + public RulePassiveRegressionNode getLearningNode() { + return this.learningNode; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/RuleContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/RuleContentEvent.java new file mode 100644 index 00000000000..ec9684270b7 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/rules/distributed/RuleContentEvent.java @@ -0,0 +1,80 @@ +package org.apache.heron.learners.classifiers.rules.distributed; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; + +/** + * New rule from Model Aggregator/Default Rule Learner to Learners or removed rule from Learner to Model Aggregators. + * + * @author Anh Thu Vu + * + */ +public class RuleContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = -9046390274402894461L; + + private final int ruleNumberID; + private final ActiveRule addingRule; // for removing rule, we only need the rule's ID + private final boolean isRemoving; + + public RuleContentEvent() { + this(0, null, false); + } + + public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) { + this.ruleNumberID = ruleID; + this.isRemoving = isRemoving; + this.addingRule = rule; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + + public ActiveRule getRule() { + return this.addingRule; + } + + public boolean isRemoving() { + return this.isRemoving; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ActiveLearningNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ActiveLearningNode.java new file mode 100644 index 00000000000..6172f20a19b --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ActiveLearningNode.java @@ -0,0 +1,208 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class ActiveLearningNode extends LearningNode { + /** + * + */ + private static final long serialVersionUID = -2892102872646338908L; + private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class); + + private double weightSeenAtLastSplitEvaluation; + + private final Map attributeContentEventKeys; + + private AttributeSplitSuggestion bestSuggestion; + private AttributeSplitSuggestion secondBestSuggestion; + + private final long id; + private final int parallelismHint; + private int suggestionCtr; + private int thrownAwayInstance; + + private boolean isSplitting; + + ActiveLearningNode(double[] classObservation, int parallelismHint) { + super(classObservation); + this.weightSeenAtLastSplitEvaluation = this.getWeightSeen(); + this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate(); + this.attributeContentEventKeys = new HashMap<>(); + this.isSplitting = false; + this.parallelismHint = parallelismHint; + } + + long getId() { + return id; + } + + protected AttributeBatchContentEvent[] attributeBatchContentEvent; + + public AttributeBatchContentEvent[] getAttributeBatchContentEvent() { + return this.attributeBatchContentEvent; + } + + public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) { + this.attributeBatchContentEvent = attributeBatchContentEvent; + } + + @Override + void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { + // TODO: what statistics should we keep for unused instance? + if (isSplitting) { // currently throw all instance will splitting + this.thrownAwayInstance++; + return; + } + this.observedClassDistribution.addToValue((int) inst.classValue(), + inst.weight()); + // done: parallelize by sending attributes one by one + // TODO: meanwhile, we can try to use the ThreadPool to execute it + // separately + // TODO: parallelize by sending in batch, i.e. split the attributes into + // chunk instead of send the attribute one by one + for (int i = 0; i < inst.numAttributes() - 1; i++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + Integer obsIndex = i; + String key = attributeContentEventKeys.get(obsIndex); + + if (key == null) { + key = this.generateKey(i); + attributeContentEventKeys.put(obsIndex, key); + } + AttributeContentEvent ace = new AttributeContentEvent.Builder( + this.id, i, key) + .attrValue(inst.value(instAttIndex)) + .classValue((int) inst.classValue()) + .weight(inst.weight()) + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + if (this.attributeBatchContentEvent == null) { + this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1]; + } + if (this.attributeBatchContentEvent[i] == null) { + this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder( + this.id, i, key) + // .attrValue(inst.value(instAttIndex)) + // .classValue((int) inst.classValue()) + // .weight(inst.weight()] + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + } + this.attributeBatchContentEvent[i].add(ace); + // proc.sendToAttributeStream(ace); + } + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { + return this.observedClassDistribution.getArrayCopy(); + } + + double getWeightSeen() { + return this.observedClassDistribution.sumOfValues(); + } + + void setWeightSeenAtLastSplitEvaluation(double weight) { + this.weightSeenAtLastSplitEvaluation = weight; + } + + double getWeightSeenAtLastSplitEvaluation() { + return this.weightSeenAtLastSplitEvaluation; + } + + void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) { + this.isSplitting = true; + this.suggestionCtr = 0; + this.thrownAwayInstance = 0; + + ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id, + this.getObservedClassDistribution()); + modelAggrProc.sendToControlStream(cce); + } + + void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion) { + // starts comparing from the best suggestion + if (bestSuggestion != null) { + if ((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)) { + this.secondBestSuggestion = this.bestSuggestion; + this.bestSuggestion = bestSuggestion; + + if (secondBestSuggestion != null) { + + if ((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { + this.secondBestSuggestion = secondBestSuggestion; + } + } + } else { + if ((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { + this.secondBestSuggestion = bestSuggestion; + } + } + } + + // TODO: optimize the code to use less memory + this.suggestionCtr++; + } + + boolean isSplitting() { + return this.isSplitting; + } + + void endSplitting() { + this.isSplitting = false; + logger.trace("wasted instance: {}", this.thrownAwayInstance); + this.thrownAwayInstance = 0; + this.bestSuggestion = null; + this.secondBestSuggestion = null; + } + + AttributeSplitSuggestion getDistributedBestSuggestion() { + return this.bestSuggestion; + } + + AttributeSplitSuggestion getDistributedSecondBestSuggestion() { + return this.secondBestSuggestion; + } + + boolean isAllSuggestionsCollected() { + return (this.suggestionCtr == this.parallelismHint); + } + + private static int modelAttIndexToInstanceAttIndex(int index, Instance inst) { + return inst.classIndex() > index ? index : index + 1; + } + + private String generateKey(int obsIndex) { + final int prime = 31; + int result = 1; + result = prime * result + (int) (this.id ^ (this.id >>> 32)); + result = prime * result + obsIndex; + return Integer.toString(result); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/AttributeBatchContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/AttributeBatchContentEvent.java new file mode 100644 index 00000000000..d7ea35efcc5 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/AttributeBatchContentEvent.java @@ -0,0 +1,135 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; + +/** + * Attribute Content Event represents the instances that split vertically based on their attribute + * + * @author Arinto Murdopo + * + */ +final class AttributeBatchContentEvent implements ContentEvent { + + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final List contentEventList; + private final transient String key; + private final boolean isNominal; + + public AttributeBatchContentEvent() { + learningNodeId = -1; + obsIndex = -1; + contentEventList = new LinkedList<>(); + key = ""; + isNominal = true; + } + + private AttributeBatchContentEvent(Builder builder) { + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.contentEventList = new LinkedList<>(); + if (builder.contentEvent != null) { + this.contentEventList.add(builder.contentEvent); + } + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + public void add(ContentEvent contentEvent) { + this.contentEventList.add(contentEvent); + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + // do nothing, maybe useful when we want to reuse the object for + // serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId() { + return this.learningNodeId; + } + + int getObsIndex() { + return this.obsIndex; + } + + public List getContentEventList() { + return this.contentEventList; + } + + boolean isNominal() { + return this.isNominal; + } + + static final class Builder { + + // required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + private ContentEvent contentEvent; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder contentEvent(ContentEvent contentEvent) { + this.contentEvent = contentEvent; + return this; + } + + Builder isNominal(boolean val) { + this.isNominal = val; + return this; + } + + AttributeBatchContentEvent build() { + return new AttributeBatchContentEvent(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/AttributeContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/AttributeContentEvent.java new file mode 100644 index 00000000000..fda2e85edd7 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/AttributeContentEvent.java @@ -0,0 +1,224 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +/** + * Attribute Content Event represents the instances that split vertically based on their attribute + * + * @author Arinto Murdopo + * + */ +public final class AttributeContentEvent implements ContentEvent { + + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final double attrVal; + private final int classVal; + private final double weight; + private final transient String key; + private final boolean isNominal; + + public AttributeContentEvent() { + learningNodeId = -1; + obsIndex = -1; + attrVal = 0.0; + classVal = -1; + weight = 0.0; + key = ""; + isNominal = true; + } + + private AttributeContentEvent(Builder builder) { + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.attrVal = builder.attrVal; + this.classVal = builder.classVal; + this.weight = builder.weight; + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + // do nothing, maybe useful when we want to reuse the object for + // serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId() { + return this.learningNodeId; + } + + int getObsIndex() { + return this.obsIndex; + } + + int getClassVal() { + return this.classVal; + } + + double getAttrVal() { + return this.attrVal; + } + + double getWeight() { + return this.weight; + } + + boolean isNominal() { + return this.isNominal; + } + + static final class Builder { + + // required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + // optional parameters + private double attrVal = 0.0; + private int classVal = 0; + private double weight = 0.0; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder attrValue(double val) { + this.attrVal = val; + return this; + } + + Builder classValue(int val) { + this.classVal = val; + return this; + } + + Builder weight(double val) { + this.weight = val; + return this; + } + + Builder isNominal(boolean val) { + this.isNominal = val; + return this; + } + + AttributeContentEvent build() { + return new AttributeContentEvent(this); + } + } + + /** + * The Kryo serializer class for AttributeContentEvent when executing on top of Storm. This class allow us to change + * the precision of the statistics. + * + * @author Arinto Murdopo + * + */ + public static final class AttributeCESerializer extends Serializer { + + private static double PRECISION = 1000000.0; + + @Override + public void write(Kryo kryo, Output output, AttributeContentEvent event) { + output.writeLong(event.learningNodeId, true); + output.writeInt(event.obsIndex, true); + output.writeDouble(event.attrVal, PRECISION, true); + output.writeInt(event.classVal, true); + output.writeDouble(event.weight, PRECISION, true); + output.writeBoolean(event.isNominal); + } + + @Override + public AttributeContentEvent read(Kryo kryo, Input input, + Class type) { + AttributeContentEvent ace = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) + .attrValue(input.readDouble(PRECISION, true)) + .classValue(input.readInt(true)) + .weight(input.readDouble(PRECISION, true)) + .isNominal(input.readBoolean()) + .build(); + return ace; + } + } + + /** + * The Kryo serializer class for AttributeContentEvent when executing on top of Storm with full precision of the + * statistics. + * + * @author Arinto Murdopo + * + */ + public static final class AttributeCEFullPrecSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, AttributeContentEvent event) { + output.writeLong(event.learningNodeId, true); + output.writeInt(event.obsIndex, true); + output.writeDouble(event.attrVal); + output.writeInt(event.classVal, true); + output.writeDouble(event.weight); + output.writeBoolean(event.isNominal); + } + + @Override + public AttributeContentEvent read(Kryo kryo, Input input, + Class type) { + AttributeContentEvent ace = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) + .attrValue(input.readDouble()) + .classValue(input.readInt(true)) + .weight(input.readDouble()) + .isNominal(input.readBoolean()) + .build(); + return ace; + } + + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ComputeContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ComputeContentEvent.java new file mode 100644 index 00000000000..d8d28a64f2a --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ComputeContentEvent.java @@ -0,0 +1,145 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +/** + * Compute content event is the message that is sent by Model Aggregator Processor to request Local Statistic PI to + * start the local statistic calculation for splitting + * + * @author Arinto Murdopo + * + */ +public final class ComputeContentEvent extends ControlContentEvent { + + private static final long serialVersionUID = 5590798490073395190L; + + private final double[] preSplitDist; + private final long splitId; + + public ComputeContentEvent() { + super(-1); + preSplitDist = null; + splitId = -1; + } + + ComputeContentEvent(long splitId, long id, double[] preSplitDist) { + super(id); + // this.preSplitDist = Arrays.copyOf(preSplitDist, preSplitDist.length); + this.preSplitDist = preSplitDist; + this.splitId = splitId; + } + + @Override + LocStatControl getType() { + return LocStatControl.COMPUTE; + } + + double[] getPreSplitDist() { + return this.preSplitDist; + } + + long getSplitId() { + return this.splitId; + } + + /** + * The Kryo serializer class for ComputeContentEevent when executing on top of Storm. This class allow us to change + * the precision of the statistics. + * + * @author Arinto Murdopo + * + */ + public static final class ComputeCESerializer extends Serializer { + + private static double PRECISION = 1000000.0; + + @Override + public void write(Kryo kryo, Output output, ComputeContentEvent object) { + output.writeLong(object.splitId, true); + output.writeLong(object.learningNodeId, true); + + output.writeInt(object.preSplitDist.length, true); + for (int i = 0; i < object.preSplitDist.length; i++) { + output.writeDouble(object.preSplitDist[i], PRECISION, true); + } + } + + @Override + public ComputeContentEvent read(Kryo kryo, Input input, + Class type) { + long splitId = input.readLong(true); + long learningNodeId = input.readLong(true); + + int dataLength = input.readInt(true); + double[] preSplitDist = new double[dataLength]; + + for (int i = 0; i < dataLength; i++) { + preSplitDist[i] = input.readDouble(PRECISION, true); + } + + return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); + } + } + + /** + * The Kryo serializer class for ComputeContentEevent when executing on top of Storm with full precision of the + * statistics. + * + * @author Arinto Murdopo + * + */ + public static final class ComputeCEFullPrecSerializer extends Serializer { + + @Override + public void write(Kryo kryo, Output output, ComputeContentEvent object) { + output.writeLong(object.splitId, true); + output.writeLong(object.learningNodeId, true); + + output.writeInt(object.preSplitDist.length, true); + for (int i = 0; i < object.preSplitDist.length; i++) { + output.writeDouble(object.preSplitDist[i]); + } + } + + @Override + public ComputeContentEvent read(Kryo kryo, Input input, + Class type) { + long splitId = input.readLong(true); + long learningNodeId = input.readLong(true); + + int dataLength = input.readInt(true); + double[] preSplitDist = new double[dataLength]; + + for (int i = 0; i < dataLength; i++) { + preSplitDist[i] = input.readDouble(); + } + + return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); + } + + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ControlContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ControlContentEvent.java new file mode 100644 index 00000000000..f49050e8b21 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ControlContentEvent.java @@ -0,0 +1,72 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; + +/** + * Abstract class to represent ContentEvent to control Local Statistic Processor. + * + * @author Arinto Murdopo + * + */ +abstract class ControlContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 5837375639629708363L; + + protected final long learningNodeId; + + public ControlContentEvent() { + this.learningNodeId = -1; + } + + ControlContentEvent(long id) { + this.learningNodeId = id; + } + + @Override + public final String getKey() { + return null; + } + + @Override + public void setKey(String str) { + // Do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + final long getLearningNodeId() { + return this.learningNodeId; + } + + abstract LocStatControl getType(); + + static enum LocStatControl { + COMPUTE, DELETE + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/DeleteContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/DeleteContentEvent.java new file mode 100644 index 00000000000..083a5b2bf80 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/DeleteContentEvent.java @@ -0,0 +1,47 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * Delete Content Event is the content event that is sent by Model Aggregator Processor to delete unnecessary statistic + * in Local Statistic Processor. + * + * @author Arinto Murdopo + * + */ +final class DeleteContentEvent extends ControlContentEvent { + + private static final long serialVersionUID = -2105250722560863633L; + + public DeleteContentEvent() { + super(-1); + } + + DeleteContentEvent(long id) { + super(id); + } + + @Override + LocStatControl getType() { + return LocStatControl.DELETE; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/FilterProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/FilterProcessor.java new file mode 100644 index 00000000000..44e6d8a59dc --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/FilterProcessor.java @@ -0,0 +1,185 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.InstancesContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedList; +import java.util.List; + +/** + * Filter Processor that stores and filters the instances before sending them to the Model Aggregator Processor. + * + * @author Arinto Murdopo + * + */ +final class FilterProcessor implements Processor { + + private static final long serialVersionUID = -1685875718300564885L; + private static final Logger logger = LoggerFactory.getLogger(FilterProcessor.class); + + private int processorId; + + private final Instances dataset; + private InstancesHeader modelContext; + + // available streams + private Stream outputStream; + + // private constructor based on Builder pattern + private FilterProcessor(Builder builder) { + this.dataset = builder.dataset; + this.batchSize = builder.batchSize; + this.delay = builder.delay; + } + + private int waitingInstances = 0; + + private int delay = 0; + + private int batchSize = 200; + + private List contentEventList = new LinkedList(); + + @Override + public boolean process(ContentEvent event) { + // Receive a new instance from source + if (event instanceof InstanceContentEvent) { + InstanceContentEvent instanceContentEvent = (InstanceContentEvent) event; + this.contentEventList.add(instanceContentEvent); + this.waitingInstances++; + if (this.waitingInstances == this.batchSize || instanceContentEvent.isLastEvent()) { + // Send Instances + InstancesContentEvent outputEvent = new InstancesContentEvent(); + while (!this.contentEventList.isEmpty()) { + InstanceContentEvent ice = this.contentEventList.remove(0); + outputEvent.add(ice.getInstanceContent()); + } + this.waitingInstances = 0; + this.outputStream.put(outputEvent); + if (this.delay > 0) { + try { + Thread.sleep(this.delay); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + } + } + } + } + return false; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.waitingInstances = 0; + + } + + @Override + public Processor newProcessor(Processor p) { + FilterProcessor oldProcessor = (FilterProcessor) p; + FilterProcessor newProcessor = + new FilterProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()); + return sb.toString(); + } + + void setOutputStream(Stream outputStream) { + this.outputStream = outputStream; + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and its prediction result. + * + * @param prediction + * The predicted class label from the decision tree model. + * @param inEvent + * The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + /** + * Builder class to replace constructors with many parameters + * + * @author Arinto Murdopo + * + */ + static class Builder { + + // required parameters + private final Instances dataset; + + private int delay = 0; + + private int batchSize = 200; + + Builder(Instances dataset) { + this.dataset = dataset; + } + + Builder(FilterProcessor oldProcessor) { + this.dataset = oldProcessor.dataset; + this.delay = oldProcessor.delay; + this.batchSize = oldProcessor.batchSize; + } + + public Builder delay(int delay) { + this.delay = delay; + return this; + } + + public Builder batchSize(int val) { + this.batchSize = val; + return this; + } + + FilterProcessor build() { + return new FilterProcessor(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/FoundNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/FoundNode.java new file mode 100644 index 00000000000..36278304da5 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/FoundNode.java @@ -0,0 +1,77 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +/** + * Class that represents the necessary data structure of the node where an instance is routed/filtered through the + * decision tree model. + * + * @author Arinto Murdopo + * + */ +final class FoundNode implements java.io.Serializable { + + /** + * + */ + private static final long serialVersionUID = -637695387934143293L; + + private final Node node; + private final SplitNode parent; + private final int parentBranch; + + FoundNode(Node node, SplitNode splitNode, int parentBranch) { + this.node = node; + this.parent = splitNode; + this.parentBranch = parentBranch; + } + + /** + * Method to get the node where an instance is routed/filtered through the decision tree model for testing and + * training. + * + * @return The node where the instance is routed/filtered + */ + Node getNode() { + return this.node; + } + + /** + * Method to get the parent of the node where an instance is routed/filtered through the decision tree model for + * testing and training + * + * @return The parent of the node + */ + SplitNode getParent() { + return this.parent; + } + + /** + * Method to get the index of the node (where an instance is routed/filtered through the decision tree model for + * testing and training) in its parent. + * + * @return The index of the node in its parent node. + */ + int getParentBranch() { + return this.parentBranch; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/InactiveLearningNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/InactiveLearningNode.java new file mode 100644 index 00000000000..a372754e203 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/InactiveLearningNode.java @@ -0,0 +1,55 @@ +package org.apache.heron.learners.classifiers.trees; + +/* +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.instances.Instance; + +/** + * Class that represents inactive learning node. Inactive learning node is a node which only keeps track of the observed + * class distribution. It does not store the statistic for splitting the node. + * + * @author Arinto Murdopo + * + */ +final class InactiveLearningNode extends LearningNode { + + /** + * + */ + private static final long serialVersionUID = -814552382883472302L; + + InactiveLearningNode(double[] initialClassObservation) { + super(initialClassObservation); + } + + @Override + void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { + this.observedClassDistribution.addToValue( + (int) inst.classValue(), inst.weight()); + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { + return this.observedClassDistribution.getArrayCopy(); + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LearningNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LearningNode.java new file mode 100644 index 00000000000..b6eb08d26d2 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LearningNode.java @@ -0,0 +1,59 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.instances.Instance; + +/** + * Abstract class that represents a learning node + * + * @author Arinto Murdopo + * + */ +abstract class LearningNode extends Node { + + private static final long serialVersionUID = 7157319356146764960L; + + protected LearningNode(double[] classObservation) { + super(classObservation); + } + + /** + * Method to process the instance for learning + * + * @param inst + * The processed instance + * @param proc + * The model aggregator processor where this learning node exists + */ + abstract void learnFromInstance(Instance inst, ModelAggregatorProcessor proc); + + @Override + protected boolean isLeaf() { + return true; + } + + @Override + protected FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, + int parentBranch) { + return new FoundNode(this, parent, parentBranch); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LocalResultContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LocalResultContentEvent.java new file mode 100644 index 00000000000..32014273a7a --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LocalResultContentEvent.java @@ -0,0 +1,95 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; + +/** + * Local Result Content Event is the content event that represents local calculation of statistic in Local Statistic + * Processor. + * + * @author Arinto Murdopo + * + */ +final class LocalResultContentEvent implements ContentEvent { + + private static final long serialVersionUID = -4206620993777418571L; + + private final AttributeSplitSuggestion bestSuggestion; + private final AttributeSplitSuggestion secondBestSuggestion; + private final long splitId; + + public LocalResultContentEvent() { + bestSuggestion = null; + secondBestSuggestion = null; + splitId = -1; + } + + LocalResultContentEvent(long splitId, AttributeSplitSuggestion best, AttributeSplitSuggestion secondBest) { + this.splitId = splitId; + this.bestSuggestion = best; + this.secondBestSuggestion = secondBest; + } + + @Override + public String getKey() { + return null; + } + + /** + * Method to return the best attribute split suggestion from this local statistic calculation. + * + * @return The best attribute split suggestion. + */ + AttributeSplitSuggestion getBestSuggestion() { + return this.bestSuggestion; + } + + /** + * Method to return the second best attribute split suggestion from this local statistic calculation. + * + * @return The second best attribute split suggestion. + */ + AttributeSplitSuggestion getSecondBestSuggestion() { + return this.secondBestSuggestion; + } + + /** + * Method to get the split ID of this local statistic calculation result + * + * @return The split id of this local calculation result + */ + long getSplitId() { + return this.splitId; + } + + @Override + public void setKey(String str) { + // do nothing + + } + + @Override + public boolean isLastEvent() { + return false; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LocalStatisticsProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LocalStatisticsProcessor.java new file mode 100644 index 00000000000..98761d402ba --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/LocalStatisticsProcessor.java @@ -0,0 +1,242 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Vector; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; + +/** + * Local Statistic Processor contains the local statistic of a subset of the attributes. + * + * @author Arinto Murdopo + * + */ +final class LocalStatisticsProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -3967695130634517631L; + private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class); + + // Collection of AttributeObservers, for each ActiveLearningNode and + // AttributeId + private Table localStats; + + private Stream computationResultStream; + + private final SplitCriterion splitCriterion; + private final boolean binarySplit; + private final AttributeClassObserver nominalClassObserver; + private final AttributeClassObserver numericClassObserver; + + // the two observer classes below are also needed to be setup from the Tree + private LocalStatisticsProcessor(Builder builder) { + this.splitCriterion = builder.splitCriterion; + this.binarySplit = builder.binarySplit; + this.nominalClassObserver = builder.nominalClassObserver; + this.numericClassObserver = builder.numericClassObserver; + } + + @Override + public boolean process(ContentEvent event) { + // process AttributeContentEvent by updating the subset of local statistics + if (event instanceof AttributeBatchContentEvent) { + AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event; + List contentEventList = abce.getContentEventList(); + for (ContentEvent contentEvent : contentEventList) { + AttributeContentEvent ace = (AttributeContentEvent) contentEvent; + Long learningNodeId = ace.getLearningNodeId(); + Integer obsIndex = ace.getObsIndex(); + + AttributeClassObserver obs = localStats.get( + learningNodeId, obsIndex); + + if (obs == null) { + obs = ace.isNominal() ? newNominalClassObserver() + : newNumericClassObserver(); + localStats.put(ace.getLearningNodeId(), obsIndex, obs); + } + obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(), + ace.getWeight()); + } + + /* + * if (event instanceof AttributeContentEvent) { AttributeContentEvent ace + * = (AttributeContentEvent) event; Long learningNodeId = + * Long.valueOf(ace.getLearningNodeId()); Integer obsIndex = + * Integer.valueOf(ace.getObsIndex()); + * + * AttributeClassObserver obs = localStats.get( learningNodeId, obsIndex); + * + * if (obs == null) { obs = ace.isNominal() ? newNominalClassObserver() : + * newNumericClassObserver(); localStats.put(ace.getLearningNodeId(), + * obsIndex, obs); } obs.observeAttributeClass(ace.getAttrVal(), + * ace.getClassVal(), ace.getWeight()); + */ + } else if (event instanceof ComputeContentEvent) { + // process ComputeContentEvent by calculating the local statistic + // and send back the calculation results via computation result stream. + ComputeContentEvent cce = (ComputeContentEvent) event; + Long learningNodeId = cce.getLearningNodeId(); + double[] preSplitDist = cce.getPreSplitDist(); + + Map learningNodeRowMap = localStats + .row(learningNodeId); + List suggestions = new Vector<>(); + + for (Entry entry : learningNodeRowMap.entrySet()) { + AttributeClassObserver obs = entry.getValue(); + AttributeSplitSuggestion suggestion = obs + .getBestEvaluatedSplitSuggestion(splitCriterion, + preSplitDist, entry.getKey(), binarySplit); + if (suggestion != null) { + suggestions.add(suggestion); + } + } + + AttributeSplitSuggestion[] bestSuggestions = suggestions + .toArray(new AttributeSplitSuggestion[suggestions.size()]); + + Arrays.sort(bestSuggestions); + + AttributeSplitSuggestion bestSuggestion = null; + AttributeSplitSuggestion secondBestSuggestion = null; + + if (bestSuggestions.length >= 1) { + bestSuggestion = bestSuggestions[bestSuggestions.length - 1]; + + if (bestSuggestions.length >= 2) { + secondBestSuggestion = bestSuggestions[bestSuggestions.length - 2]; + } + } + + // create the local result content event + LocalResultContentEvent lcre = + new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion); + computationResultStream.put(lcre); + logger.debug("Finish compute event"); + } else if (event instanceof DeleteContentEvent) { + DeleteContentEvent dce = (DeleteContentEvent) event; + Long learningNodeId = dce.getLearningNodeId(); + localStats.rowMap().remove(learningNodeId); + } + return false; + } + + @Override + public void onCreate(int id) { + this.localStats = HashBasedTable.create(); + } + + @Override + public Processor newProcessor(Processor p) { + LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p; + LocalStatisticsProcessor newProcessor = new LocalStatisticsProcessor.Builder(oldProcessor).build(); + + newProcessor.setComputationResultStream(oldProcessor.computationResultStream); + + return newProcessor; + } + + /** + * Method to set the computation result when using this processor to build a topology. + * + * @param computeStream + */ + void setComputationResultStream(Stream computeStream) { + this.computationResultStream = computeStream; + } + + private AttributeClassObserver newNominalClassObserver() { + return (AttributeClassObserver) this.nominalClassObserver.copy(); + } + + private AttributeClassObserver newNumericClassObserver() { + return (AttributeClassObserver) this.numericClassObserver.copy(); + } + + /** + * Builder class to replace constructors with many parameters + * + * @author Arinto Murdopo + * + */ + static class Builder { + + private SplitCriterion splitCriterion = new InfoGainSplitCriterion(); + private boolean binarySplit = false; + private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver(); + private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver(); + + Builder() { + + } + + Builder(LocalStatisticsProcessor oldProcessor) { + this.splitCriterion = oldProcessor.splitCriterion; + this.binarySplit = oldProcessor.binarySplit; + } + + Builder splitCriterion(SplitCriterion splitCriterion) { + this.splitCriterion = splitCriterion; + return this; + } + + Builder binarySplit(boolean binarySplit) { + this.binarySplit = binarySplit; + return this; + } + + Builder nominalClassObserver(AttributeClassObserver nominalClassObserver) { + this.nominalClassObserver = nominalClassObserver; + return this; + } + + Builder numericClassObserver(AttributeClassObserver numericClassObserver) { + this.numericClassObserver = numericClassObserver; + return this; + } + + LocalStatisticsProcessor build() { + return new LocalStatisticsProcessor(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ModelAggregatorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ModelAggregatorProcessor.java new file mode 100644 index 00000000000..01e70a08c3a --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/ModelAggregatorProcessor.java @@ -0,0 +1,721 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import static org.apache.samoa.moa.core.Utils.maxIndex; + +import java.io.Serializable; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.learners.InstanceContent; +import org.apache.samoa.learners.InstancesContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model Aggegator Processor consists of the decision tree model. It connects to local-statistic PI via attribute stream + * and control stream. Model-aggregator PI sends the split instances via attribute stream and it sends control messages + * to ask local-statistic PI to perform computation via control stream. + * + * Model-aggregator PI sends the classification result via result stream to an evaluator PI for classifier or other + * destination PI. The calculation results from local statistic arrive to the model-aggregator PI via computation-result + * stream. + * + * @author Arinto Murdopo + * + */ +final class ModelAggregatorProcessor implements Processor { + + private static final long serialVersionUID = -1685875718300564886L; + private static final Logger logger = LoggerFactory.getLogger(ModelAggregatorProcessor.class); + + private int processorId; + + private Node treeRoot; + + private int activeLeafNodeCount; + private int inactiveLeafNodeCount; + private int decisionNodeCount; + private boolean growthAllowed; + + private final Instances dataset; + + // to support concurrent split + private long splitId; + private ConcurrentMap splittingNodes; + private BlockingQueue timedOutSplittingNodes; + + // available streams + private Stream resultStream; + private Stream attributeStream; + private Stream controlStream; + + private transient ScheduledExecutorService executor; + + private final SplitCriterion splitCriterion; + private final double splitConfidence; + private final double tieThreshold; + private final int gracePeriod; + private final int parallelismHint; + private final long timeOut; + + // private constructor based on Builder pattern + private ModelAggregatorProcessor(Builder builder) { + this.dataset = builder.dataset; + this.splitCriterion = builder.splitCriterion; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + this.parallelismHint = builder.parallelismHint; + this.timeOut = builder.timeOut; + this.changeDetector = builder.changeDetector; + + InstancesHeader ih = new InstancesHeader(dataset); + this.setModelContext(ih); + } + + @Override + public boolean process(ContentEvent event) { + + // Poll the blocking queue shared between ModelAggregator and the time-out + // threads + Long timedOutSplitId = timedOutSplittingNodes.poll(); + if (timedOutSplitId != null) { // time out has been reached! + SplittingNodeInfo splittingNode = splittingNodes.get(timedOutSplitId); + if (splittingNode != null) { + this.splittingNodes.remove(timedOutSplitId); + this.continueAttemptToSplit(splittingNode.activeLearningNode, splittingNode.foundNode); + + } + + } + + // Receive a new instance from source + if (event instanceof InstancesContentEvent) { + InstancesContentEvent instancesEvent = (InstancesContentEvent) event; + this.processInstanceContentEvent(instancesEvent); + // Send information to local-statistic PI + // for each of the nodes + if (this.foundNodeSet != null) { + for (FoundNode foundNode : this.foundNodeSet) { + ActiveLearningNode leafNode = (ActiveLearningNode) foundNode.getNode(); + AttributeBatchContentEvent[] abce = leafNode.getAttributeBatchContentEvent(); + if (abce != null) { + for (int i = 0; i < this.dataset.numAttributes() - 1; i++) { + this.sendToAttributeStream(abce[i]); + } + } + leafNode.setAttributeBatchContentEvent(null); + // this.sendToControlStream(event); //split information + // See if we can ask for splits + if (!leafNode.isSplitting()) { + double weightSeen = leafNode.getWeightSeen(); + // check whether it is the time for splitting + if (weightSeen - leafNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriod) { + attemptToSplit(leafNode, foundNode); + } + } + } + } + this.foundNodeSet = null; + } else if (event instanceof LocalResultContentEvent) { + LocalResultContentEvent lrce = (LocalResultContentEvent) event; + Long lrceSplitId = lrce.getSplitId(); + SplittingNodeInfo splittingNodeInfo = splittingNodes.get(lrceSplitId); + + if (splittingNodeInfo != null) { // if null, that means + // activeLearningNode has been + // removed by timeout thread + ActiveLearningNode activeLearningNode = splittingNodeInfo.activeLearningNode; + + activeLearningNode.addDistributedSuggestions(lrce.getBestSuggestion(), lrce.getSecondBestSuggestion()); + + if (activeLearningNode.isAllSuggestionsCollected()) { + splittingNodeInfo.scheduledFuture.cancel(false); + this.splittingNodes.remove(lrceSplitId); + this.continueAttemptToSplit(activeLearningNode, splittingNodeInfo.foundNode); + } + } + } + return false; + } + + protected Set foundNodeSet; + + @Override + public void onCreate(int id) { + this.processorId = id; + + this.activeLeafNodeCount = 0; + this.inactiveLeafNodeCount = 0; + this.decisionNodeCount = 0; + this.growthAllowed = true; + + this.splittingNodes = new ConcurrentHashMap<>(); + this.timedOutSplittingNodes = new LinkedBlockingQueue<>(); + this.splitId = 0; + + // Executor for scheduling time-out threads + this.executor = Executors.newScheduledThreadPool(8); + } + + @Override + public Processor newProcessor(Processor p) { + ModelAggregatorProcessor oldProcessor = (ModelAggregatorProcessor) p; + ModelAggregatorProcessor newProcessor = new ModelAggregatorProcessor.Builder(oldProcessor).build(); + + newProcessor.setResultStream(oldProcessor.resultStream); + newProcessor.setAttributeStream(oldProcessor.attributeStream); + newProcessor.setControlStream(oldProcessor.controlStream); + return newProcessor; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()); + + sb.append("ActiveLeafNodeCount: ").append(activeLeafNodeCount); + sb.append("InactiveLeafNodeCount: ").append(inactiveLeafNodeCount); + sb.append("DecisionNodeCount: ").append(decisionNodeCount); + sb.append("Growth allowed: ").append(growthAllowed); + return sb.toString(); + } + + void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + void setAttributeStream(Stream attributeStream) { + this.attributeStream = attributeStream; + } + + void setControlStream(Stream controlStream) { + this.controlStream = controlStream; + } + + void sendToAttributeStream(ContentEvent event) { + this.attributeStream.put(event); + } + + void sendToControlStream(ContentEvent event) { + this.controlStream.put(event); + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and its prediction result. + * + * @param prediction + * The predicted class label from the decision tree model. + * @param inEvent + * The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + private List contentEventList = new LinkedList<>(); + + /** + * Helper method to process the InstanceContentEvent + * + * @param instContentEvent + */ + private void processInstanceContentEvent(InstancesContentEvent instContentEvent) { + this.numBatches++; + this.contentEventList.add(instContentEvent); + if (this.numBatches == 1 || this.numBatches > 4) { + this.processInstances(this.contentEventList.remove(0)); + } + + if (instContentEvent.isLastEvent()) { + // drain remaining instances + while (!contentEventList.isEmpty()) { + processInstances(contentEventList.remove(0)); + } + } + + } + + private int numBatches = 0; + + private void processInstances(InstancesContentEvent instContentEvent) { + for (InstanceContent instContent : instContentEvent.getList()) { + Instance inst = instContent.getInstance(); + boolean isTesting = instContent.isTesting(); + boolean isTraining = instContent.isTraining(); + inst.setDataset(this.dataset); + // Check the instance whether it is used for testing or training + // boolean testAndTrain = isTraining; //Train after testing + double[] prediction = null; + if (isTesting) { + prediction = getVotesForInstance(inst, false); + this.resultStream.put(newResultContentEvent(prediction, instContent)); + } + + if (isTraining) { + trainOnInstanceImpl(inst); + if (this.changeDetector != null) { + if (prediction == null) { + prediction = getVotesForInstance(inst); + } + boolean correctlyClassifies = this.correctlyClassifies(inst, prediction); + double oldEstimation = this.changeDetector.getEstimation(); + this.changeDetector.input(correctlyClassifies ? 0 : 1); + if (this.changeDetector.getEstimation() > oldEstimation) { + // Start a new classifier + logger.info("Change detected, resetting the classifier"); + this.resetLearning(); + this.changeDetector.resetLearning(); + } + } + } + } + } + + private boolean correctlyClassifies(Instance inst, double[] prediction) { + return maxIndex(prediction) == (int) inst.classValue(); + } + + private void resetLearning() { + this.treeRoot = null; + // Remove nodes + FoundNode[] learningNodes = findNodes(); + for (FoundNode learningNode : learningNodes) { + Node node = learningNode.getNode(); + if (node instanceof SplitNode) { + SplitNode splitNode; + splitNode = (SplitNode) node; + for (int i = 0; i < splitNode.numChildren(); i++) { + splitNode.setChild(i, null); + } + } + } + } + + protected FoundNode[] findNodes() { + List foundList = new LinkedList<>(); + findNodes(this.treeRoot, null, -1, foundList); + return foundList.toArray(new FoundNode[foundList.size()]); + } + + protected void findNodes(Node node, SplitNode parent, int parentBranch, List found) { + if (node != null) { + found.add(new FoundNode(node, parent, parentBranch)); + if (node instanceof SplitNode) { + SplitNode splitNode = (SplitNode) node; + for (int i = 0; i < splitNode.numChildren(); i++) { + findNodes(splitNode.getChild(i), splitNode, i, found); + } + } + } + } + + /** + * Helper method to get the prediction result. The actual prediction result is delegated to the leaf node. + * + * @param inst + * @return + */ + private double[] getVotesForInstance(Instance inst) { + return getVotesForInstance(inst, false); + } + + private double[] getVotesForInstance(Instance inst, boolean isTraining) { + double[] ret; + FoundNode foundNode = null; + if (this.treeRoot != null) { + foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + Node leafNode = foundNode.getNode(); + if (leafNode == null) { + leafNode = foundNode.getParent(); + } + + ret = leafNode.getClassVotes(inst, this); + } else { + int numClasses = this.dataset.numClasses(); + ret = new double[numClasses]; + + } + + // Training after testing to speed up the process + if (isTraining) { + if (this.treeRoot == null) { + this.treeRoot = newLearningNode(this.parallelismHint); + this.activeLeafNodeCount = 1; + foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + } + trainOnInstanceImpl(foundNode, inst); + } + return ret; + } + + /** + * Helper method that represent training of an instance. Since it is decision tree, this method routes the incoming + * instance into the correct leaf and then update the statistic on the found leaf. + * + * @param inst + */ + private void trainOnInstanceImpl(Instance inst) { + if (this.treeRoot == null) { + this.treeRoot = newLearningNode(this.parallelismHint); + this.activeLeafNodeCount = 1; + + } + FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + trainOnInstanceImpl(foundNode, inst); + } + + private void trainOnInstanceImpl(FoundNode foundNode, Instance inst) { + + Node leafNode = foundNode.getNode(); + + if (leafNode == null) { + leafNode = newLearningNode(this.parallelismHint); + foundNode.getParent().setChild(foundNode.getParentBranch(), leafNode); + activeLeafNodeCount++; + } + + if (leafNode instanceof LearningNode) { + LearningNode learningNode = (LearningNode) leafNode; + learningNode.learnFromInstance(inst, this); + } + if (this.foundNodeSet == null) { + this.foundNodeSet = new HashSet<>(); + } + this.foundNodeSet.add(foundNode); + } + + /** + * Helper method to represent a split attempt + * + * @param activeLearningNode + * The corresponding active learning node which will be split + * @param foundNode + * The data structure to represents the filtering of the instance using the tree model. + */ + private void attemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode) { + if (!activeLearningNode.observedClassDistributionIsPure()) { + // Increment the split ID + this.splitId++; + + // Schedule time-out thread + ScheduledFuture timeOutHandler = this.executor.schedule(new AggregationTimeOutHandler(this.splitId, + this.timedOutSplittingNodes), this.timeOut, TimeUnit.SECONDS); + + // Keep track of the splitting node information, so that we can continue the + // split + // once we receive all local statistic calculation from Local Statistic PI + // this.splittingNodes.put(Long.valueOf(this.splitId), new + // SplittingNodeInfo(activeLearningNode, foundNode, null)); + this.splittingNodes.put(this.splitId, new SplittingNodeInfo(activeLearningNode, foundNode, timeOutHandler)); + + // Inform Local Statistic PI to perform local statistic calculation + activeLearningNode.requestDistributedSuggestions(this.splitId, this); + } + } + + /** + * Helper method to continue the attempt to split once all local calculation results are received. + * + * @param activeLearningNode + * The corresponding active learning node which will be split + * @param foundNode + * The data structure to represents the filtering of the instance using the tree model. + */ + private void continueAttemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode) { + AttributeSplitSuggestion bestSuggestion = activeLearningNode.getDistributedBestSuggestion(); + AttributeSplitSuggestion secondBestSuggestion = activeLearningNode.getDistributedSecondBestSuggestion(); + + // compare with null split + double[] preSplitDist = activeLearningNode.getObservedClassDistribution(); + AttributeSplitSuggestion nullSplit = new AttributeSplitSuggestion(null, new double[0][], + this.splitCriterion.getMeritOfSplit(preSplitDist, new double[][] { preSplitDist })); + + if ((bestSuggestion == null) || (nullSplit.compareTo(bestSuggestion) > 0)) { + secondBestSuggestion = bestSuggestion; + bestSuggestion = nullSplit; + } else { + if ((secondBestSuggestion == null) || (nullSplit.compareTo(secondBestSuggestion) > 0)) { + secondBestSuggestion = nullSplit; + } + } + + boolean shouldSplit = false; + + if (secondBestSuggestion == null) { + shouldSplit = (bestSuggestion != null); + } else { + double hoeffdingBound = computeHoeffdingBound( + this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()), this.splitConfidence, + activeLearningNode.getWeightSeen()); + + if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound) || (hoeffdingBound < tieThreshold)) { + shouldSplit = true; + } + // TODO: add poor attributes removal + } + + SplitNode parent = foundNode.getParent(); + int parentBranch = foundNode.getParentBranch(); + + // split if the Hoeffding bound condition is satisfied + if (shouldSplit) { + + if (bestSuggestion.splitTest != null) { + SplitNode newSplit = new SplitNode(bestSuggestion.splitTest, activeLearningNode.getObservedClassDistribution()); + + for (int i = 0; i < bestSuggestion.numSplits(); i++) { + Node newChild = newLearningNode(bestSuggestion.resultingClassDistributionFromSplit(i), this.parallelismHint); + newSplit.setChild(i, newChild); + } + + this.activeLeafNodeCount--; + this.decisionNodeCount++; + this.activeLeafNodeCount += bestSuggestion.numSplits(); + + if (parent == null) { + this.treeRoot = newSplit; + } else { + parent.setChild(parentBranch, newSplit); + } + } + // TODO: add check on the model's memory size + } + + // housekeeping + activeLearningNode.endSplitting(); + activeLearningNode.setWeightSeenAtLastSplitEvaluation(activeLearningNode.getWeightSeen()); + } + + /** + * Helper method to deactivate learning node + * + * @param toDeactivate + * Active Learning Node that will be deactivated + * @param parent + * Parent of the soon-to-be-deactivated Active LearningNode + * @param parentBranch + * the branch index of the node in the parent node + */ + private void deactivateLearningNode(ActiveLearningNode toDeactivate, SplitNode parent, int parentBranch) { + Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution()); + if (parent == null) { + this.treeRoot = newLeaf; + } else { + parent.setChild(parentBranch, newLeaf); + } + + this.activeLeafNodeCount--; + this.inactiveLeafNodeCount++; + } + + private LearningNode newLearningNode(int parallelismHint) { + return newLearningNode(new double[0], parallelismHint); + } + + private LearningNode newLearningNode(double[] initialClassObservations, int parallelismHint) { + // for VHT optimization, we need to dynamically instantiate the appropriate + // ActiveLearningNode + return new ActiveLearningNode(initialClassObservations, parallelismHint); + } + + /** + * Helper method to set the model context, i.e. how many attributes they are and what is the class index + * + * @param ih + */ + private void setModelContext(InstancesHeader ih) { + // TODO possibly refactored + if ((ih != null) && (ih.classIndex() < 0)) { + throw new IllegalArgumentException("Context for a classifier must include a class to learn"); + } + // TODO: check flag for checking whether training has started or not + + // model context is used to describe the model + logger.trace("Model context: {}", ih.toString()); + } + + private static double computeHoeffdingBound(double range, double confidence, double n) { + return Math.sqrt((Math.pow(range, 2.0) * Math.log(1.0 / confidence)) / (2.0 * n)); + } + + /** + * AggregationTimeOutHandler is a class to support time-out feature while waiting for local computation results from + * the local statistic PIs. + * + * @author Arinto Murdopo + * + */ + static class AggregationTimeOutHandler implements Runnable { + + private static final Logger logger = LoggerFactory.getLogger(AggregationTimeOutHandler.class); + private final Long splitId; + private final BlockingQueue toBeSplittedNodes; + + AggregationTimeOutHandler(Long splitId, BlockingQueue toBeSplittedNodes) { + this.splitId = splitId; + this.toBeSplittedNodes = toBeSplittedNodes; + } + + @Override + public void run() { + logger.debug("Time out is reached. AggregationTimeOutHandler is started."); + try { + toBeSplittedNodes.put(splitId); + } catch (InterruptedException e) { + logger.warn("Interrupted while trying to put the ID into the queue"); + } + logger.debug("AggregationTimeOutHandler is finished."); + } + } + + /** + * SplittingNodeInfo is a class to represents the ActiveLearningNode that is splitting + * + * @author Arinto Murdopo + * + */ + static class SplittingNodeInfo implements Serializable { + + private final ActiveLearningNode activeLearningNode; + private final FoundNode foundNode; + private final transient ScheduledFuture scheduledFuture; + + SplittingNodeInfo(ActiveLearningNode activeLearningNode, FoundNode foundNode, ScheduledFuture scheduledFuture) { + this.activeLearningNode = activeLearningNode; + this.foundNode = foundNode; + this.scheduledFuture = scheduledFuture; + } + } + + protected ChangeDetector changeDetector; + + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } + + /** + * Builder class to replace constructors with many parameters + * + * @author Arinto Murdopo + * + */ + static class Builder { + + // required parameters + private final Instances dataset; + + // default values + private SplitCriterion splitCriterion = new InfoGainSplitCriterion(); + private double splitConfidence = 0.0000001; + private double tieThreshold = 0.05; + private int gracePeriod = 200; + private int parallelismHint = 1; + private long timeOut = 30; + private ChangeDetector changeDetector = null; + + Builder(Instances dataset) { + this.dataset = dataset; + } + + Builder(ModelAggregatorProcessor oldProcessor) { + this.dataset = oldProcessor.dataset; + this.splitCriterion = oldProcessor.splitCriterion; + this.splitConfidence = oldProcessor.splitConfidence; + this.tieThreshold = oldProcessor.tieThreshold; + this.gracePeriod = oldProcessor.gracePeriod; + this.parallelismHint = oldProcessor.parallelismHint; + this.timeOut = oldProcessor.timeOut; + } + + Builder splitCriterion(SplitCriterion splitCriterion) { + this.splitCriterion = splitCriterion; + return this; + } + + Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + Builder parallelismHint(int parallelismHint) { + this.parallelismHint = parallelismHint; + return this; + } + + Builder timeOut(long timeOut) { + this.timeOut = timeOut; + return this; + } + + Builder changeDetector(ChangeDetector changeDetector) { + this.changeDetector = changeDetector; + return this; + } + + ModelAggregatorProcessor build() { + return new ModelAggregatorProcessor(this); + } + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/Node.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/Node.java new file mode 100644 index 00000000000..678c4a06e95 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/Node.java @@ -0,0 +1,103 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.core.DoubleVector; +import org.apache.samoa.instances.Instance; + +/** + * Abstract class that represents a node in the tree model. + * + * @author Arinto Murdopo + * + */ +abstract class Node implements java.io.Serializable { + + private static final long serialVersionUID = 4008521239214180548L; + + protected final DoubleVector observedClassDistribution; + + /** + * Method to route/filter an instance into its corresponding leaf. This method will be invoked recursively. + * + * @param inst + * Instance to be routed + * @param parent + * Parent of the current node + * @param parentBranch + * The index of the current node in the parent + * @return FoundNode which is the data structure to represent the resulting leaf. + */ + abstract FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch); + + /** + * Method to return the predicted class of the instance based on the statistic inside the node. + * + * @param inst + * To-be-predicted instance + * @param map + * ModelAggregatorProcessor + * @return The prediction result in the form of class distribution + */ + abstract double[] getClassVotes(Instance inst, ModelAggregatorProcessor map); + + /** + * Method to check whether the node is a leaf node or not. + * + * @return Boolean flag to indicate whether the node is a leaf or not + */ + abstract boolean isLeaf(); + + /** + * Constructor of the tree node + * + * @param classObservation + * distribution of the observed classes. + */ + protected Node(double[] classObservation) { + this.observedClassDistribution = new DoubleVector(classObservation); + } + + /** + * Getter method for the class distribution + * + * @return Observed class distribution + */ + protected double[] getObservedClassDistribution() { + return this.observedClassDistribution.getArrayCopy(); + } + + /** + * A method to check whether the class distribution only consists of one class or not. + * + * @return Flag whether class distribution is pure or not. + */ + protected boolean observedClassDistributionIsPure() { + return (observedClassDistribution.numNonZeroEntries() < 2); + } + + protected void describeSubtree(ModelAggregatorProcessor modelAggrProc, StringBuilder out, int indent) { + // TODO: implement method to gracefully define the tree + } + + // TODO: calculate promise for limiting the model based on the memory size + // double calculatePromise(); +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/SplitNode.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/SplitNode.java new file mode 100644 index 00000000000..1b9aff4dbaf --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/SplitNode.java @@ -0,0 +1,117 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import org.apache.samoa.moa.core.AutoExpandVector; + +/** + * SplitNode represents the node that contains one or more questions in the decision tree model, in order to route the + * instances into the correct leaf. + * + * @author Arinto Murdopo + * + */ +public class SplitNode extends Node { + + private static final long serialVersionUID = -7380795529928485792L; + + private final AutoExpandVector children; + protected final InstanceConditionalTest splitTest; + + public SplitNode(InstanceConditionalTest splitTest, + double[] classObservation) { + super(classObservation); + this.children = new AutoExpandVector<>(); + this.splitTest = splitTest; + } + + @Override + FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch) { + int childIndex = instanceChildIndex(inst); + if (childIndex >= 0) { + Node child = getChild(childIndex); + if (child != null) { + return child.filterInstanceToLeaf(inst, this, childIndex); + } + return new FoundNode(null, this, childIndex); + } + return new FoundNode(this, parent, parentBranch); + } + + @Override + boolean isLeaf() { + return false; + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor vht) { + return this.observedClassDistribution.getArrayCopy(); + } + + /** + * Method to return the number of children of this split node + * + * @return number of children + */ + int numChildren() { + return this.children.size(); + } + + /** + * Method to set the children in a specific index of the SplitNode with the appropriate child + * + * @param index + * Index of the child in the SplitNode + * @param child + * The child node + */ + void setChild(int index, Node child) { + if ((this.splitTest.maxBranches() >= 0) + && (index >= this.splitTest.maxBranches())) { + throw new IndexOutOfBoundsException(); + } + this.children.set(index, child); + } + + /** + * Method to get the child node given the index + * + * @param index + * The child node index + * @return The child node in the given index + */ + Node getChild(int index) { + return this.children.get(index); + } + + /** + * Method to route the instance using this split node + * + * @param inst + * The routed instance + * @return The index of the branch where the instance is routed + */ + int instanceChildIndex(Instance inst) { + return this.splitTest.branchForInstance(inst); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/VerticalHoeffdingTree.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/VerticalHoeffdingTree.java new file mode 100644 index 00000000000..8bf4fe01dd7 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/classifiers/trees/VerticalHoeffdingTree.java @@ -0,0 +1,184 @@ +package org.apache.heron.learners.classifiers.trees; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.AdaptiveLearner; +import org.apache.samoa.learners.ClassificationLearner; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.google.common.collect.ImmutableSet; + +/** + * Vertical Hoeffding Tree. + *

+ * Vertical Hoeffding Tree (VHT) classifier is a distributed classifier that utilizes vertical parallelism on top of + * Very Fast Decision Tree (VFDT) classifier. + * + * @author Arinto Murdopo + */ +public final class VerticalHoeffdingTree implements ClassificationLearner, AdaptiveLearner, Configurable { + + private static final long serialVersionUID = -4937416312929984057L; + + public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", + 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, + "GaussianNumericAttributeClassObserver"); + + public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", + 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, + "NominalAttributeClassObserver"); + + public ClassOption splitCriterionOption = new ClassOption("splitCriterion", + 's', "Split criterion to use.", SplitCriterion.class, + "InfoGainSplitCriterion"); + + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption( + "gracePeriod", + 'g', + "The number of instances a leaf should observe between split attempts.", + 200, 0, Integer.MAX_VALUE); + + public IntOption parallelismHintOption = new IntOption( + "parallelismHint", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + + public IntOption timeOutOption = new IntOption( + "timeOut", + 'o', + "The duration to wait all distributed computation results from local statistics PI", + 30, 1, Integer.MAX_VALUE); + + public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', + "Only allow binary splits."); + + private Stream resultStream; + + private FilterProcessor filterProc; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + this.filterProc = new FilterProcessor.Builder(dataset) + .build(); + topologyBuilder.addProcessor(filterProc, parallelism); + + Stream filterStream = topologyBuilder.createStream(filterProc); + this.filterProc.setOutputStream(filterStream); + + ModelAggregatorProcessor modelAggrProc = new ModelAggregatorProcessor.Builder(dataset) + .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .parallelismHint(parallelismHintOption.getValue()) + .timeOut(timeOutOption.getValue()) + .changeDetector(this.getChangeDetector()) + .build(); + + topologyBuilder.addProcessor(modelAggrProc, parallelism); + + topologyBuilder.connectInputShuffleStream(filterStream, modelAggrProc); + + this.resultStream = topologyBuilder.createStream(modelAggrProc); + modelAggrProc.setResultStream(resultStream); + + Stream attributeStream = topologyBuilder.createStream(modelAggrProc); + modelAggrProc.setAttributeStream(attributeStream); + + Stream controlStream = topologyBuilder.createStream(modelAggrProc); + modelAggrProc.setControlStream(controlStream); + + LocalStatisticsProcessor locStatProc = new LocalStatisticsProcessor.Builder() + .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue()) + .binarySplit(binarySplitsOption.isSet()) + .nominalClassObserver((AttributeClassObserver) this.nominalEstimatorOption.getValue()) + .numericClassObserver((AttributeClassObserver) this.numericEstimatorOption.getValue()) + .build(); + + topologyBuilder.addProcessor(locStatProc, parallelismHintOption.getValue()); + topologyBuilder.connectInputKeyStream(attributeStream, locStatProc); + topologyBuilder.connectInputAllStream(controlStream, locStatProc); + + Stream computeStream = topologyBuilder.createStream(locStatProc); + + locStatProc.setComputationResultStream(computeStream); + topologyBuilder.connectInputAllStream(computeStream, modelAggrProc); + } + + @Override + public Processor getInputProcessor() { + return this.filterProc; + } + + @Override + public Set getResultStreams() { + return ImmutableSet.of(this.resultStream); + } + + protected ChangeDetector changeDetector; + + @Override + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + @Override + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } + + static class LearningNodeIdGenerator { + + // TODO: add code to warn user of when value reaches Long.MAX_VALUES + private static long id = 0; + + static synchronized long generate() { + return id++; + } + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/ClusteringContentEvent.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/ClusteringContentEvent.java new file mode 100644 index 00000000000..85ce3080aeb --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/ClusteringContentEvent.java @@ -0,0 +1,89 @@ +package org.apache.heron.learners.clusterers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.instances.Instance; + +import net.jcip.annotations.Immutable; + +/** + * The Class ClusteringContentEvent. + */ +@Immutable +final public class ClusteringContentEvent implements ContentEvent { + + private static final long serialVersionUID = -7746983521296618922L; + private Instance instance; + private boolean isLast = false; + private String key; + private boolean isSample; + + public ClusteringContentEvent() { + // Necessary for kryo serializer + } + + /** + * Instantiates a new clustering event. + * + * @param index + * the index + * @param instance + * the instance + */ + public ClusteringContentEvent(long index, Instance instance) { + /* + * if (instance != null) { this.instance = new + * SerializableInstance(instance); } + */ + this.instance = instance; + this.setKey(Long.toString(index)); + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + this.key = str; + } + + @Override + public boolean isLastEvent() { + return this.isLast; + } + + public void setLast(boolean isLast) { + this.isLast = isLast; + } + + public Instance getInstance() { + return this.instance; + } + + public boolean isSample() { + return isSample; + } + + public void setSample(boolean b) { + this.isSample = b; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/ClustreamClustererAdapter.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/ClustreamClustererAdapter.java new file mode 100644 index 00000000000..f0862123cd6 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/ClustreamClustererAdapter.java @@ -0,0 +1,170 @@ +package org.apache.heron.learners.clusterers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * License + */ +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.clusterers.clustream.Clustream; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Base class for adapting Clustream clusterer. + * + */ +public class ClustreamClustererAdapter implements LocalClustererAdapter, Configurable { + + /** + * + */ + private static final long serialVersionUID = 4372366401338704353L; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Clusterer to train.", org.apache.samoa.moa.clusterers.Clusterer.class, Clustream.class.getName()); + /** + * The learner. + */ + protected org.apache.samoa.moa.clusterers.Clusterer learner; + + /** + * The is init. + */ + protected Boolean isInit; + + /** + * The dataset. + */ + protected Instances dataset; + + @Override + public void setDataset(Instances dataset) { + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner + * the learner + * @param dataset + * the dataset + */ + public ClustreamClustererAdapter(org.apache.samoa.moa.clusterers.Clusterer learner, Instances dataset) { + this.learner = learner.copy(); + this.isInit = false; + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner + * the learner + * @param dataset + * the dataset + */ + public ClustreamClustererAdapter() { + this.learner = ((org.apache.samoa.moa.clusterers.Clusterer) this.learnerOption.getValue()).copy(); + this.isInit = false; + // this.dataset = dataset; + } + + /** + * Creates a new learner object. + * + * @return the learner + */ + @Override + public ClustreamClustererAdapter create() { + ClustreamClustererAdapter l = new ClustreamClustererAdapter(learner, dataset); + if (dataset == null) { + System.out.println("dataset null while creating"); + } + return l; + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + if (this.isInit == false) { + this.isInit = true; + InstancesHeader instances = new InstancesHeader(dataset); + this.learner.setModelContext(instances); + this.learner.prepareForUse(); + } + if (inst.weight() > 0) { + inst.setDataset(dataset); + learner.trainOnInstance(inst); + } + } + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + @Override + public double[] getVotesForInstance(Instance inst) { + double[] ret; + inst.setDataset(dataset); + if (this.isInit == false) { + ret = new double[dataset.numClasses()]; + } else { + ret = learner.getVotesForInstance(inst); + } + return ret; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + @Override + public void resetLearning() { + learner.resetLearning(); + } + + public boolean implementsMicroClusterer() { + return this.learner.implementsMicroClusterer(); + } + + public Clustering getMicroClusteringResult() { + return this.learner.getMicroClusteringResult(); + } + + public Instances getDataset() { + return this.dataset; + } + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/LocalClustererAdapter.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/LocalClustererAdapter.java new file mode 100644 index 00000000000..18d970376e0 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/LocalClustererAdapter.java @@ -0,0 +1,79 @@ +package org.apache.heron.learners.clusterers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import java.io.Serializable; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.moa.cluster.Clustering; + +/** + * Learner interface for non-distributed learners. + * + * @author abifet + */ +public interface LocalClustererAdapter extends Serializable { + + /** + * Creates a new learner object. + * + * @return the learner + */ + LocalClustererAdapter create(); + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + double[] getVotesForInstance(Instance inst); + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + void resetLearning(); + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + void trainOnInstance(Instance inst); + + /** + * Sets where to obtain the information of attributes of Instances + * + * @param dataset + * the dataset that contains the information + */ + public void setDataset(Instances dataset); + + public Instances getDataset(); + + public boolean implementsMicroClusterer(); + + public Clustering getMicroClusteringResult(); + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/LocalClustererProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/LocalClustererProcessor.java new file mode 100644 index 00000000000..9fbe4eed077 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/LocalClustererProcessor.java @@ -0,0 +1,198 @@ +package org.apache.heron.learners.clusterers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + * License + */ +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.evaluation.ClusteringEvaluationContentEvent; +import org.apache.samoa.evaluation.ClusteringResultContentEvent; +import org.apache.samoa.instances.DenseInstance; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +//import weka.core.Instance; + +/** + * The Class LearnerProcessor. + */ +final public class LocalClustererProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -1577910988699148691L; + private static final Logger logger = LoggerFactory + .getLogger(LocalClustererProcessor.class); + private LocalClustererAdapter model; + private Stream outputStream; + private int modelId; + private long instancesCount = 0; + private long sampleFrequency = 1000; + + public long getSampleFrequency() { + return sampleFrequency; + } + + public void setSampleFrequency(long sampleFrequency) { + this.sampleFrequency = sampleFrequency; + } + + /** + * Sets the learner. + * + * @param model + * the model to set + */ + public void setLearner(LocalClustererAdapter model) { + this.model = model; + } + + /** + * Gets the learner. + * + * @return the model + */ + public LocalClustererAdapter getLearner() { + return model; + } + + /** + * Set the output streams. + * + * @param outputStream + * the new output stream {@link PredictionCombinerPE}. + */ + public void setOutputStream(Stream outputStream) { + + this.outputStream = outputStream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the instances count. + * + * @return number of observation vectors used in training iteration. + */ + public long getInstancesCount() { + return instancesCount; + } + + /** + * Update stats. + * + * @param event + * the event + */ + private void updateStats(ContentEvent event) { + Instance instance; + if (event instanceof ClusteringContentEvent) { + // Local Clustering + ClusteringContentEvent ev = (ClusteringContentEvent) event; + instance = ev.getInstance(); + DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey())); + model.trainOnInstance(point); + instancesCount++; + } + + if (event instanceof ClusteringResultContentEvent) { + // Global Clustering + ClusteringResultContentEvent ev = (ClusteringResultContentEvent) event; + Clustering clustering = ev.getClustering(); + + for (int i = 0; i < clustering.size(); i++) { + instance = new DenseInstance(1.0, clustering.get(i).getCenter()); + instance.setDataset(model.getDataset()); + DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey())); + model.trainOnInstance(point); + instancesCount++; + } + } + + if (instancesCount % this.sampleFrequency == 0) { + logger.info("Trained model using {} events with classifier id {}", instancesCount, this.modelId); // getId()); + } + } + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + if (event.isLastEvent() || + (instancesCount > 0 && instancesCount % this.sampleFrequency == 0)) { + if (model.implementsMicroClusterer()) { + + Clustering clustering = model.getMicroClusteringResult(); + + ClusteringResultContentEvent resultEvent = new ClusteringResultContentEvent(clustering, event.isLastEvent()); + + this.outputStream.put(resultEvent); + } + } + + updateStats(event); + return false; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#onCreate(int) + */ + @Override + public void onCreate(int id) { + this.modelId = id; + model = model.create(); + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + LocalClustererProcessor newProcessor = new LocalClustererProcessor(); + LocalClustererProcessor originProcessor = (LocalClustererProcessor) sourceProcessor; + if (originProcessor.getLearner() != null) { + newProcessor.setLearner(originProcessor.getLearner().create()); + } + newProcessor.setOutputStream(originProcessor.getOutputStream()); + return newProcessor; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/SingleLearner.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/SingleLearner.java new file mode 100644 index 00000000000..3ed37c11e8d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/SingleLearner.java @@ -0,0 +1,97 @@ +package org.apache.heron.learners.clusterers; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Learner that contain a single learner. + * + */ +public final class SingleLearner implements Learner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private LocalClustererProcessor learnerP; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName()); + + private TopologyBuilder builder; + + private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + + this.builder.addProcessor(learnerP, this.parallelism); + resultStream = this.builder.createStream(learnerP); + + learnerP.setOutputStream(resultStream); + } + + /* + * (non-Javadoc) + * + * @see samoa.classifiers.Classifier#getInputProcessingItem() + */ + @Override + public Processor getInputProcessor() { + return learnerP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set getResultStreams() { + Set streams = ImmutableSet.of(this.resultStream); + return streams; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/simple/ClusteringDistributorProcessor.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/simple/ClusteringDistributorProcessor.java new file mode 100644 index 00000000000..796a33b4417 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/simple/ClusteringDistributorProcessor.java @@ -0,0 +1,100 @@ +package org.apache.heron.learners.clusterers.simple; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * License + */ +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.evaluation.ClusteringEvaluationContentEvent; +import org.apache.samoa.learners.clusterers.ClusteringContentEvent; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.topology.Stream; + +/** + * The Class ClusteringDistributorPE. + */ +public class ClusteringDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192730L; + + private Stream outputStream; + private Stream evaluationStream; + private int numInstances; + + public Stream getOutputStream() { + return outputStream; + } + + public void setOutputStream(Stream outputStream) { + this.outputStream = outputStream; + } + + public Stream getEvaluationStream() { + return evaluationStream; + } + + public void setEvaluationStream(Stream evaluationStream) { + this.evaluationStream = evaluationStream; + } + + /** + * Process event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + // distinguish between ClusteringContentEvent and + // ClusteringEvaluationContentEvent + if (event instanceof ClusteringContentEvent) { + ClusteringContentEvent cce = (ClusteringContentEvent) event; + outputStream.put(event); + if (cce.isSample()) { + evaluationStream.put(new ClusteringEvaluationContentEvent(null, + new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent())); + } + } else if (event instanceof ClusteringEvaluationContentEvent) { + evaluationStream.put(event); + } + return true; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor(); + ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) + newProcessor.setOutputStream(originProcessor.getOutputStream()); + if (originProcessor.getEvaluationStream() != null) + newProcessor.setEvaluationStream(originProcessor.getEvaluationStream()); + return newProcessor; + } + + public void onCreate(int id) { + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/simple/DistributedClusterer.java b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/simple/DistributedClusterer.java new file mode 100644 index 00000000000..a564b56c2b5 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/learners/clusterers/simple/DistributedClusterer.java @@ -0,0 +1,121 @@ +package org.apache.heron.learners.clusterers.simple; + + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.clusterers.*; +import org.apache.samoa.topology.ProcessingItem; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * + * Learner that contain a single learner. + * + */ +public final class DistributedClusterer implements Learner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class, + ClustreamClustererAdapter.class.getName()); + + public IntOption paralellismOption = new IntOption("paralellismOption", 'P', + "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE); + + private TopologyBuilder builder; + + // private ClusteringDistributorProcessor distributorP; + private LocalClustererProcessor learnerP; + + // private Stream distributorToLocalStream; + private Stream localToGlobalStream; + + // private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + // this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + // Distributor + // distributorP = new ClusteringDistributorProcessor(); + // this.builder.addProcessor(distributorP, parallelism); + // distributorToLocalStream = this.builder.createStream(distributorP); + // distributorP.setOutputStream(distributorToLocalStream); + // distributorToGlobalStream = this.builder.createStream(distributorP); + + // Local Clustering + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + builder.addProcessor(learnerP, this.paralellismOption.getValue()); + localToGlobalStream = this.builder.createStream(learnerP); + learnerP.setOutputStream(localToGlobalStream); + + // Global Clustering + LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor(); + LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue(); + globalLearner.setDataset(this.dataset); + globalClusteringCombinerP.setLearner(learner); + builder.addProcessor(globalClusteringCombinerP, 1); + builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP); + + // Output Stream + resultStream = this.builder.createStream(globalClusteringCombinerP); + globalClusteringCombinerP.setOutputStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + // return distributorP; + return learnerP; + } + + @Override + public Set getResultStreams() { + Set streams = ImmutableSet.of(this.resultStream); + return streams; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormBoltStream.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormBoltStream.java new file mode 100644 index 00000000000..8f7a955899b --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormBoltStream.java @@ -0,0 +1,65 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import org.apache.samoa.core.ContentEvent; + +import backtype.storm.task.OutputCollector; +import backtype.storm.tuple.Values; + +/** + * Storm Stream that connects into Bolt. It wraps Storm's outputCollector class + * + * @author Arinto Murdopo + * + */ +class StormBoltStream extends StormStream { + + /** + * + */ + private static final long serialVersionUID = -5712513402991550847L; + + private OutputCollector outputCollector; + + StormBoltStream(String stormComponentId) { + super(stormComponentId); + } + + @Override + public void put(ContentEvent contentEvent) { + outputCollector.emit(this.outputStreamId, new Values(contentEvent, contentEvent.getKey())); + } + + public void setCollector(OutputCollector outputCollector) { + this.outputCollector = outputCollector; + } + + // @Override + // public void setStreamId(String streamId) { + // // TODO Auto-generated method stub + // //this.outputStreamId = streamId; + // } + + @Override + public String getStreamId() { + // TODO Auto-generated method stub + return null; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormComponentFactory.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormComponentFactory.java new file mode 100644 index 00000000000..b5441039b32 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormComponentFactory.java @@ -0,0 +1,89 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.core.EntranceProcessor; +import org.apache.samoa.core.Processor; +import org.apache.samoa.topology.ComponentFactory; +import org.apache.samoa.topology.EntranceProcessingItem; +import org.apache.samoa.topology.IProcessingItem; +import org.apache.samoa.topology.ProcessingItem; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.Topology; + +/** + * Component factory implementation for samoa-storm + */ +public final class StormComponentFactory implements ComponentFactory { + + private final Map processorList; + + public StormComponentFactory() { + processorList = new HashMap<>(); + } + + @Override + public ProcessingItem createPi(Processor processor) { + return new StormProcessingItem(processor, this.getComponentName(processor.getClass()), 1); + } + + @Override + public EntranceProcessingItem createEntrancePi(EntranceProcessor processor) { + return new StormEntranceProcessingItem(processor, this.getComponentName(processor.getClass())); + } + + @Override + public Stream createStream(IProcessingItem sourcePi) { + StormTopologyNode stormCompatiblePi = (StormTopologyNode) sourcePi; + return stormCompatiblePi.createStream(); + } + + @Override + public Topology createTopology(String topoName) { + return new StormTopology(topoName); + } + + private String getComponentName(Class clazz) { + StringBuilder componentName = new StringBuilder(clazz.getCanonicalName()); + String key = componentName.toString(); + Integer index; + + if (!processorList.containsKey(key)) { + index = 1; + } else { + index = processorList.get(key) + 1; + } + + processorList.put(key, index); + + componentName.append('_'); + componentName.append(index); + + return componentName.toString(); + } + + @Override + public ProcessingItem createPi(Processor processor, int parallelism) { + return new StormProcessingItem(processor, this.getComponentName(processor.getClass()), parallelism); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormDoTask.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormDoTask.java new file mode 100644 index 00000000000..ab7b9eeda86 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormDoTask.java @@ -0,0 +1,117 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import backtype.storm.Config; +import backtype.storm.utils.Utils; + +/** + * The main class that used by samoa script to execute SAMOA task. + * + * @author Arinto Murdopo + * + */ +public class StormDoTask { + private static final Logger logger = LoggerFactory.getLogger(StormDoTask.class); + private static String localFlag = "local"; + private static String clusterFlag = "cluster"; + + /** + * The main method. + * + * @param args + * the arguments + */ + public static void main(String[] args) { + + List tmpArgs = new ArrayList(Arrays.asList(args)); + + boolean isLocal = isLocal(tmpArgs); + int numWorker = StormSamoaUtils.numWorkers(tmpArgs); + + args = tmpArgs.toArray(new String[0]); + + // convert the arguments into Storm topology + StormTopology stormTopo = StormSamoaUtils.argsToTopology(args); + String topologyName = stormTopo.getTopologyName(); + + Config conf = new Config(); + conf.putAll(Utils.readStormConfig()); + conf.setDebug(false); + + if (isLocal) { + // local mode + conf.setMaxTaskParallelism(numWorker); + + backtype.storm.LocalCluster cluster = new backtype.storm.LocalCluster(); + cluster.submitTopology(topologyName, conf, stormTopo.getStormBuilder().createTopology()); + + backtype.storm.utils.Utils.sleep(600 * 1000); + + cluster.killTopology(topologyName); + cluster.shutdown(); + + } else { + // cluster mode + conf.setNumWorkers(numWorker); + try { + backtype.storm.StormSubmitter.submitTopology(topologyName, conf, + stormTopo.getStormBuilder().createTopology()); + } catch (backtype.storm.generated.AlreadyAliveException ale) { + ale.printStackTrace(); + } catch (backtype.storm.generated.InvalidTopologyException ite) { + ite.printStackTrace(); + } + } + } + + private static boolean isLocal(List tmpArgs) { + ExecutionMode executionMode = ExecutionMode.UNDETERMINED; + + int position = tmpArgs.size() - 1; + String flag = tmpArgs.get(position); + boolean isLocal = true; + + if (flag.equals(clusterFlag)) { + executionMode = ExecutionMode.CLUSTER; + isLocal = false; + } else if (flag.equals(localFlag)) { + executionMode = ExecutionMode.LOCAL; + isLocal = true; + } + + if (executionMode != ExecutionMode.UNDETERMINED) { + tmpArgs.remove(position); + } + + return isLocal; + } + + private enum ExecutionMode { + LOCAL, CLUSTER, UNDETERMINED + }; +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormEntranceProcessingItem.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormEntranceProcessingItem.java new file mode 100644 index 00000000000..79e1100ada9 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormEntranceProcessingItem.java @@ -0,0 +1,210 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import java.util.Map; +import java.util.UUID; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.EntranceProcessor; +import org.apache.samoa.topology.AbstractEntranceProcessingItem; +import org.apache.samoa.topology.EntranceProcessingItem; +import org.apache.samoa.topology.Stream; + +import backtype.storm.spout.SpoutOutputCollector; +import backtype.storm.task.TopologyContext; +import backtype.storm.topology.OutputFieldsDeclarer; +import backtype.storm.topology.base.BaseRichSpout; +import backtype.storm.tuple.Fields; +import backtype.storm.tuple.Values; +import backtype.storm.utils.Utils; + +/** + * EntranceProcessingItem implementation for Storm. + */ +class StormEntranceProcessingItem extends AbstractEntranceProcessingItem implements StormTopologyNode { + private final StormEntranceSpout piSpout; + + StormEntranceProcessingItem(EntranceProcessor processor) { + this(processor, UUID.randomUUID().toString()); + } + + StormEntranceProcessingItem(EntranceProcessor processor, String friendlyId) { + super(processor); + this.setName(friendlyId); + this.piSpout = new StormEntranceSpout(processor); + } + + @Override + public EntranceProcessingItem setOutputStream(Stream stream) { + // piSpout.streams.add(stream); + piSpout.setOutputStream((StormStream) stream); + return this; + } + + @Override + public Stream getOutputStream() { + return piSpout.getOutputStream(); + } + + @Override + public void addToTopology(StormTopology topology, int parallelismHint) { + topology.getStormBuilder().setSpout(this.getName(), piSpout, parallelismHint); + } + + @Override + public StormStream createStream() { + return piSpout.createStream(this.getName()); + } + + @Override + public String getId() { + return this.getName(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(super.toString()); + sb.insert(0, String.format("id: %s, ", this.getName())); + return sb.toString(); + } + + /** + * Resulting Spout of StormEntranceProcessingItem + */ + final static class StormEntranceSpout extends BaseRichSpout { + + private static final long serialVersionUID = -9066409791668954099L; + + // private final Set streams; + private final EntranceProcessor entranceProcessor; + private StormStream outputStream; + + // private transient SpoutStarter spoutStarter; + // private transient Executor spoutExecutors; + // private transient LinkedBlockingQueue tupleInfoQueue; + + private SpoutOutputCollector collector; + + StormEntranceSpout(EntranceProcessor processor) { + // this.streams = new HashSet(); + this.entranceProcessor = processor; + } + + public StormStream getOutputStream() { + return outputStream; + } + + public void setOutputStream(StormStream stream) { + this.outputStream = stream; + } + + @Override + public void open(@SuppressWarnings("rawtypes") Map conf, TopologyContext context, SpoutOutputCollector collector) { + this.collector = collector; + // this.tupleInfoQueue = new LinkedBlockingQueue(); + + // Processor and this class share the same instance of stream + // for (StormSpoutStream stream : streams) { + // stream.setSpout(this); + // } + // outputStream.setSpout(this); + + this.entranceProcessor.onCreate(context.getThisTaskId()); + // this.spoutStarter = new SpoutStarter(this.starter); + + // this.spoutExecutors = Executors.newSingleThreadExecutor(); + // this.spoutExecutors.execute(spoutStarter); + } + + @Override + public void nextTuple() { + if (entranceProcessor.hasNext()) { + Values value = newValues(entranceProcessor.nextEvent()); + collector.emit(outputStream.getOutputId(), value); + } else + Utils.sleep(1000); + // StormTupleInfo tupleInfo = tupleInfoQueue.poll(50, + // TimeUnit.MILLISECONDS); + // if (tupleInfo != null) { + // Values value = newValues(tupleInfo.getContentEvent()); + // collector.emit(tupleInfo.getStormStream().getOutputId(), value); + // } + } + + @Override + public void declareOutputFields(OutputFieldsDeclarer declarer) { + // for (StormStream stream : streams) { + // declarer.declareStream(stream.getOutputId(), new + // Fields(StormSamoaUtils.CONTENT_EVENT_FIELD, + // StormSamoaUtils.KEY_FIELD)); + // } + declarer.declareStream(outputStream.getOutputId(), new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD, + StormSamoaUtils.KEY_FIELD)); + } + + StormStream createStream(String piId) { + // StormSpoutStream stream = new StormSpoutStream(piId); + StormStream stream = new StormBoltStream(piId); + // streams.add(stream); + return stream; + } + + // void put(StormSpoutStream stream, ContentEvent contentEvent) { + // tupleInfoQueue.add(new StormTupleInfo(stream, contentEvent)); + // } + + private Values newValues(ContentEvent contentEvent) { + return new Values(contentEvent, contentEvent.getKey()); + } + + // private final static class StormTupleInfo { + // + // private final StormStream stream; + // private final ContentEvent event; + // + // StormTupleInfo(StormStream stream, ContentEvent event) { + // this.stream = stream; + // this.event = event; + // } + // + // public StormStream getStormStream() { + // return this.stream; + // } + // + // public ContentEvent getContentEvent() { + // return this.event; + // } + // } + + // private final static class SpoutStarter implements Runnable { + // + // private final TopologyStarter topoStarter; + // + // SpoutStarter(TopologyStarter topoStarter) { + // this.topoStarter = topoStarter; + // } + // + // @Override + // public void run() { + // this.topoStarter.start(); + // } + // } + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormJarSubmitter.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormJarSubmitter.java new file mode 100644 index 00000000000..1d211d43c55 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormJarSubmitter.java @@ -0,0 +1,74 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Properties; + +import backtype.storm.Config; +import backtype.storm.StormSubmitter; +import backtype.storm.utils.Utils; + +/** + * Utility class to submit samoa-storm jar to a Storm cluster. + * + * @author Arinto Murdopo + * + */ +public class StormJarSubmitter { + + public final static String UPLOADED_JAR_LOCATION_KEY = "UploadedJarLocation"; + + /** + * @param args + * @throws IOException + */ + public static void main(String[] args) throws IOException { + + Config config = new Config(); + config.putAll(Utils.readCommandLineOpts()); + config.putAll(Utils.readStormConfig()); + + String nimbusHost = (String) config.get(Config.NIMBUS_HOST); + int nimbusThriftPort = Utils.getInt(config + .get(Config.NIMBUS_THRIFT_PORT)); + + System.out.println("Nimbus host " + nimbusHost); + System.out.println("Nimbus thrift port " + nimbusThriftPort); + + System.out.println("uploading jar from " + args[0]); + String uploadedJarLocation = StormSubmitter.submitJar(config, args[0]); + + System.out.println("Uploaded jar file location: "); + System.out.println(uploadedJarLocation); + + Properties props = StormSamoaUtils.getProperties(); + props.setProperty(StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY, uploadedJarLocation); + + File f = new File("src/main/resources/samoa-storm-cluster.properties"); + f.createNewFile(); + + OutputStream out = new FileOutputStream(f); + props.store(out, "properties file to store uploaded jar location from StormJarSubmitter"); + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormProcessingItem.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormProcessingItem.java new file mode 100644 index 00000000000..180fdb44544 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormProcessingItem.java @@ -0,0 +1,168 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.topology.AbstractProcessingItem; +import org.apache.samoa.topology.ProcessingItem; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.impl.StormStream.InputStreamId; +import org.apache.samoa.utils.PartitioningScheme; + +import backtype.storm.task.OutputCollector; +import backtype.storm.task.TopologyContext; +import backtype.storm.topology.BoltDeclarer; +import backtype.storm.topology.OutputFieldsDeclarer; +import backtype.storm.topology.TopologyBuilder; +import backtype.storm.topology.base.BaseRichBolt; +import backtype.storm.tuple.Fields; +import backtype.storm.tuple.Tuple; + +/** + * ProcessingItem implementation for Storm. + * + * @author Arinto Murdopo + * + */ +class StormProcessingItem extends AbstractProcessingItem implements StormTopologyNode { + private final ProcessingItemBolt piBolt; + private BoltDeclarer piBoltDeclarer; + + // TODO: should we put parallelism hint here? + // imo, parallelism hint only declared when we add this PI in the topology + // open for dicussion :p + + StormProcessingItem(Processor processor, int parallelismHint) { + this(processor, UUID.randomUUID().toString(), parallelismHint); + } + + StormProcessingItem(Processor processor, String friendlyId, int parallelismHint) { + super(processor, parallelismHint); + this.piBolt = new ProcessingItemBolt(processor); + this.setName(friendlyId); + } + + @Override + protected ProcessingItem addInputStream(Stream inputStream, PartitioningScheme scheme) { + StormStream stormInputStream = (StormStream) inputStream; + InputStreamId inputId = stormInputStream.getInputId(); + + switch (scheme) { + case SHUFFLE: + piBoltDeclarer.shuffleGrouping(inputId.getComponentId(), inputId.getStreamId()); + break; + case GROUP_BY_KEY: + piBoltDeclarer.fieldsGrouping( + inputId.getComponentId(), + inputId.getStreamId(), + new Fields(StormSamoaUtils.KEY_FIELD)); + break; + case BROADCAST: + piBoltDeclarer.allGrouping( + inputId.getComponentId(), + inputId.getStreamId()); + break; + } + return this; + } + + @Override + public void addToTopology(StormTopology topology, int parallelismHint) { + if (piBoltDeclarer != null) { + // throw exception that one PI only belong to one topology + } else { + TopologyBuilder stormBuilder = topology.getStormBuilder(); + this.piBoltDeclarer = stormBuilder.setBolt(this.getName(), + this.piBolt, parallelismHint); + } + } + + @Override + public StormStream createStream() { + return piBolt.createStream(this.getName()); + } + + @Override + public String getId() { + return this.getName(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(super.toString()); + sb.insert(0, String.format("id: %s, ", this.getName())); + return sb.toString(); + } + + private final static class ProcessingItemBolt extends BaseRichBolt { + + private static final long serialVersionUID = -6637673741263199198L; + + private final Set streams; + private final Processor processor; + + private OutputCollector collector; + + ProcessingItemBolt(Processor processor) { + this.streams = new HashSet(); + this.processor = processor; + } + + @Override + public void prepare(@SuppressWarnings("rawtypes") Map stormConf, TopologyContext context, + OutputCollector collector) { + this.collector = collector; + // Processor and this class share the same instance of stream + for (StormBoltStream stream : streams) { + stream.setCollector(this.collector); + } + + this.processor.onCreate(context.getThisTaskId()); + } + + @Override + public void execute(Tuple input) { + Object sentObject = input.getValue(0); + ContentEvent sentEvent = (ContentEvent) sentObject; + processor.process(sentEvent); + } + + @Override + public void declareOutputFields(OutputFieldsDeclarer declarer) { + for (StormStream stream : streams) { + declarer.declareStream(stream.getOutputId(), + new Fields(StormSamoaUtils.CONTENT_EVENT_FIELD, + StormSamoaUtils.KEY_FIELD)); + } + } + + StormStream createStream(String piId) { + StormBoltStream stream = new StormBoltStream(piId); + streams.add(stream); + return stream; + } + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormSamoaUtils.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormSamoaUtils.java new file mode 100644 index 00000000000..7db00556b8e --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormSamoaUtils.java @@ -0,0 +1,129 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import com.github.javacliparser.ClassOption; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Properties; + +import org.apache.samoa.tasks.Task; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.commons.configuration.Configuration; +import org.apache.commons.configuration.ConfigurationException; +import org.apache.commons.configuration.PropertiesConfiguration; + +/** + * Utility class for samoa-storm project. It is used by StormDoTask to process its arguments. + * + * @author Arinto Murdopo + * + */ +public class StormSamoaUtils { + + private static final Logger logger = LoggerFactory.getLogger(StormSamoaUtils.class); + + static final String KEY_FIELD = "key"; + static final String CONTENT_EVENT_FIELD = "content_event"; + + static Properties getProperties() throws IOException { + Properties props = new Properties(); + InputStream is; + + File f = new File("src/main/resources/samoa-storm-cluster.properties"); // FIXME it does not exist anymore + is = new FileInputStream(f); + + try { + props.load(is); + } catch (IOException e1) { + System.out.println("Fail to load property file"); + return null; + } finally { + is.close(); + } + + return props; + } + + public static StormTopology argsToTopology(String[] args) { + StringBuilder cliString = new StringBuilder(); + for (String arg : args) { + cliString.append(" ").append(arg); + } + logger.debug("Command line string = {}", cliString.toString()); + + Task task = getTask(cliString.toString()); + + // TODO: remove setFactory method with DynamicBinding + task.setFactory(new StormComponentFactory()); + task.init(); + + return (StormTopology) task.getTopology(); + } + + public static int numWorkers(List tmpArgs) { + int position = tmpArgs.size() - 1; + int numWorkers; + + try { + numWorkers = Integer.parseInt(tmpArgs.get(position)); + tmpArgs.remove(position); + } catch (NumberFormatException e) { + numWorkers = 4; + } + + return numWorkers; + } + + public static Task getTask(String cliString) { + Task task = null; + try { + logger.debug("Providing task [{}]", cliString); + task = ClassOption.cliStringToObject(cliString, Task.class, null); + } catch (Exception e) { + logger.warn("Fail in initializing the task!"); + e.printStackTrace(); + } + return task; + } + + public static Configuration getPropertyConfig(String configPropertyPath){ + Configuration config = null; + try { + config = new PropertiesConfiguration(configPropertyPath); + if (null == config || config.isEmpty()) { + logger.error("Configuration is null or empty at file = {}",configPropertyPath); + throw new RuntimeException("Configuration is null or empty : " + configPropertyPath); + } + } + catch(ConfigurationException configurationException) + { + logger.error("ConfigurationException while reading property file = {}",configurationException); + throw new RuntimeException("ConfigurationException while reading property file : " + configPropertyPath); + } + return config; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormSpoutStream.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormSpoutStream.java new file mode 100644 index 00000000000..4f7ec1584c2 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormSpoutStream.java @@ -0,0 +1,65 @@ +package org.apache.heron.topology.impl; +//package org.apache.samoa.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// +//import org.apache.samoa.core.ContentEvent; +//import org.apache.samoa.topology.impl.StormEntranceProcessingItem.StormEntranceSpout; +// +///** +// * Storm Stream that connects into Spout. It wraps the spout itself +// * @author Arinto Murdopo +// * +// */ +//final class StormSpoutStream extends StormStream{ +// +// /** +// * +// */ +// private static final long serialVersionUID = -7444653177614988650L; +// +// private StormEntranceSpout spout; +// +// StormSpoutStream(String stormComponentId) { +// super(stormComponentId); +// } +// +// @Override +// public void put(ContentEvent contentEvent) { +// spout.put(this, contentEvent); +// } +// +// void setSpout(StormEntranceSpout spout){ +// this.spout = spout; +// } +// +//// @Override +//// public void setStreamId(String stream) { +//// // TODO Auto-generated method stub +//// +//// } +// +// @Override +// public String getStreamId() { +// // TODO Auto-generated method stub +// return null; +// } +// +// } diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormStream.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormStream.java new file mode 100644 index 00000000000..33c768f943d --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormStream.java @@ -0,0 +1,85 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.util.UUID; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.topology.Stream; + +/** + * Abstract class to implement Storm Stream + * + * @author Arinto Murdopo + * + */ +abstract class StormStream implements Stream, java.io.Serializable { + + /** + * + */ + private static final long serialVersionUID = 281835563756514852L; + protected final String outputStreamId; + protected final InputStreamId inputStreamId; + + public StormStream(String stormComponentId) { + this.outputStreamId = UUID.randomUUID().toString(); + this.inputStreamId = new InputStreamId(stormComponentId, this.outputStreamId); + } + + @Override + public abstract void put(ContentEvent contentEvent); + + String getOutputId() { + return this.outputStreamId; + } + + InputStreamId getInputId() { + return this.inputStreamId; + } + + final static class InputStreamId implements java.io.Serializable { + + /** + * + */ + private static final long serialVersionUID = -7457995634133691295L; + private final String componentId; + private final String streamId; + + InputStreamId(String componentId, String streamId) { + this.componentId = componentId; + this.streamId = streamId; + } + + String getComponentId() { + return componentId; + } + + String getStreamId() { + return streamId; + } + } + + @Override + public void setBatchSize(int batchSize) { + // Ignore batch size + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopology.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopology.java new file mode 100644 index 00000000000..f09eb1d453f --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopology.java @@ -0,0 +1,52 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.samoa.topology.AbstractTopology; +import org.apache.samoa.topology.IProcessingItem; + +import backtype.storm.topology.TopologyBuilder; + +/** + * Adaptation of SAMOA topology in samoa-storm + * + * @author Arinto Murdopo + * + */ +public class StormTopology extends AbstractTopology { + + private TopologyBuilder builder; + + public StormTopology(String topologyName) { + super(topologyName); + this.builder = new TopologyBuilder(); + } + + @Override + public void addProcessingItem(IProcessingItem procItem, int parallelismHint) { + StormTopologyNode stormNode = (StormTopologyNode) procItem; + stormNode.addToTopology(this, parallelismHint); + super.addProcessingItem(procItem, parallelismHint); + } + + public TopologyBuilder getStormBuilder() { + return builder; + } +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopologyNode.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopologyNode.java new file mode 100644 index 00000000000..539414eac15 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopologyNode.java @@ -0,0 +1,36 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Interface to represent a node in samoa-storm topology. + * + * @author Arinto Murdopo + * + */ +interface StormTopologyNode { + + void addToTopology(StormTopology topology, int parallelismHint); + + StormStream createStream(); + + String getId(); + +} diff --git a/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopologySubmitter.java b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopologySubmitter.java new file mode 100644 index 00000000000..3602c5b2806 --- /dev/null +++ b/heron/mlmgr/src/main/java/org/apache/heron/topology/impl/StormTopologySubmitter.java @@ -0,0 +1,131 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import java.io.IOException; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import org.apache.thrift7.TException; +import org.json.simple.JSONValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import backtype.storm.Config; +import backtype.storm.generated.AlreadyAliveException; +import backtype.storm.generated.InvalidTopologyException; +import backtype.storm.utils.NimbusClient; +import backtype.storm.utils.Utils; + +/** + * Helper class to submit SAMOA task into Storm without the need of submitting the jar file. The jar file must be + * submitted first using StormJarSubmitter class. + * + * @author Arinto Murdopo + * + */ +public class StormTopologySubmitter { + + public static String YJP_OPTIONS_KEY = "YjpOptions"; + + private static Logger logger = LoggerFactory.getLogger(StormTopologySubmitter.class); + + public static void main(String[] args) throws IOException { + Properties props = StormSamoaUtils.getProperties(); + + String uploadedJarLocation = props.getProperty(StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY); + if (uploadedJarLocation == null) { + logger.error("Invalid properties file. It must have key {}", + StormJarSubmitter.UPLOADED_JAR_LOCATION_KEY); + return; + } + + List tmpArgs = new ArrayList(Arrays.asList(args)); + int numWorkers = StormSamoaUtils.numWorkers(tmpArgs); + + args = tmpArgs.toArray(new String[0]); + StormTopology stormTopo = StormSamoaUtils.argsToTopology(args); + + Config conf = new Config(); + conf.putAll(Utils.readStormConfig()); + conf.putAll(Utils.readCommandLineOpts()); + conf.setDebug(false); + conf.setNumWorkers(numWorkers); + + String profilerOption = + props.getProperty(StormTopologySubmitter.YJP_OPTIONS_KEY); + if (profilerOption != null) { + String topoWorkerChildOpts = (String) conf.get(Config.TOPOLOGY_WORKER_CHILDOPTS); + StringBuilder optionBuilder = new StringBuilder(); + if (topoWorkerChildOpts != null) { + optionBuilder.append(topoWorkerChildOpts); + optionBuilder.append(' '); + } + optionBuilder.append(profilerOption); + conf.put(Config.TOPOLOGY_WORKER_CHILDOPTS, optionBuilder.toString()); + } + + Map myConfigMap = new HashMap(conf); + StringWriter out = new StringWriter(); + + try { + JSONValue.writeJSONString(myConfigMap, out); + } catch (IOException e) { + System.out.println("Error in writing JSONString"); + e.printStackTrace(); + return; + } + + Config config = new Config(); + config.putAll(Utils.readStormConfig()); + + NimbusClient nc = NimbusClient.getConfiguredClient(config); + String topologyName = stormTopo.getTopologyName(); + try { + System.out.println("Submitting topology with name: " + + topologyName); + nc.getClient().submitTopology(topologyName, uploadedJarLocation, + out.toString(), stormTopo.getStormBuilder().createTopology()); + System.out.println(topologyName + " is successfully submitted"); + + } catch (AlreadyAliveException aae) { + System.out.println("Fail to submit " + topologyName + + "\nError message: " + aae.get_msg()); + } catch (InvalidTopologyException ite) { + System.out.println("Invalid topology for " + topologyName); + ite.printStackTrace(); + } catch (TException te) { + System.out.println("Texception for " + topologyName); + te.printStackTrace(); + } + } + + private static String uploadedJarLocation(List tmpArgs) { + int position = tmpArgs.size() - 1; + String uploadedJarLocation = tmpArgs.get(position); + tmpArgs.remove(position); + return uploadedJarLocation; + } +} diff --git a/heron/mlmgr/src/test/java/org/apache/heron/AlgosTest.java b/heron/mlmgr/src/test/java/org/apache/heron/AlgosTest.java new file mode 100644 index 00000000000..1c8e97d6827 --- /dev/null +++ b/heron/mlmgr/src/test/java/org/apache/heron/AlgosTest.java @@ -0,0 +1,92 @@ +package org.apache.heron; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.LocalStormDoTask; +import org.apache.samoa.TestParams; +import org.apache.samoa.TestUtils; +import org.junit.Test; + +public class AlgosTest { + + @Test(timeout = 60000) + public void testVHTWithStorm() throws Exception { + + TestParams vhtConfig = new TestParams.Builder() + .inputInstances(200_000) + .samplingSize(20_000) + .evaluationInstances(200_000) + .classifiedInstances(200_000) + .labelSamplingSize(10l) + .classificationsCorrect(55f) + .kappaStat(-0.1f) + .kappaTempStat(-0.1f) + .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE) + .resultFilePollTimeout(30) + .prePollWait(15) + .taskClassName(LocalStormDoTask.class.getName()) + .build(); + TestUtils.test(vhtConfig); + + } + + @Test(timeout = 120000) + public void testBaggingWithStorm() throws Exception { + TestParams baggingConfig = new TestParams.Builder() + .inputInstances(200_000) + .samplingSize(20_000) + .evaluationInstances(180_000) + .classifiedInstances(190_000) + .labelSamplingSize(10l) + .classificationsCorrect(60f) + .kappaStat(0f) + .kappaTempStat(0f) + .cliStringTemplate(TestParams.Templates.PREQEVAL_BAGGING_RANDOMTREE) + .resultFilePollTimeout(40) + .prePollWait(20) + .taskClassName(LocalStormDoTask.class.getName()) + .build(); + TestUtils.test(baggingConfig); + + } + + @Test(timeout = 240000) + public void testCVPReqVHTWithStorm() throws Exception { + + TestParams vhtConfig = new TestParams.Builder() + .inputInstances(200_000) + .samplingSize(20_000) + .evaluationInstances(200_000) + .classifiedInstances(200_000) + .classificationsCorrect(55f) + .kappaStat(0f) + .kappaTempStat(0f) + .cliStringTemplate(TestParams.Templates.PREQCVEVAL_VHT_RANDOMTREE) + .resultFilePollTimeout(30) + .prePollWait(15) + .taskClassName(LocalStormDoTask.class.getName()) + .labelFileCreated(false) + .build(); + TestUtils.test(vhtConfig); + + } + +} diff --git a/heron/mlmgr/src/test/java/org/apache/heron/topology/impl/StormProcessingItemTest.java b/heron/mlmgr/src/test/java/org/apache/heron/topology/impl/StormProcessingItemTest.java new file mode 100644 index 00000000000..e02c7047983 --- /dev/null +++ b/heron/mlmgr/src/test/java/org/apache/heron/topology/impl/StormProcessingItemTest.java @@ -0,0 +1,82 @@ +package org.apache.heron.topology.impl; + +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import static org.junit.Assert.assertEquals; +import mockit.Expectations; +import mockit.MockUp; +import mockit.Mocked; +import mockit.Tested; +import mockit.Verifications; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.topology.impl.StormProcessingItem; +import org.apache.samoa.topology.impl.StormTopology; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import backtype.storm.topology.BoltDeclarer; +import backtype.storm.topology.IRichBolt; +import backtype.storm.topology.TopologyBuilder; + +public class StormProcessingItemTest { + private static final int PARRALLELISM_HINT_2 = 2; + private static final int PARRALLELISM_HINT_4 = 4; + private static final String ID = "id"; + @Tested + private StormProcessingItem pi; + @Mocked + private Processor processor; + @Mocked + private StormTopology topology; + @Mocked + private TopologyBuilder stormBuilder = new TopologyBuilder(); + + @Before + public void setUp() { + pi = new StormProcessingItem(processor, ID, PARRALLELISM_HINT_2); + } + + @Test + public void testAddToTopology() { + new Expectations() { + { + topology.getStormBuilder(); + result = stormBuilder; + + stormBuilder.setBolt(ID, (IRichBolt) any, anyInt); + result = new MockUp() { + }.getMockInstance(); + } + }; + + pi.addToTopology(topology, PARRALLELISM_HINT_4); // this parallelism hint is ignored + + new Verifications() { + { + assertEquals(pi.getProcessor(), processor); + // TODO add methods to explore a topology and verify them + assertEquals(pi.getParallelism(), PARRALLELISM_HINT_2); + assertEquals(pi.getId(), ID); + } + }; + } +}