Arg Max
Argmax is a data wrangling operation to find the index of the maximum value in a list of values.
- Nada program
- Test 1
- Test 2
- Test 3
- Test 4
src/arg_max.py
from nada_dsl import *
import nada_numpy as na
DIM = 10
def argmax(array: na.NadaArray):
# The result (the index of the argmax)
result = Integer(-1)
# The current index of the loop. It is set to be a public value
current_index = Integer(0)
# Assume the max value is at index 0
max_val = array[0]
# Compare the remaining content of the array with max_val
for v in array:
# Is the current value, v, greater than max_val?
cond = v >= max_val
# If true, then max_val is set to be v. Otherwise, it is not changed.
max_val = cond.if_else(v, max_val)
# If true, then result index is updated to the current index. Otherwise, it is not changed.
result = cond.if_else(current_index, result)
# Increment the index counter.
current_index = current_index + Integer(1)
return result
def nada_main():
party1 = Party(name="Party1")
# A party to keep track of the index of the result
index_party = Party(name="index_party")
array = na.array([DIM], party1, "array", SecretInteger)
result = argmax(array)
return na.output(result, index_party, "argmax")
if __name__ == "__main__":
nada_main()
tests/arg_max_test_1.yaml
---
program: arg_max
inputs:
array_0: 5
array_1: 3
array_2: 1
array_3: 12
array_4: 3
array_5: 1
array_6: 5
array_7: 3
array_8: 1
array_9: 70
expected_outputs:
argmax: 9
tests/arg_max_test_2.yaml
---
program: arg_max
inputs:
array_0: 1
array_1: 2
array_2: 3
array_3: 4
array_4: 5
array_5: 69
array_6: 7
array_7: 8
array_8: 9
array_9: 10
expected_outputs:
argmax: 5
tests/arg_max_test_3.yaml
---
program: arg_max
inputs:
array_0: 10
array_1: 0
array_2: -1
array_3: 40
array_4: 99
array_5: 33
array_6: 22
array_7: 11
array_8: 0
array_9: 3
expected_outputs:
argmax: 4
tests/arg_max_test_4.yaml
---
program: arg_max
inputs:
array_0: 100
array_1: 99
array_2: 80
array_3: 79
array_4: 60
array_5: 59
array_6: 40
array_7: 39
array_8: 20
array_9: 19
expected_outputs:
argmax: 0
Run and test the arg_max program
1. Open "Nada by Example"
2. Run the program with inputs
from the test file
nada run arg_max_test_1
3. Test the program with inputs
from the test file against the expected_outputs
from the test file
nada test arg_max_test_1