diff --git a/hands_on/local_maxima_part2/local_maxima.py b/hands_on/local_maxima_part2/local_maxima.py index db89ba3..abca899 100644 --- a/hands_on/local_maxima_part2/local_maxima.py +++ b/hands_on/local_maxima_part2/local_maxima.py @@ -1,4 +1,6 @@ -def find_maxima(x): +import numpy as np + +def find_maxima(x: np.array): """Find local maxima of x. Input arguments: @@ -7,4 +9,15 @@ def find_maxima(x): Output: idx -- list of indices of the local maxima in x """ - return [] + idx = [] + for i in range(0, len(x)): + pred = True + if i != 0: + pred = pred and (x[i-1] <= x[i]) + if i != len(x)-1: + pred = pred and (x[i+1] < x[i]) + if pred: + idx.append(i) + return idx + + diff --git a/hands_on/local_maxima_part2/test_local_maxima.py b/hands_on/local_maxima_part2/test_local_maxima.py index 316442d..490d1fd 100644 --- a/hands_on/local_maxima_part2/test_local_maxima.py +++ b/hands_on/local_maxima_part2/test_local_maxima.py @@ -1,4 +1,5 @@ from local_maxima import find_maxima +import numpy as np def test_find_maxima(): @@ -23,8 +24,18 @@ def test_find_maxima_empty(): def test_find_maxima_plateau(): - raise Exception('not yet implemented') + values = [1, 2, 2, 1] + expected_v1 = [1] + expected_v2 = [2] + expected_v3 = [1, 2] + + assert (np.all(find_maxima(values) == expected_v1) + or np.all(find_maxima(values) == expected_v2) + or np.all(find_maxima(values) == expected_v3)) def test_find_maxima_not_a_plateau(): - raise Exception('not yet implemented') + values = np.array([1, 2, 2, 3, 1]) + expected = np.array([3]) + + assert np.all(find_maxima(values) == expected)