- commit
- 1bf85abcec4588d0e3f0806016148a68ffb13918
- parent
- 39930b5dc497f10d174802bfc943c5c18cdb4c6f
- Author
- Tobias Bengfort <tobias.bengfort@posteo.de>
- Date
- 2025-05-06 05:43
add test.py
Diffstat
| A | test.py | 99 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
1 files changed, 99 insertions, 0 deletions
diff --git a/test.py b/test.py
@@ -0,0 +1,99 @@
-1 1 import argparse
-1 2 import json
-1 3
-1 4 LANG_MAP = {
-1 5 'afr': 'af',
-1 6 'ara': 'ar',
-1 7 'bul': 'bg',
-1 8 'ben': 'bn',
-1 9 'cat': 'ca',
-1 10 'ces': 'cs',
-1 11 'cym': 'cy',
-1 12 'dan': 'da',
-1 13 'deu': 'de',
-1 14 'ell': 'el',
-1 15 'eng': 'en',
-1 16 'spa': 'es',
-1 17 'est': 'et',
-1 18 'fas': 'fa',
-1 19 'fin': 'fi',
-1 20 'fra': 'fr',
-1 21 'guj': 'gu',
-1 22 'heb': 'he',
-1 23 'hin': 'hi',
-1 24 'hrv': 'hr',
-1 25 'hun': 'hu',
-1 26 'ind': 'id',
-1 27 'ita': 'it',
-1 28 'jpn': 'ja',
-1 29 'kan': 'kn',
-1 30 'kor': 'ko',
-1 31 'lit': 'lt',
-1 32 'lav': 'lv',
-1 33 'mkd': 'mk',
-1 34 'mal': 'ml',
-1 35 'mar': 'mr',
-1 36 'nep': 'ne',
-1 37 'nld': 'nl',
-1 38 'nor': 'no',
-1 39 'pan': 'pa',
-1 40 'pol': 'pl',
-1 41 'por': 'pt',
-1 42 'ron': 'ro',
-1 43 'rus': 'ru',
-1 44 'slk': 'sk',
-1 45 'slv': 'sl',
-1 46 'som': 'so',
-1 47 'sqi': 'sq',
-1 48 'swe': 'sv',
-1 49 'swa': 'sw',
-1 50 'tam': 'ta',
-1 51 'tel': 'te',
-1 52 'tha': 'th',
-1 53 'tgl': 'tl',
-1 54 'tur': 'tr',
-1 55 'ukr': 'uk',
-1 56 'urd': 'ur',
-1 57 'vie': 'vi',
-1 58 'zho': 'zh-cn',
-1 59 # 'zho': 'zh-tw',
-1 60 }
-1 61
-1 62
-1 63 def dist(a, b):
-1 64 return sum((av - bv) ** 2 for av, bv in zip(a, b))
-1 65
-1 66
-1 67 def classify(model, text):
-1 68 n = len(text) + 1
-1 69 freq = [text.count(g) / (n - len(g)) for g in model['ngrams']]
-1 70 return min(model['freq'], key=lambda lang: dist(model['freq'][lang], freq))
-1 71
-1 72
-1 73 def test(model):
-1 74 total = 0
-1 75 correct = 0
-1 76
-1 77 with open('data/wili/x_test.txt') as fh:
-1 78 with open('data/wili/y_test.txt') as fh2:
-1 79 for lang, text in zip(fh2, fh):
-1 80 lang = LANG_MAP.get(lang.rstrip())
-1 81 text = text.rstrip()
-1 82 if lang in model['freq']:
-1 83 actual = classify(model, text)
-1 84 total += 1
-1 85 if actual == lang:
-1 86 correct += 1
-1 87
-1 88 print(f'overall correctness {correct / total:.1%} ({total})')
-1 89
-1 90
-1 91 if __name__ == '__main__':
-1 92 parser = argparse.ArgumentParser()
-1 93 parser.add_argument('model')
-1 94 args = parser.parse_args()
-1 95
-1 96 with open(args.model) as fh:
-1 97 model = json.load(fh)
-1 98
-1 99 test(model)