ted commited on
Commit
5d96d01
1 Parent(s): 6a12783

[feat] Add new codebleu

Browse files
app.py CHANGED
@@ -2,5 +2,5 @@ import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
 
4
 
5
- module = evaluate.load("vichyt/codebleu")
6
  launch_gradio_widget(module)
 
2
  from evaluate.utils import launch_gradio_widget
3
 
4
 
5
+ module = evaluate.load("vichyt/metric-codebleu")
6
  launch_gradio_widget(module)
eval/__init__.py DELETED
@@ -1 +0,0 @@
1
- import code_bleu
 
 
eval/bleu.py DELETED
@@ -1,590 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Natural Language Toolkit: BLEU Score
3
- #
4
- # Copyright (C) 2001-2020 NLTK Project
5
- # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim
6
- # Contributors: Björn Mattsson, Dmitrijs Milajevs, Liling Tan
7
- # URL: <http://nltk.org/>
8
- # For license information, see LICENSE.TXT
9
-
10
- """BLEU score implementation."""
11
-
12
- import math
13
- import sys
14
- from fractions import Fraction
15
- import warnings
16
- from collections import Counter
17
-
18
- from utils import ngrams
19
- import pdb
20
-
21
-
22
- def sentence_bleu(
23
- references,
24
- hypothesis,
25
- weights=(0.25, 0.25, 0.25, 0.25),
26
- smoothing_function=None,
27
- auto_reweigh=False,
28
- ):
29
- """
30
- Calculate BLEU score (Bilingual Evaluation Understudy) from
31
- Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002.
32
- "BLEU: a method for automatic evaluation of machine translation."
33
- In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf
34
- >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
35
- ... 'ensures', 'that', 'the', 'military', 'always',
36
- ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
37
- >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops',
38
- ... 'forever', 'hearing', 'the', 'activity', 'guidebook',
39
- ... 'that', 'party', 'direct']
40
- >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
41
- ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
42
- ... 'heed', 'Party', 'commands']
43
- >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which',
44
- ... 'guarantees', 'the', 'military', 'forces', 'always',
45
- ... 'being', 'under', 'the', 'command', 'of', 'the',
46
- ... 'Party']
47
- >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
48
- ... 'army', 'always', 'to', 'heed', 'the', 'directions',
49
- ... 'of', 'the', 'party']
50
- >>> sentence_bleu([reference1, reference2, reference3], hypothesis1) # doctest: +ELLIPSIS
51
- 0.5045...
52
- If there is no ngrams overlap for any order of n-grams, BLEU returns the
53
- value 0. This is because the precision for the order of n-grams without
54
- overlap is 0, and the geometric mean in the final BLEU score computation
55
- multiplies the 0 with the precision of other n-grams. This results in 0
56
- (independently of the precision of the othe n-gram orders). The following
57
- example has zero 3-gram and 4-gram overlaps:
58
- >>> round(sentence_bleu([reference1, reference2, reference3], hypothesis2),4) # doctest: +ELLIPSIS
59
- 0.0
60
- To avoid this harsh behaviour when no ngram overlaps are found a smoothing
61
- function can be used.
62
- >>> chencherry = SmoothingFunction()
63
- >>> sentence_bleu([reference1, reference2, reference3], hypothesis2,
64
- ... smoothing_function=chencherry.method1) # doctest: +ELLIPSIS
65
- 0.0370...
66
- The default BLEU calculates a score for up to 4-grams using uniform
67
- weights (this is called BLEU-4). To evaluate your translations with
68
- higher/lower order ngrams, use customized weights. E.g. when accounting
69
- for up to 5-grams with uniform weights (this is called BLEU-5) use:
70
- >>> weights = (1./5., 1./5., 1./5., 1./5., 1./5.)
71
- >>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights) # doctest: +ELLIPSIS
72
- 0.3920...
73
- :param references: reference sentences
74
- :type references: list(list(str))
75
- :param hypothesis: a hypothesis sentence
76
- :type hypothesis: list(str)
77
- :param weights: weights for unigrams, bigrams, trigrams and so on
78
- :type weights: list(float)
79
- :param smoothing_function:
80
- :type smoothing_function: SmoothingFunction
81
- :param auto_reweigh: Option to re-normalize the weights uniformly.
82
- :type auto_reweigh: bool
83
- :return: The sentence-level BLEU score.
84
- :rtype: float
85
- """
86
- return corpus_bleu(
87
- [references], [hypothesis], weights, smoothing_function, auto_reweigh
88
- )
89
-
90
-
91
- def corpus_bleu(
92
- list_of_references,
93
- hypotheses,
94
- weights=(0.25, 0.25, 0.25, 0.25),
95
- smoothing_function=None,
96
- auto_reweigh=False,
97
- ):
98
- """
99
- Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
100
- the hypotheses and their respective references.
101
- Instead of averaging the sentence level BLEU scores (i.e. marco-average
102
- precision), the original BLEU metric (Papineni et al. 2002) accounts for
103
- the micro-average precision (i.e. summing the numerators and denominators
104
- for each hypothesis-reference(s) pairs before the division).
105
- >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
106
- ... 'ensures', 'that', 'the', 'military', 'always',
107
- ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
108
- >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
109
- ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
110
- ... 'heed', 'Party', 'commands']
111
- >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
112
- ... 'guarantees', 'the', 'military', 'forces', 'always',
113
- ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
114
- >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
115
- ... 'army', 'always', 'to', 'heed', 'the', 'directions',
116
- ... 'of', 'the', 'party']
117
- >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
118
- ... 'interested', 'in', 'world', 'history']
119
- >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
120
- ... 'because', 'he', 'read', 'the', 'book']
121
- >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
122
- >>> hypotheses = [hyp1, hyp2]
123
- >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
124
- 0.5920...
125
- The example below show that corpus_bleu() is different from averaging
126
- sentence_bleu() for hypotheses
127
- >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
128
- >>> score2 = sentence_bleu([ref2a], hyp2)
129
- >>> (score1 + score2) / 2 # doctest: +ELLIPSIS
130
- 0.6223...
131
- :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses
132
- :type list_of_references: list(list(list(str)))
133
- :param hypotheses: a list of hypothesis sentences
134
- :type hypotheses: list(list(str))
135
- :param weights: weights for unigrams, bigrams, trigrams and so on
136
- :type weights: list(float)
137
- :param smoothing_function:
138
- :type smoothing_function: SmoothingFunction
139
- :param auto_reweigh: Option to re-normalize the weights uniformly.
140
- :type auto_reweigh: bool
141
- :return: The corpus-level BLEU score.
142
- :rtype: float
143
- """
144
- # Before proceeding to compute BLEU, perform sanity checks.
145
-
146
- p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
147
- p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
148
- hyp_lengths, ref_lengths = 0, 0
149
-
150
- assert len(list_of_references) == len(hypotheses), (
151
- "The number of hypotheses and their reference(s) should be the " "same "
152
- )
153
-
154
- # Iterate through each hypothesis and their corresponding references.
155
- for references, hypothesis in zip(list_of_references, hypotheses):
156
- # For each order of ngram, calculate the numerator and
157
- # denominator for the corpus-level modified precision.
158
- for i, _ in enumerate(weights, start=1):
159
- p_i = modified_precision(references, hypothesis, i)
160
- p_numerators[i] += p_i.numerator
161
- p_denominators[i] += p_i.denominator
162
-
163
- # Calculate the hypothesis length and the closest reference length.
164
- # Adds them to the corpus-level hypothesis and reference counts.
165
- hyp_len = len(hypothesis)
166
- hyp_lengths += hyp_len
167
- ref_lengths += closest_ref_length(references, hyp_len)
168
-
169
- # Calculate corpus-level brevity penalty.
170
- bp = brevity_penalty(ref_lengths, hyp_lengths)
171
-
172
- # Uniformly re-weighting based on maximum hypothesis lengths if largest
173
- # order of n-grams < 4 and weights is set at default.
174
- if auto_reweigh:
175
- if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
176
- weights = (1 / hyp_lengths,) * hyp_lengths
177
-
178
- # Collects the various precision values for the different ngram orders.
179
- p_n = [
180
- Fraction(p_numerators[i], p_denominators[i], _normalize=False)
181
- for i, _ in enumerate(weights, start=1)
182
- ]
183
-
184
- # Returns 0 if there's no matching n-grams
185
- # We only need to check for p_numerators[1] == 0, since if there's
186
- # no unigrams, there won't be any higher order ngrams.
187
- if p_numerators[1] == 0:
188
- return 0
189
-
190
- # If there's no smoothing, set use method0 from SmoothinFunction class.
191
- if not smoothing_function:
192
- smoothing_function = SmoothingFunction().method1
193
- # Smoothen the modified precision.
194
- # Note: smoothing_function() may convert values into floats;
195
- # it tries to retain the Fraction object as much as the
196
- # smoothing method allows.
197
- p_n = smoothing_function(
198
- p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
199
- )
200
- s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
201
- s = bp * math.exp(math.fsum(s))
202
- return s
203
-
204
-
205
- def modified_precision(references, hypothesis, n):
206
- """
207
- Calculate modified ngram precision.
208
- The normal precision method may lead to some wrong translations with
209
- high-precision, e.g., the translation, in which a word of reference
210
- repeats several times, has very high precision.
211
- This function only returns the Fraction object that contains the numerator
212
- and denominator necessary to calculate the corpus-level precision.
213
- To calculate the modified precision for a single pair of hypothesis and
214
- references, cast the Fraction object into a float.
215
- The famous "the the the ... " example shows that you can get BLEU precision
216
- by duplicating high frequency words.
217
- >>> reference1 = 'the cat is on the mat'.split()
218
- >>> reference2 = 'there is a cat on the mat'.split()
219
- >>> hypothesis1 = 'the the the the the the the'.split()
220
- >>> references = [reference1, reference2]
221
- >>> float(modified_precision(references, hypothesis1, n=1)) # doctest: +ELLIPSIS
222
- 0.2857...
223
- In the modified n-gram precision, a reference word will be considered
224
- exhausted after a matching hypothesis word is identified, e.g.
225
- >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
226
- ... 'ensures', 'that', 'the', 'military', 'will',
227
- ... 'forever', 'heed', 'Party', 'commands']
228
- >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which',
229
- ... 'guarantees', 'the', 'military', 'forces', 'always',
230
- ... 'being', 'under', 'the', 'command', 'of', 'the',
231
- ... 'Party']
232
- >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
233
- ... 'army', 'always', 'to', 'heed', 'the', 'directions',
234
- ... 'of', 'the', 'party']
235
- >>> hypothesis = 'of the'.split()
236
- >>> references = [reference1, reference2, reference3]
237
- >>> float(modified_precision(references, hypothesis, n=1))
238
- 1.0
239
- >>> float(modified_precision(references, hypothesis, n=2))
240
- 1.0
241
- An example of a normal machine translation hypothesis:
242
- >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
243
- ... 'ensures', 'that', 'the', 'military', 'always',
244
- ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
245
- >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops',
246
- ... 'forever', 'hearing', 'the', 'activity', 'guidebook',
247
- ... 'that', 'party', 'direct']
248
- >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
249
- ... 'ensures', 'that', 'the', 'military', 'will',
250
- ... 'forever', 'heed', 'Party', 'commands']
251
- >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which',
252
- ... 'guarantees', 'the', 'military', 'forces', 'always',
253
- ... 'being', 'under', 'the', 'command', 'of', 'the',
254
- ... 'Party']
255
- >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
256
- ... 'army', 'always', 'to', 'heed', 'the', 'directions',
257
- ... 'of', 'the', 'party']
258
- >>> references = [reference1, reference2, reference3]
259
- >>> float(modified_precision(references, hypothesis1, n=1)) # doctest: +ELLIPSIS
260
- 0.9444...
261
- >>> float(modified_precision(references, hypothesis2, n=1)) # doctest: +ELLIPSIS
262
- 0.5714...
263
- >>> float(modified_precision(references, hypothesis1, n=2)) # doctest: +ELLIPSIS
264
- 0.5882352941176471
265
- >>> float(modified_precision(references, hypothesis2, n=2)) # doctest: +ELLIPSIS
266
- 0.07692...
267
- :param references: A list of reference translations.
268
- :type references: list(list(str))
269
- :param hypothesis: A hypothesis translation.
270
- :type hypothesis: list(str)
271
- :param n: The ngram order.
272
- :type n: int
273
- :return: BLEU's modified precision for the nth order ngram.
274
- :rtype: Fraction
275
- """
276
- # Extracts all ngrams in hypothesis
277
- # Set an empty Counter if hypothesis is empty.
278
-
279
- counts = Counter(ngrams(hypothesis, n)) if len(hypothesis) >= n else Counter()
280
- # Extract a union of references' counts.
281
- # max_counts = reduce(or_, [Counter(ngrams(ref, n)) for ref in references])
282
- max_counts = {}
283
- for reference in references:
284
- reference_counts = (
285
- Counter(ngrams(reference, n)) if len(reference) >= n else Counter()
286
- )
287
- for ngram in counts:
288
- max_counts[ngram] = max(max_counts.get(ngram, 0), reference_counts[ngram])
289
-
290
- # Assigns the intersection between hypothesis and references' counts.
291
- clipped_counts = {
292
- ngram: min(count, max_counts[ngram]) for ngram, count in counts.items()
293
- }
294
-
295
- numerator = sum(clipped_counts.values())
296
- # Ensures that denominator is minimum 1 to avoid ZeroDivisionError.
297
- # Usually this happens when the ngram order is > len(reference).
298
- denominator = max(1, sum(counts.values()))
299
-
300
- return Fraction(numerator, denominator, _normalize=False)
301
-
302
-
303
- def closest_ref_length(references, hyp_len):
304
- """
305
- This function finds the reference that is the closest length to the
306
- hypothesis. The closest reference length is referred to as *r* variable
307
- from the brevity penalty formula in Papineni et. al. (2002)
308
- :param references: A list of reference translations.
309
- :type references: list(list(str))
310
- :param hyp_len: The length of the hypothesis.
311
- :type hyp_len: int
312
- :return: The length of the reference that's closest to the hypothesis.
313
- :rtype: int
314
- """
315
- ref_lens = (len(reference) for reference in references)
316
- closest_ref_len = min(
317
- ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len)
318
- )
319
- return closest_ref_len
320
-
321
-
322
- def brevity_penalty(closest_ref_len, hyp_len):
323
- """
324
- Calculate brevity penalty.
325
- As the modified n-gram precision still has the problem from the short
326
- length sentence, brevity penalty is used to modify the overall BLEU
327
- score according to length.
328
- An example from the paper. There are three references with length 12, 15
329
- and 17. And a concise hypothesis of the length 12. The brevity penalty is 1.
330
- >>> reference1 = list('aaaaaaaaaaaa') # i.e. ['a'] * 12
331
- >>> reference2 = list('aaaaaaaaaaaaaaa') # i.e. ['a'] * 15
332
- >>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17
333
- >>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12
334
- >>> references = [reference1, reference2, reference3]
335
- >>> hyp_len = len(hypothesis)
336
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
337
- >>> brevity_penalty(closest_ref_len, hyp_len)
338
- 1.0
339
- In case a hypothesis translation is shorter than the references, penalty is
340
- applied.
341
- >>> references = [['a'] * 28, ['a'] * 28]
342
- >>> hypothesis = ['a'] * 12
343
- >>> hyp_len = len(hypothesis)
344
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
345
- >>> brevity_penalty(closest_ref_len, hyp_len)
346
- 0.2635971381157267
347
- The length of the closest reference is used to compute the penalty. If the
348
- length of a hypothesis is 12, and the reference lengths are 13 and 2, the
349
- penalty is applied because the hypothesis length (12) is less then the
350
- closest reference length (13).
351
- >>> references = [['a'] * 13, ['a'] * 2]
352
- >>> hypothesis = ['a'] * 12
353
- >>> hyp_len = len(hypothesis)
354
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
355
- >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS
356
- 0.9200...
357
- The brevity penalty doesn't depend on reference order. More importantly,
358
- when two reference sentences are at the same distance, the shortest
359
- reference sentence length is used.
360
- >>> references = [['a'] * 13, ['a'] * 11]
361
- >>> hypothesis = ['a'] * 12
362
- >>> hyp_len = len(hypothesis)
363
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
364
- >>> bp1 = brevity_penalty(closest_ref_len, hyp_len)
365
- >>> hyp_len = len(hypothesis)
366
- >>> closest_ref_len = closest_ref_length(reversed(references), hyp_len)
367
- >>> bp2 = brevity_penalty(closest_ref_len, hyp_len)
368
- >>> bp1 == bp2 == 1
369
- True
370
- A test example from mteval-v13a.pl (starting from the line 705):
371
- >>> references = [['a'] * 11, ['a'] * 8]
372
- >>> hypothesis = ['a'] * 7
373
- >>> hyp_len = len(hypothesis)
374
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
375
- >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS
376
- 0.8668...
377
- >>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7]
378
- >>> hypothesis = ['a'] * 7
379
- >>> hyp_len = len(hypothesis)
380
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
381
- >>> brevity_penalty(closest_ref_len, hyp_len)
382
- 1.0
383
- :param hyp_len: The length of the hypothesis for a single sentence OR the
384
- sum of all the hypotheses' lengths for a corpus
385
- :type hyp_len: int
386
- :param closest_ref_len: The length of the closest reference for a single
387
- hypothesis OR the sum of all the closest references for every hypotheses.
388
- :type closest_ref_len: int
389
- :return: BLEU's brevity penalty.
390
- :rtype: float
391
- """
392
- if hyp_len > closest_ref_len:
393
- return 1
394
- # If hypothesis is empty, brevity penalty = 0 should result in BLEU = 0.0
395
- elif hyp_len == 0:
396
- return 0
397
- else:
398
- return math.exp(1 - closest_ref_len / hyp_len)
399
-
400
-
401
- class SmoothingFunction:
402
- """
403
- This is an implementation of the smoothing techniques
404
- for segment-level BLEU scores that was presented in
405
- Boxing Chen and Collin Cherry (2014) A Systematic Comparison of
406
- Smoothing Techniques for Sentence-Level BLEU. In WMT14.
407
- http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
408
- """
409
-
410
- def __init__(self, epsilon=0.1, alpha=5, k=5):
411
- """
412
- This will initialize the parameters required for the various smoothing
413
- techniques, the default values are set to the numbers used in the
414
- experiments from Chen and Cherry (2014).
415
- >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures',
416
- ... 'that', 'the', 'military', 'always', 'obeys', 'the',
417
- ... 'commands', 'of', 'the', 'party']
418
- >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures',
419
- ... 'that', 'the', 'military', 'will', 'forever', 'heed',
420
- ... 'Party', 'commands']
421
- >>> chencherry = SmoothingFunction()
422
- >>> print(sentence_bleu([reference1], hypothesis1)) # doctest: +ELLIPSIS
423
- 0.4118...
424
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method0)) # doctest: +ELLIPSIS
425
- 0.4118...
426
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method1)) # doctest: +ELLIPSIS
427
- 0.4118...
428
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method2)) # doctest: +ELLIPSIS
429
- 0.4489...
430
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method3)) # doctest: +ELLIPSIS
431
- 0.4118...
432
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method4)) # doctest: +ELLIPSIS
433
- 0.4118...
434
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method5)) # doctest: +ELLIPSIS
435
- 0.4905...
436
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method6)) # doctest: +ELLIPSIS
437
- 0.4135...
438
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method7)) # doctest: +ELLIPSIS
439
- 0.4905...
440
- :param epsilon: the epsilon value use in method 1
441
- :type epsilon: float
442
- :param alpha: the alpha value use in method 6
443
- :type alpha: int
444
- :param k: the k value use in method 4
445
- :type k: int
446
- """
447
- self.epsilon = epsilon
448
- self.alpha = alpha
449
- self.k = k
450
-
451
- def method0(self, p_n, *args, **kwargs):
452
- """
453
- No smoothing.
454
- """
455
- p_n_new = []
456
- for i, p_i in enumerate(p_n):
457
- if p_i.numerator != 0:
458
- p_n_new.append(p_i)
459
- else:
460
- _msg = str(
461
- "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n"
462
- "Therefore the BLEU score evaluates to 0, independently of\n"
463
- "how many N-gram overlaps of lower order it contains.\n"
464
- "Consider using lower n-gram order or use "
465
- "SmoothingFunction()"
466
- ).format(i + 1)
467
- warnings.warn(_msg)
468
- # When numerator==0 where denonminator==0 or !=0, the result
469
- # for the precision score should be equal to 0 or undefined.
470
- # Due to BLEU geometric mean computation in logarithm space,
471
- # we we need to take the return sys.float_info.min such that
472
- # math.log(sys.float_info.min) returns a 0 precision score.
473
- p_n_new.append(sys.float_info.min)
474
- return p_n_new
475
-
476
- def method1(self, p_n, *args, **kwargs):
477
- """
478
- Smoothing method 1: Add *epsilon* counts to precision with 0 counts.
479
- """
480
- return [
481
- (p_i.numerator + self.epsilon) / p_i.denominator
482
- if p_i.numerator == 0
483
- else p_i
484
- for p_i in p_n
485
- ]
486
-
487
- def method2(self, p_n, *args, **kwargs):
488
- """
489
- Smoothing method 2: Add 1 to both numerator and denominator from
490
- Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of
491
- machine translation quality using longest common subsequence and
492
- skip-bigram statistics. In ACL04.
493
- """
494
- return [
495
- Fraction(p_i.numerator + 1, p_i.denominator + 1, _normalize=False)
496
- for p_i in p_n
497
- ]
498
-
499
- def method3(self, p_n, *args, **kwargs):
500
- """
501
- Smoothing method 3: NIST geometric sequence smoothing
502
- The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each
503
- precision score whose matching n-gram count is null.
504
- k is 1 for the first 'n' value for which the n-gram match count is null/
505
- For example, if the text contains:
506
- - one 2-gram match
507
- - and (consequently) two 1-gram matches
508
- the n-gram count for each individual precision score would be:
509
- - n=1 => prec_count = 2 (two unigrams)
510
- - n=2 => prec_count = 1 (one bigram)
511
- - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1)
512
- - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2)
513
- """
514
- incvnt = 1 # From the mteval-v13a.pl, it's referred to as k.
515
- for i, p_i in enumerate(p_n):
516
- if p_i.numerator == 0:
517
- p_n[i] = 1 / (2 ** incvnt * p_i.denominator)
518
- incvnt += 1
519
- return p_n
520
-
521
- def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
522
- """
523
- Smoothing method 4:
524
- Shorter translations may have inflated precision values due to having
525
- smaller denominators; therefore, we give them proportionally
526
- smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry
527
- suggests dividing by 1/ln(len(T)), where T is the length of the translation.
528
- """
529
- hyp_len = hyp_len if hyp_len else len(hypothesis)
530
- for i, p_i in enumerate(p_n):
531
- if p_i.numerator == 0 and hyp_len != 0:
532
- incvnt = i + 1 * self.k / math.log(
533
- hyp_len
534
- ) # Note that this K is different from the K from NIST.
535
- p_n[i] = incvnt / p_i.denominator
536
- return p_n
537
-
538
- def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
539
- """
540
- Smoothing method 5:
541
- The matched counts for similar values of n should be similar. To a
542
- calculate the n-gram matched count, it averages the n−1, n and n+1 gram
543
- matched counts.
544
- """
545
- hyp_len = hyp_len if hyp_len else len(hypothesis)
546
- m = {}
547
- # Requires an precision value for an addition ngram order.
548
- p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)]
549
- m[-1] = p_n[0] + 1
550
- for i, p_i in enumerate(p_n):
551
- p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3
552
- m[i] = p_n[i]
553
- return p_n
554
-
555
- def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
556
- """
557
- Smoothing method 6:
558
- Interpolates the maximum likelihood estimate of the precision *p_n* with
559
- a prior estimate *pi0*. The prior is estimated by assuming that the ratio
560
- between pn and pn−1 will be the same as that between pn−1 and pn−2; from
561
- Gao and He (2013) Training MRF-Based Phrase Translation Models using
562
- Gradient Ascent. In NAACL.
563
- """
564
- hyp_len = hyp_len if hyp_len else len(hypothesis)
565
- # This smoothing only works when p_1 and p_2 is non-zero.
566
- # Raise an error with an appropriate message when the input is too short
567
- # to use this smoothing technique.
568
- assert p_n[2], "This smoothing method requires non-zero precision for bigrams."
569
- for i, p_i in enumerate(p_n):
570
- if i in [0, 1]: # Skips the first 2 orders of ngrams.
571
- continue
572
- else:
573
- pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2]
574
- # No. of ngrams in translation that matches the reference.
575
- m = p_i.numerator
576
- # No. of ngrams in translation.
577
- l = sum(1 for _ in ngrams(hypothesis, i + 1))
578
- # Calculates the interpolated precision.
579
- p_n[i] = (m + self.alpha * pi0) / (l + self.alpha)
580
- return p_n
581
-
582
- def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
583
- """
584
- Smoothing method 7:
585
- Interpolates methods 4 and 5.
586
- """
587
- hyp_len = hyp_len if hyp_len else len(hypothesis)
588
- p_n = self.method4(p_n, references, hypothesis, hyp_len)
589
- p_n = self.method5(p_n, references, hypothesis, hyp_len)
590
- return p_n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/code_bleu.py DELETED
@@ -1,44 +0,0 @@
1
- import bleu
2
- import weighted_ngram_match
3
- import syntax_match
4
- import dataflow_match
5
-
6
-
7
- def calc(predictions, references):
8
- lang = "python"
9
-
10
- alpha, beta, gamma, theta = (0.1, 0.1, 0.4, 0.4)
11
-
12
- tokenized_pres = [x.split() for x in predictions]
13
- tokenized_refs = [[x.split() for x in reference] for reference in references]
14
-
15
- ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_pres)
16
- keywords = [x.strip() for x in open('./src/eval/keywords/python.txt', 'r', encoding='utf-8').readlines()]
17
-
18
- def make_weights(reference_tokens, key_word_list):
19
- return {token: 1 if token in key_word_list else 0.2 for token in reference_tokens}
20
-
21
- tokenized_refs_with_weights = [[[reference_tokens, make_weights(reference_tokens, keywords)] \
22
- for reference_tokens in reference] for reference in tokenized_refs]
23
-
24
- weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_pres)
25
-
26
- # calculate syntax match
27
- syntax_match_score = syntax_match.corpus_syntax_match(references, predictions, lang)
28
-
29
- # calculate dataflow match
30
- dataflow_match_score = dataflow_match.corpus_dataflow_match(references, predictions, lang)
31
-
32
- code_bleu_score = alpha * ngram_match_score \
33
- + beta * weighted_ngram_match_score \
34
- + gamma * syntax_match_score \
35
- + theta * dataflow_match_score
36
-
37
- return {
38
- 'ngram_match_score': ngram_match_score,
39
- 'weighted_ngram_match_score': weighted_ngram_match_score,
40
- 'syntax_match_score': syntax_match_score,
41
- 'dataflow_match_score': dataflow_match_score,
42
- 'code_bleu_score': code_bleu_score
43
- }
44
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/dataflow_match.py DELETED
@@ -1,148 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- from parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript, DFG_csharp
5
- from parser import (remove_comments_and_docstrings,
6
- tree_to_token_index,
7
- index_to_code_token,
8
- tree_to_variable_index)
9
- from tree_sitter import Language, Parser
10
- import pdb
11
-
12
- dfg_function = {
13
- 'python': DFG_python,
14
- 'java': DFG_java,
15
- 'ruby': DFG_ruby,
16
- 'go': DFG_go,
17
- 'php': DFG_php,
18
- 'javascript': DFG_javascript,
19
- 'c_sharp': DFG_csharp,
20
- }
21
-
22
-
23
- def calc_dataflow_match(references, candidate, lang):
24
- return corpus_dataflow_match([references], [candidate], lang)
25
-
26
-
27
- def corpus_dataflow_match(references, candidates, lang):
28
- LANGUAGE = Language('./src/eval/parser/my-languages.so', lang)
29
- parser = Parser()
30
- parser.set_language(LANGUAGE)
31
- parser = [parser, dfg_function[lang]]
32
- match_count = 0
33
- total_count = 0
34
-
35
- for i in range(len(candidates)):
36
- references_sample = references[i]
37
- candidate = candidates[i]
38
- for reference in references_sample:
39
- try:
40
- candidate = remove_comments_and_docstrings(candidate, 'java')
41
- except:
42
- pass
43
- try:
44
- reference = remove_comments_and_docstrings(reference, 'java')
45
- except:
46
- pass
47
-
48
- cand_dfg = get_data_flow(candidate, parser)
49
- ref_dfg = get_data_flow(reference, parser)
50
-
51
- normalized_cand_dfg = normalize_dataflow(cand_dfg)
52
- normalized_ref_dfg = normalize_dataflow(ref_dfg)
53
-
54
- if len(normalized_ref_dfg) > 0:
55
- total_count += len(normalized_ref_dfg)
56
- for dataflow in normalized_ref_dfg:
57
- if dataflow in normalized_cand_dfg:
58
- match_count += 1
59
- normalized_cand_dfg.remove(dataflow)
60
- if total_count == 0:
61
- print(
62
- "WARNING: There is no reference data-flows extracted from the whole corpus, and the data-flow match score degenerates to 0. Please consider ignoring this score.")
63
- return 0
64
- score = match_count / total_count
65
- return score
66
-
67
-
68
- def get_data_flow(code, parser):
69
- try:
70
- tree = parser[0].parse(bytes(code, 'utf8'))
71
- root_node = tree.root_node
72
- tokens_index = tree_to_token_index(root_node)
73
- code = code.split('\n')
74
- code_tokens = [index_to_code_token(x, code) for x in tokens_index]
75
- index_to_code = {}
76
- for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)):
77
- index_to_code[index] = (idx, code)
78
- try:
79
- DFG, _ = parser[1](root_node, index_to_code, {})
80
- except:
81
- DFG = []
82
- DFG = sorted(DFG, key=lambda x: x[1])
83
- indexs = set()
84
- for d in DFG:
85
- if len(d[-1]) != 0:
86
- indexs.add(d[1])
87
- for x in d[-1]:
88
- indexs.add(x)
89
- new_DFG = []
90
- for d in DFG:
91
- if d[1] in indexs:
92
- new_DFG.append(d)
93
- codes = code_tokens
94
- dfg = new_DFG
95
- except:
96
- codes = code.split()
97
- dfg = []
98
- # merge nodes
99
- dic = {}
100
- for d in dfg:
101
- if d[1] not in dic:
102
- dic[d[1]] = d
103
- else:
104
- dic[d[1]] = (d[0], d[1], d[2], list(set(dic[d[1]][3] + d[3])), list(set(dic[d[1]][4] + d[4])))
105
- DFG = []
106
- for d in dic:
107
- DFG.append(dic[d])
108
- dfg = DFG
109
- return dfg
110
-
111
-
112
- def normalize_dataflow_item(dataflow_item):
113
- var_name = dataflow_item[0]
114
- var_pos = dataflow_item[1]
115
- relationship = dataflow_item[2]
116
- par_vars_name_list = dataflow_item[3]
117
- par_vars_pos_list = dataflow_item[4]
118
-
119
- var_names = list(set(par_vars_name_list + [var_name]))
120
- norm_names = {}
121
- for i in range(len(var_names)):
122
- norm_names[var_names[i]] = 'var_' + str(i)
123
-
124
- norm_var_name = norm_names[var_name]
125
- relationship = dataflow_item[2]
126
- norm_par_vars_name_list = [norm_names[x] for x in par_vars_name_list]
127
-
128
- return (norm_var_name, relationship, norm_par_vars_name_list)
129
-
130
-
131
- def normalize_dataflow(dataflow):
132
- var_dict = {}
133
- i = 0
134
- normalized_dataflow = []
135
- for item in dataflow:
136
- var_name = item[0]
137
- relationship = item[2]
138
- par_vars_name_list = item[3]
139
- for name in par_vars_name_list:
140
- if name not in var_dict:
141
- var_dict[name] = 'var_' + str(i)
142
- i += 1
143
- if var_name not in var_dict:
144
- var_dict[var_name] = 'var_' + str(i)
145
- i += 1
146
- normalized_dataflow.append((var_dict[var_name], relationship, [var_dict[x] for x in par_vars_name_list]))
147
- return normalized_dataflow
148
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/keywords/python.txt DELETED
@@ -1,35 +0,0 @@
1
- False
2
- None
3
- True
4
- and
5
- as
6
- assert
7
- async
8
- await
9
- break
10
- class
11
- continue
12
- def
13
- del
14
- elif
15
- else
16
- except
17
- finally
18
- for
19
- from
20
- global
21
- if
22
- import
23
- in
24
- is
25
- lambda
26
- nonlocal
27
- not
28
- or
29
- pass
30
- raise
31
- return
32
- try
33
- while
34
- with
35
- yield
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/parser/DFG.py DELETED
@@ -1,1186 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- from tree_sitter import Language, Parser
5
- from .utils import (remove_comments_and_docstrings,
6
- tree_to_token_index,
7
- index_to_code_token,
8
- tree_to_variable_index)
9
-
10
-
11
- def DFG_python(root_node, index_to_code, states):
12
- assignment = ['assignment', 'augmented_assignment', 'for_in_clause']
13
- if_statement = ['if_statement']
14
- for_statement = ['for_statement']
15
- while_statement = ['while_statement']
16
- do_first_statement = ['for_in_clause']
17
- def_statement = ['default_parameter']
18
- states = states.copy()
19
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
20
- 'character_literal']) and root_node.type != 'comment':
21
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
22
- if root_node.type == code:
23
- return [], states
24
- elif code in states:
25
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
26
- else:
27
- if root_node.type == 'identifier':
28
- states[code] = [idx]
29
- return [(code, idx, 'comesFrom', [], [])], states
30
- elif root_node.type in def_statement:
31
- name = root_node.child_by_field_name('name')
32
- value = root_node.child_by_field_name('value')
33
- DFG = []
34
- if value is None:
35
- indexs = tree_to_variable_index(name, index_to_code)
36
- for index in indexs:
37
- idx, code = index_to_code[index]
38
- DFG.append((code, idx, 'comesFrom', [], []))
39
- states[code] = [idx]
40
- return sorted(DFG, key=lambda x: x[1]), states
41
- else:
42
- name_indexs = tree_to_variable_index(name, index_to_code)
43
- value_indexs = tree_to_variable_index(value, index_to_code)
44
- temp, states = DFG_python(value, index_to_code, states)
45
- DFG += temp
46
- for index1 in name_indexs:
47
- idx1, code1 = index_to_code[index1]
48
- for index2 in value_indexs:
49
- idx2, code2 = index_to_code[index2]
50
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
51
- states[code1] = [idx1]
52
- return sorted(DFG, key=lambda x: x[1]), states
53
- elif root_node.type in assignment:
54
- if root_node.type == 'for_in_clause':
55
- right_nodes = [root_node.children[-1]]
56
- left_nodes = [root_node.child_by_field_name('left')]
57
- else:
58
- if root_node.child_by_field_name('right') is None:
59
- return [], states
60
- left_nodes = [x for x in root_node.child_by_field_name('left').children if x.type != ',']
61
- right_nodes = [x for x in root_node.child_by_field_name('right').children if x.type != ',']
62
- if len(right_nodes) != len(left_nodes):
63
- left_nodes = [root_node.child_by_field_name('left')]
64
- right_nodes = [root_node.child_by_field_name('right')]
65
- if len(left_nodes) == 0:
66
- left_nodes = [root_node.child_by_field_name('left')]
67
- if len(right_nodes) == 0:
68
- right_nodes = [root_node.child_by_field_name('right')]
69
- DFG = []
70
- for node in right_nodes:
71
- temp, states = DFG_python(node, index_to_code, states)
72
- DFG += temp
73
-
74
- for left_node, right_node in zip(left_nodes, right_nodes):
75
- left_tokens_index = tree_to_variable_index(left_node, index_to_code)
76
- right_tokens_index = tree_to_variable_index(right_node, index_to_code)
77
- temp = []
78
- for token1_index in left_tokens_index:
79
- idx1, code1 = index_to_code[token1_index]
80
- temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index],
81
- [index_to_code[x][0] for x in right_tokens_index]))
82
- states[code1] = [idx1]
83
- DFG += temp
84
- return sorted(DFG, key=lambda x: x[1]), states
85
- elif root_node.type in if_statement:
86
- DFG = []
87
- current_states = states.copy()
88
- others_states = []
89
- tag = False
90
- if 'else' in root_node.type:
91
- tag = True
92
- for child in root_node.children:
93
- if 'else' in child.type:
94
- tag = True
95
- if child.type not in ['elif_clause', 'else_clause']:
96
- temp, current_states = DFG_python(child, index_to_code, current_states)
97
- DFG += temp
98
- else:
99
- temp, new_states = DFG_python(child, index_to_code, states)
100
- DFG += temp
101
- others_states.append(new_states)
102
- others_states.append(current_states)
103
- if tag is False:
104
- others_states.append(states)
105
- new_states = {}
106
- for dic in others_states:
107
- for key in dic:
108
- if key not in new_states:
109
- new_states[key] = dic[key].copy()
110
- else:
111
- new_states[key] += dic[key]
112
- for key in new_states:
113
- new_states[key] = sorted(list(set(new_states[key])))
114
- return sorted(DFG, key=lambda x: x[1]), new_states
115
- elif root_node.type in for_statement:
116
- DFG = []
117
- for i in range(2):
118
- right_nodes = [x for x in root_node.child_by_field_name('right').children if x.type != ',']
119
- left_nodes = [x for x in root_node.child_by_field_name('left').children if x.type != ',']
120
- if len(right_nodes) != len(left_nodes):
121
- left_nodes = [root_node.child_by_field_name('left')]
122
- right_nodes = [root_node.child_by_field_name('right')]
123
- if len(left_nodes) == 0:
124
- left_nodes = [root_node.child_by_field_name('left')]
125
- if len(right_nodes) == 0:
126
- right_nodes = [root_node.child_by_field_name('right')]
127
- for node in right_nodes:
128
- temp, states = DFG_python(node, index_to_code, states)
129
- DFG += temp
130
- for left_node, right_node in zip(left_nodes, right_nodes):
131
- left_tokens_index = tree_to_variable_index(left_node, index_to_code)
132
- right_tokens_index = tree_to_variable_index(right_node, index_to_code)
133
- temp = []
134
- for token1_index in left_tokens_index:
135
- idx1, code1 = index_to_code[token1_index]
136
- temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index],
137
- [index_to_code[x][0] for x in right_tokens_index]))
138
- states[code1] = [idx1]
139
- DFG += temp
140
- if root_node.children[-1].type == "block":
141
- temp, states = DFG_python(root_node.children[-1], index_to_code, states)
142
- DFG += temp
143
- dic = {}
144
- for x in DFG:
145
- if (x[0], x[1], x[2]) not in dic:
146
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
147
- else:
148
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
149
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
150
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
151
- return sorted(DFG, key=lambda x: x[1]), states
152
- elif root_node.type in while_statement:
153
- DFG = []
154
- for i in range(2):
155
- for child in root_node.children:
156
- temp, states = DFG_python(child, index_to_code, states)
157
- DFG += temp
158
- dic = {}
159
- for x in DFG:
160
- if (x[0], x[1], x[2]) not in dic:
161
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
162
- else:
163
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
164
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
165
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
166
- return sorted(DFG, key=lambda x: x[1]), states
167
- else:
168
- DFG = []
169
- for child in root_node.children:
170
- if child.type in do_first_statement:
171
- temp, states = DFG_python(child, index_to_code, states)
172
- DFG += temp
173
- for child in root_node.children:
174
- if child.type not in do_first_statement:
175
- temp, states = DFG_python(child, index_to_code, states)
176
- DFG += temp
177
-
178
- return sorted(DFG, key=lambda x: x[1]), states
179
-
180
-
181
- def DFG_java(root_node, index_to_code, states):
182
- assignment = ['assignment_expression']
183
- def_statement = ['variable_declarator']
184
- increment_statement = ['update_expression']
185
- if_statement = ['if_statement', 'else']
186
- for_statement = ['for_statement']
187
- enhanced_for_statement = ['enhanced_for_statement']
188
- while_statement = ['while_statement']
189
- do_first_statement = []
190
- states = states.copy()
191
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
192
- 'character_literal']) and root_node.type != 'comment':
193
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
194
- if root_node.type == code:
195
- return [], states
196
- elif code in states:
197
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
198
- else:
199
- if root_node.type == 'identifier':
200
- states[code] = [idx]
201
- return [(code, idx, 'comesFrom', [], [])], states
202
- elif root_node.type in def_statement:
203
- name = root_node.child_by_field_name('name')
204
- value = root_node.child_by_field_name('value')
205
- DFG = []
206
- if value is None:
207
- indexs = tree_to_variable_index(name, index_to_code)
208
- for index in indexs:
209
- idx, code = index_to_code[index]
210
- DFG.append((code, idx, 'comesFrom', [], []))
211
- states[code] = [idx]
212
- return sorted(DFG, key=lambda x: x[1]), states
213
- else:
214
- name_indexs = tree_to_variable_index(name, index_to_code)
215
- value_indexs = tree_to_variable_index(value, index_to_code)
216
- temp, states = DFG_java(value, index_to_code, states)
217
- DFG += temp
218
- for index1 in name_indexs:
219
- idx1, code1 = index_to_code[index1]
220
- for index2 in value_indexs:
221
- idx2, code2 = index_to_code[index2]
222
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
223
- states[code1] = [idx1]
224
- return sorted(DFG, key=lambda x: x[1]), states
225
- elif root_node.type in assignment:
226
- left_nodes = root_node.child_by_field_name('left')
227
- right_nodes = root_node.child_by_field_name('right')
228
- DFG = []
229
- temp, states = DFG_java(right_nodes, index_to_code, states)
230
- DFG += temp
231
- name_indexs = tree_to_variable_index(left_nodes, index_to_code)
232
- value_indexs = tree_to_variable_index(right_nodes, index_to_code)
233
- for index1 in name_indexs:
234
- idx1, code1 = index_to_code[index1]
235
- for index2 in value_indexs:
236
- idx2, code2 = index_to_code[index2]
237
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
238
- states[code1] = [idx1]
239
- return sorted(DFG, key=lambda x: x[1]), states
240
- elif root_node.type in increment_statement:
241
- DFG = []
242
- indexs = tree_to_variable_index(root_node, index_to_code)
243
- for index1 in indexs:
244
- idx1, code1 = index_to_code[index1]
245
- for index2 in indexs:
246
- idx2, code2 = index_to_code[index2]
247
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
248
- states[code1] = [idx1]
249
- return sorted(DFG, key=lambda x: x[1]), states
250
- elif root_node.type in if_statement:
251
- DFG = []
252
- current_states = states.copy()
253
- others_states = []
254
- flag = False
255
- tag = False
256
- if 'else' in root_node.type:
257
- tag = True
258
- for child in root_node.children:
259
- if 'else' in child.type:
260
- tag = True
261
- if child.type not in if_statement and flag is False:
262
- temp, current_states = DFG_java(child, index_to_code, current_states)
263
- DFG += temp
264
- else:
265
- flag = True
266
- temp, new_states = DFG_java(child, index_to_code, states)
267
- DFG += temp
268
- others_states.append(new_states)
269
- others_states.append(current_states)
270
- if tag is False:
271
- others_states.append(states)
272
- new_states = {}
273
- for dic in others_states:
274
- for key in dic:
275
- if key not in new_states:
276
- new_states[key] = dic[key].copy()
277
- else:
278
- new_states[key] += dic[key]
279
- for key in new_states:
280
- new_states[key] = sorted(list(set(new_states[key])))
281
- return sorted(DFG, key=lambda x: x[1]), new_states
282
- elif root_node.type in for_statement:
283
- DFG = []
284
- for child in root_node.children:
285
- temp, states = DFG_java(child, index_to_code, states)
286
- DFG += temp
287
- flag = False
288
- for child in root_node.children:
289
- if flag:
290
- temp, states = DFG_java(child, index_to_code, states)
291
- DFG += temp
292
- elif child.type == "local_variable_declaration":
293
- flag = True
294
- dic = {}
295
- for x in DFG:
296
- if (x[0], x[1], x[2]) not in dic:
297
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
298
- else:
299
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
300
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
301
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
302
- return sorted(DFG, key=lambda x: x[1]), states
303
- elif root_node.type in enhanced_for_statement:
304
- name = root_node.child_by_field_name('name')
305
- value = root_node.child_by_field_name('value')
306
- body = root_node.child_by_field_name('body')
307
- DFG = []
308
- for i in range(2):
309
- temp, states = DFG_java(value, index_to_code, states)
310
- DFG += temp
311
- name_indexs = tree_to_variable_index(name, index_to_code)
312
- value_indexs = tree_to_variable_index(value, index_to_code)
313
- for index1 in name_indexs:
314
- idx1, code1 = index_to_code[index1]
315
- for index2 in value_indexs:
316
- idx2, code2 = index_to_code[index2]
317
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
318
- states[code1] = [idx1]
319
- temp, states = DFG_java(body, index_to_code, states)
320
- DFG += temp
321
- dic = {}
322
- for x in DFG:
323
- if (x[0], x[1], x[2]) not in dic:
324
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
325
- else:
326
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
327
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
328
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
329
- return sorted(DFG, key=lambda x: x[1]), states
330
- elif root_node.type in while_statement:
331
- DFG = []
332
- for i in range(2):
333
- for child in root_node.children:
334
- temp, states = DFG_java(child, index_to_code, states)
335
- DFG += temp
336
- dic = {}
337
- for x in DFG:
338
- if (x[0], x[1], x[2]) not in dic:
339
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
340
- else:
341
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
342
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
343
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
344
- return sorted(DFG, key=lambda x: x[1]), states
345
- else:
346
- DFG = []
347
- for child in root_node.children:
348
- if child.type in do_first_statement:
349
- temp, states = DFG_java(child, index_to_code, states)
350
- DFG += temp
351
- for child in root_node.children:
352
- if child.type not in do_first_statement:
353
- temp, states = DFG_java(child, index_to_code, states)
354
- DFG += temp
355
-
356
- return sorted(DFG, key=lambda x: x[1]), states
357
-
358
-
359
- def DFG_csharp(root_node, index_to_code, states):
360
- assignment = ['assignment_expression']
361
- def_statement = ['variable_declarator']
362
- increment_statement = ['postfix_unary_expression']
363
- if_statement = ['if_statement', 'else']
364
- for_statement = ['for_statement']
365
- enhanced_for_statement = ['for_each_statement']
366
- while_statement = ['while_statement']
367
- do_first_statement = []
368
- states = states.copy()
369
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
370
- 'character_literal']) and root_node.type != 'comment':
371
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
372
- if root_node.type == code:
373
- return [], states
374
- elif code in states:
375
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
376
- else:
377
- if root_node.type == 'identifier':
378
- states[code] = [idx]
379
- return [(code, idx, 'comesFrom', [], [])], states
380
- elif root_node.type in def_statement:
381
- if len(root_node.children) == 2:
382
- name = root_node.children[0]
383
- value = root_node.children[1]
384
- else:
385
- name = root_node.children[0]
386
- value = None
387
- DFG = []
388
- if value is None:
389
- indexs = tree_to_variable_index(name, index_to_code)
390
- for index in indexs:
391
- idx, code = index_to_code[index]
392
- DFG.append((code, idx, 'comesFrom', [], []))
393
- states[code] = [idx]
394
- return sorted(DFG, key=lambda x: x[1]), states
395
- else:
396
- name_indexs = tree_to_variable_index(name, index_to_code)
397
- value_indexs = tree_to_variable_index(value, index_to_code)
398
- temp, states = DFG_csharp(value, index_to_code, states)
399
- DFG += temp
400
- for index1 in name_indexs:
401
- idx1, code1 = index_to_code[index1]
402
- for index2 in value_indexs:
403
- idx2, code2 = index_to_code[index2]
404
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
405
- states[code1] = [idx1]
406
- return sorted(DFG, key=lambda x: x[1]), states
407
- elif root_node.type in assignment:
408
- left_nodes = root_node.child_by_field_name('left')
409
- right_nodes = root_node.child_by_field_name('right')
410
- DFG = []
411
- temp, states = DFG_csharp(right_nodes, index_to_code, states)
412
- DFG += temp
413
- name_indexs = tree_to_variable_index(left_nodes, index_to_code)
414
- value_indexs = tree_to_variable_index(right_nodes, index_to_code)
415
- for index1 in name_indexs:
416
- idx1, code1 = index_to_code[index1]
417
- for index2 in value_indexs:
418
- idx2, code2 = index_to_code[index2]
419
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
420
- states[code1] = [idx1]
421
- return sorted(DFG, key=lambda x: x[1]), states
422
- elif root_node.type in increment_statement:
423
- DFG = []
424
- indexs = tree_to_variable_index(root_node, index_to_code)
425
- for index1 in indexs:
426
- idx1, code1 = index_to_code[index1]
427
- for index2 in indexs:
428
- idx2, code2 = index_to_code[index2]
429
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
430
- states[code1] = [idx1]
431
- return sorted(DFG, key=lambda x: x[1]), states
432
- elif root_node.type in if_statement:
433
- DFG = []
434
- current_states = states.copy()
435
- others_states = []
436
- flag = False
437
- tag = False
438
- if 'else' in root_node.type:
439
- tag = True
440
- for child in root_node.children:
441
- if 'else' in child.type:
442
- tag = True
443
- if child.type not in if_statement and flag is False:
444
- temp, current_states = DFG_csharp(child, index_to_code, current_states)
445
- DFG += temp
446
- else:
447
- flag = True
448
- temp, new_states = DFG_csharp(child, index_to_code, states)
449
- DFG += temp
450
- others_states.append(new_states)
451
- others_states.append(current_states)
452
- if tag is False:
453
- others_states.append(states)
454
- new_states = {}
455
- for dic in others_states:
456
- for key in dic:
457
- if key not in new_states:
458
- new_states[key] = dic[key].copy()
459
- else:
460
- new_states[key] += dic[key]
461
- for key in new_states:
462
- new_states[key] = sorted(list(set(new_states[key])))
463
- return sorted(DFG, key=lambda x: x[1]), new_states
464
- elif root_node.type in for_statement:
465
- DFG = []
466
- for child in root_node.children:
467
- temp, states = DFG_csharp(child, index_to_code, states)
468
- DFG += temp
469
- flag = False
470
- for child in root_node.children:
471
- if flag:
472
- temp, states = DFG_csharp(child, index_to_code, states)
473
- DFG += temp
474
- elif child.type == "local_variable_declaration":
475
- flag = True
476
- dic = {}
477
- for x in DFG:
478
- if (x[0], x[1], x[2]) not in dic:
479
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
480
- else:
481
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
482
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
483
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
484
- return sorted(DFG, key=lambda x: x[1]), states
485
- elif root_node.type in enhanced_for_statement:
486
- name = root_node.child_by_field_name('left')
487
- value = root_node.child_by_field_name('right')
488
- body = root_node.child_by_field_name('body')
489
- DFG = []
490
- for i in range(2):
491
- temp, states = DFG_csharp(value, index_to_code, states)
492
- DFG += temp
493
- name_indexs = tree_to_variable_index(name, index_to_code)
494
- value_indexs = tree_to_variable_index(value, index_to_code)
495
- for index1 in name_indexs:
496
- idx1, code1 = index_to_code[index1]
497
- for index2 in value_indexs:
498
- idx2, code2 = index_to_code[index2]
499
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
500
- states[code1] = [idx1]
501
- temp, states = DFG_csharp(body, index_to_code, states)
502
- DFG += temp
503
- dic = {}
504
- for x in DFG:
505
- if (x[0], x[1], x[2]) not in dic:
506
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
507
- else:
508
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
509
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
510
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
511
- return sorted(DFG, key=lambda x: x[1]), states
512
- elif root_node.type in while_statement:
513
- DFG = []
514
- for i in range(2):
515
- for child in root_node.children:
516
- temp, states = DFG_csharp(child, index_to_code, states)
517
- DFG += temp
518
- dic = {}
519
- for x in DFG:
520
- if (x[0], x[1], x[2]) not in dic:
521
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
522
- else:
523
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
524
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
525
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
526
- return sorted(DFG, key=lambda x: x[1]), states
527
- else:
528
- DFG = []
529
- for child in root_node.children:
530
- if child.type in do_first_statement:
531
- temp, states = DFG_csharp(child, index_to_code, states)
532
- DFG += temp
533
- for child in root_node.children:
534
- if child.type not in do_first_statement:
535
- temp, states = DFG_csharp(child, index_to_code, states)
536
- DFG += temp
537
-
538
- return sorted(DFG, key=lambda x: x[1]), states
539
-
540
-
541
- def DFG_ruby(root_node, index_to_code, states):
542
- assignment = ['assignment', 'operator_assignment']
543
- if_statement = ['if', 'elsif', 'else', 'unless', 'when']
544
- for_statement = ['for']
545
- while_statement = ['while_modifier', 'until']
546
- do_first_statement = []
547
- def_statement = ['keyword_parameter']
548
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
549
- 'character_literal']) and root_node.type != 'comment':
550
- states = states.copy()
551
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
552
- if root_node.type == code:
553
- return [], states
554
- elif code in states:
555
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
556
- else:
557
- if root_node.type == 'identifier':
558
- states[code] = [idx]
559
- return [(code, idx, 'comesFrom', [], [])], states
560
- elif root_node.type in def_statement:
561
- name = root_node.child_by_field_name('name')
562
- value = root_node.child_by_field_name('value')
563
- DFG = []
564
- if value is None:
565
- indexs = tree_to_variable_index(name, index_to_code)
566
- for index in indexs:
567
- idx, code = index_to_code[index]
568
- DFG.append((code, idx, 'comesFrom', [], []))
569
- states[code] = [idx]
570
- return sorted(DFG, key=lambda x: x[1]), states
571
- else:
572
- name_indexs = tree_to_variable_index(name, index_to_code)
573
- value_indexs = tree_to_variable_index(value, index_to_code)
574
- temp, states = DFG_ruby(value, index_to_code, states)
575
- DFG += temp
576
- for index1 in name_indexs:
577
- idx1, code1 = index_to_code[index1]
578
- for index2 in value_indexs:
579
- idx2, code2 = index_to_code[index2]
580
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
581
- states[code1] = [idx1]
582
- return sorted(DFG, key=lambda x: x[1]), states
583
- elif root_node.type in assignment:
584
- left_nodes = [x for x in root_node.child_by_field_name('left').children if x.type != ',']
585
- right_nodes = [x for x in root_node.child_by_field_name('right').children if x.type != ',']
586
- if len(right_nodes) != len(left_nodes):
587
- left_nodes = [root_node.child_by_field_name('left')]
588
- right_nodes = [root_node.child_by_field_name('right')]
589
- if len(left_nodes) == 0:
590
- left_nodes = [root_node.child_by_field_name('left')]
591
- if len(right_nodes) == 0:
592
- right_nodes = [root_node.child_by_field_name('right')]
593
- if root_node.type == "operator_assignment":
594
- left_nodes = [root_node.children[0]]
595
- right_nodes = [root_node.children[-1]]
596
-
597
- DFG = []
598
- for node in right_nodes:
599
- temp, states = DFG_ruby(node, index_to_code, states)
600
- DFG += temp
601
-
602
- for left_node, right_node in zip(left_nodes, right_nodes):
603
- left_tokens_index = tree_to_variable_index(left_node, index_to_code)
604
- right_tokens_index = tree_to_variable_index(right_node, index_to_code)
605
- temp = []
606
- for token1_index in left_tokens_index:
607
- idx1, code1 = index_to_code[token1_index]
608
- temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index],
609
- [index_to_code[x][0] for x in right_tokens_index]))
610
- states[code1] = [idx1]
611
- DFG += temp
612
- return sorted(DFG, key=lambda x: x[1]), states
613
- elif root_node.type in if_statement:
614
- DFG = []
615
- current_states = states.copy()
616
- others_states = []
617
- tag = False
618
- if 'else' in root_node.type:
619
- tag = True
620
- for child in root_node.children:
621
- if 'else' in child.type:
622
- tag = True
623
- if child.type not in if_statement:
624
- temp, current_states = DFG_ruby(child, index_to_code, current_states)
625
- DFG += temp
626
- else:
627
- temp, new_states = DFG_ruby(child, index_to_code, states)
628
- DFG += temp
629
- others_states.append(new_states)
630
- others_states.append(current_states)
631
- if tag is False:
632
- others_states.append(states)
633
- new_states = {}
634
- for dic in others_states:
635
- for key in dic:
636
- if key not in new_states:
637
- new_states[key] = dic[key].copy()
638
- else:
639
- new_states[key] += dic[key]
640
- for key in new_states:
641
- new_states[key] = sorted(list(set(new_states[key])))
642
- return sorted(DFG, key=lambda x: x[1]), new_states
643
- elif root_node.type in for_statement:
644
- DFG = []
645
- for i in range(2):
646
- left_nodes = [root_node.child_by_field_name('pattern')]
647
- right_nodes = [root_node.child_by_field_name('value')]
648
- assert len(right_nodes) == len(left_nodes)
649
- for node in right_nodes:
650
- temp, states = DFG_ruby(node, index_to_code, states)
651
- DFG += temp
652
- for left_node, right_node in zip(left_nodes, right_nodes):
653
- left_tokens_index = tree_to_variable_index(left_node, index_to_code)
654
- right_tokens_index = tree_to_variable_index(right_node, index_to_code)
655
- temp = []
656
- for token1_index in left_tokens_index:
657
- idx1, code1 = index_to_code[token1_index]
658
- temp.append((code1, idx1, 'computedFrom', [index_to_code[x][1] for x in right_tokens_index],
659
- [index_to_code[x][0] for x in right_tokens_index]))
660
- states[code1] = [idx1]
661
- DFG += temp
662
- temp, states = DFG_ruby(root_node.child_by_field_name('body'), index_to_code, states)
663
- DFG += temp
664
- dic = {}
665
- for x in DFG:
666
- if (x[0], x[1], x[2]) not in dic:
667
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
668
- else:
669
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
670
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
671
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
672
- return sorted(DFG, key=lambda x: x[1]), states
673
- elif root_node.type in while_statement:
674
- DFG = []
675
- for i in range(2):
676
- for child in root_node.children:
677
- temp, states = DFG_ruby(child, index_to_code, states)
678
- DFG += temp
679
- dic = {}
680
- for x in DFG:
681
- if (x[0], x[1], x[2]) not in dic:
682
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
683
- else:
684
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
685
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
686
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
687
- return sorted(DFG, key=lambda x: x[1]), states
688
- else:
689
- DFG = []
690
- for child in root_node.children:
691
- if child.type in do_first_statement:
692
- temp, states = DFG_ruby(child, index_to_code, states)
693
- DFG += temp
694
- for child in root_node.children:
695
- if child.type not in do_first_statement:
696
- temp, states = DFG_ruby(child, index_to_code, states)
697
- DFG += temp
698
-
699
- return sorted(DFG, key=lambda x: x[1]), states
700
-
701
-
702
- def DFG_go(root_node, index_to_code, states):
703
- assignment = ['assignment_statement', ]
704
- def_statement = ['var_spec']
705
- increment_statement = ['inc_statement']
706
- if_statement = ['if_statement', 'else']
707
- for_statement = ['for_statement']
708
- enhanced_for_statement = []
709
- while_statement = []
710
- do_first_statement = []
711
- states = states.copy()
712
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
713
- 'character_literal']) and root_node.type != 'comment':
714
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
715
- if root_node.type == code:
716
- return [], states
717
- elif code in states:
718
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
719
- else:
720
- if root_node.type == 'identifier':
721
- states[code] = [idx]
722
- return [(code, idx, 'comesFrom', [], [])], states
723
- elif root_node.type in def_statement:
724
- name = root_node.child_by_field_name('name')
725
- value = root_node.child_by_field_name('value')
726
- DFG = []
727
- if value is None:
728
- indexs = tree_to_variable_index(name, index_to_code)
729
- for index in indexs:
730
- idx, code = index_to_code[index]
731
- DFG.append((code, idx, 'comesFrom', [], []))
732
- states[code] = [idx]
733
- return sorted(DFG, key=lambda x: x[1]), states
734
- else:
735
- name_indexs = tree_to_variable_index(name, index_to_code)
736
- value_indexs = tree_to_variable_index(value, index_to_code)
737
- temp, states = DFG_go(value, index_to_code, states)
738
- DFG += temp
739
- for index1 in name_indexs:
740
- idx1, code1 = index_to_code[index1]
741
- for index2 in value_indexs:
742
- idx2, code2 = index_to_code[index2]
743
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
744
- states[code1] = [idx1]
745
- return sorted(DFG, key=lambda x: x[1]), states
746
- elif root_node.type in assignment:
747
- left_nodes = root_node.child_by_field_name('left')
748
- right_nodes = root_node.child_by_field_name('right')
749
- DFG = []
750
- temp, states = DFG_go(right_nodes, index_to_code, states)
751
- DFG += temp
752
- name_indexs = tree_to_variable_index(left_nodes, index_to_code)
753
- value_indexs = tree_to_variable_index(right_nodes, index_to_code)
754
- for index1 in name_indexs:
755
- idx1, code1 = index_to_code[index1]
756
- for index2 in value_indexs:
757
- idx2, code2 = index_to_code[index2]
758
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
759
- states[code1] = [idx1]
760
- return sorted(DFG, key=lambda x: x[1]), states
761
- elif root_node.type in increment_statement:
762
- DFG = []
763
- indexs = tree_to_variable_index(root_node, index_to_code)
764
- for index1 in indexs:
765
- idx1, code1 = index_to_code[index1]
766
- for index2 in indexs:
767
- idx2, code2 = index_to_code[index2]
768
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
769
- states[code1] = [idx1]
770
- return sorted(DFG, key=lambda x: x[1]), states
771
- elif root_node.type in if_statement:
772
- DFG = []
773
- current_states = states.copy()
774
- others_states = []
775
- flag = False
776
- tag = False
777
- if 'else' in root_node.type:
778
- tag = True
779
- for child in root_node.children:
780
- if 'else' in child.type:
781
- tag = True
782
- if child.type not in if_statement and flag is False:
783
- temp, current_states = DFG_go(child, index_to_code, current_states)
784
- DFG += temp
785
- else:
786
- flag = True
787
- temp, new_states = DFG_go(child, index_to_code, states)
788
- DFG += temp
789
- others_states.append(new_states)
790
- others_states.append(current_states)
791
- if tag is False:
792
- others_states.append(states)
793
- new_states = {}
794
- for dic in others_states:
795
- for key in dic:
796
- if key not in new_states:
797
- new_states[key] = dic[key].copy()
798
- else:
799
- new_states[key] += dic[key]
800
- for key in states:
801
- if key not in new_states:
802
- new_states[key] = states[key]
803
- else:
804
- new_states[key] += states[key]
805
- for key in new_states:
806
- new_states[key] = sorted(list(set(new_states[key])))
807
- return sorted(DFG, key=lambda x: x[1]), new_states
808
- elif root_node.type in for_statement:
809
- DFG = []
810
- for child in root_node.children:
811
- temp, states = DFG_go(child, index_to_code, states)
812
- DFG += temp
813
- flag = False
814
- for child in root_node.children:
815
- if flag:
816
- temp, states = DFG_go(child, index_to_code, states)
817
- DFG += temp
818
- elif child.type == "for_clause":
819
- if child.child_by_field_name('update') is not None:
820
- temp, states = DFG_go(child.child_by_field_name('update'), index_to_code, states)
821
- DFG += temp
822
- flag = True
823
- dic = {}
824
- for x in DFG:
825
- if (x[0], x[1], x[2]) not in dic:
826
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
827
- else:
828
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
829
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
830
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
831
- return sorted(DFG, key=lambda x: x[1]), states
832
- else:
833
- DFG = []
834
- for child in root_node.children:
835
- if child.type in do_first_statement:
836
- temp, states = DFG_go(child, index_to_code, states)
837
- DFG += temp
838
- for child in root_node.children:
839
- if child.type not in do_first_statement:
840
- temp, states = DFG_go(child, index_to_code, states)
841
- DFG += temp
842
-
843
- return sorted(DFG, key=lambda x: x[1]), states
844
-
845
-
846
- def DFG_php(root_node, index_to_code, states):
847
- assignment = ['assignment_expression', 'augmented_assignment_expression']
848
- def_statement = ['simple_parameter']
849
- increment_statement = ['update_expression']
850
- if_statement = ['if_statement', 'else_clause']
851
- for_statement = ['for_statement']
852
- enhanced_for_statement = ['foreach_statement']
853
- while_statement = ['while_statement']
854
- do_first_statement = []
855
- states = states.copy()
856
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
857
- 'character_literal']) and root_node.type != 'comment':
858
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
859
- if root_node.type == code:
860
- return [], states
861
- elif code in states:
862
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
863
- else:
864
- if root_node.type == 'identifier':
865
- states[code] = [idx]
866
- return [(code, idx, 'comesFrom', [], [])], states
867
- elif root_node.type in def_statement:
868
- name = root_node.child_by_field_name('name')
869
- value = root_node.child_by_field_name('default_value')
870
- DFG = []
871
- if value is None:
872
- indexs = tree_to_variable_index(name, index_to_code)
873
- for index in indexs:
874
- idx, code = index_to_code[index]
875
- DFG.append((code, idx, 'comesFrom', [], []))
876
- states[code] = [idx]
877
- return sorted(DFG, key=lambda x: x[1]), states
878
- else:
879
- name_indexs = tree_to_variable_index(name, index_to_code)
880
- value_indexs = tree_to_variable_index(value, index_to_code)
881
- temp, states = DFG_php(value, index_to_code, states)
882
- DFG += temp
883
- for index1 in name_indexs:
884
- idx1, code1 = index_to_code[index1]
885
- for index2 in value_indexs:
886
- idx2, code2 = index_to_code[index2]
887
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
888
- states[code1] = [idx1]
889
- return sorted(DFG, key=lambda x: x[1]), states
890
- elif root_node.type in assignment:
891
- left_nodes = root_node.child_by_field_name('left')
892
- right_nodes = root_node.child_by_field_name('right')
893
- DFG = []
894
- temp, states = DFG_php(right_nodes, index_to_code, states)
895
- DFG += temp
896
- name_indexs = tree_to_variable_index(left_nodes, index_to_code)
897
- value_indexs = tree_to_variable_index(right_nodes, index_to_code)
898
- for index1 in name_indexs:
899
- idx1, code1 = index_to_code[index1]
900
- for index2 in value_indexs:
901
- idx2, code2 = index_to_code[index2]
902
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
903
- states[code1] = [idx1]
904
- return sorted(DFG, key=lambda x: x[1]), states
905
- elif root_node.type in increment_statement:
906
- DFG = []
907
- indexs = tree_to_variable_index(root_node, index_to_code)
908
- for index1 in indexs:
909
- idx1, code1 = index_to_code[index1]
910
- for index2 in indexs:
911
- idx2, code2 = index_to_code[index2]
912
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
913
- states[code1] = [idx1]
914
- return sorted(DFG, key=lambda x: x[1]), states
915
- elif root_node.type in if_statement:
916
- DFG = []
917
- current_states = states.copy()
918
- others_states = []
919
- flag = False
920
- tag = False
921
- if 'else' in root_node.type:
922
- tag = True
923
- for child in root_node.children:
924
- if 'else' in child.type:
925
- tag = True
926
- if child.type not in if_statement and flag is False:
927
- temp, current_states = DFG_php(child, index_to_code, current_states)
928
- DFG += temp
929
- else:
930
- flag = True
931
- temp, new_states = DFG_php(child, index_to_code, states)
932
- DFG += temp
933
- others_states.append(new_states)
934
- others_states.append(current_states)
935
- new_states = {}
936
- for dic in others_states:
937
- for key in dic:
938
- if key not in new_states:
939
- new_states[key] = dic[key].copy()
940
- else:
941
- new_states[key] += dic[key]
942
- for key in states:
943
- if key not in new_states:
944
- new_states[key] = states[key]
945
- else:
946
- new_states[key] += states[key]
947
- for key in new_states:
948
- new_states[key] = sorted(list(set(new_states[key])))
949
- return sorted(DFG, key=lambda x: x[1]), new_states
950
- elif root_node.type in for_statement:
951
- DFG = []
952
- for child in root_node.children:
953
- temp, states = DFG_php(child, index_to_code, states)
954
- DFG += temp
955
- flag = False
956
- for child in root_node.children:
957
- if flag:
958
- temp, states = DFG_php(child, index_to_code, states)
959
- DFG += temp
960
- elif child.type == "assignment_expression":
961
- flag = True
962
- dic = {}
963
- for x in DFG:
964
- if (x[0], x[1], x[2]) not in dic:
965
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
966
- else:
967
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
968
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
969
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
970
- return sorted(DFG, key=lambda x: x[1]), states
971
- elif root_node.type in enhanced_for_statement:
972
- name = None
973
- value = None
974
- for child in root_node.children:
975
- if child.type == 'variable_name' and value is None:
976
- value = child
977
- elif child.type == 'variable_name' and name is None:
978
- name = child
979
- break
980
- body = root_node.child_by_field_name('body')
981
- DFG = []
982
- for i in range(2):
983
- temp, states = DFG_php(value, index_to_code, states)
984
- DFG += temp
985
- name_indexs = tree_to_variable_index(name, index_to_code)
986
- value_indexs = tree_to_variable_index(value, index_to_code)
987
- for index1 in name_indexs:
988
- idx1, code1 = index_to_code[index1]
989
- for index2 in value_indexs:
990
- idx2, code2 = index_to_code[index2]
991
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
992
- states[code1] = [idx1]
993
- temp, states = DFG_php(body, index_to_code, states)
994
- DFG += temp
995
- dic = {}
996
- for x in DFG:
997
- if (x[0], x[1], x[2]) not in dic:
998
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
999
- else:
1000
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
1001
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
1002
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
1003
- return sorted(DFG, key=lambda x: x[1]), states
1004
- elif root_node.type in while_statement:
1005
- DFG = []
1006
- for i in range(2):
1007
- for child in root_node.children:
1008
- temp, states = DFG_php(child, index_to_code, states)
1009
- DFG += temp
1010
- dic = {}
1011
- for x in DFG:
1012
- if (x[0], x[1], x[2]) not in dic:
1013
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
1014
- else:
1015
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
1016
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
1017
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
1018
- return sorted(DFG, key=lambda x: x[1]), states
1019
- else:
1020
- DFG = []
1021
- for child in root_node.children:
1022
- if child.type in do_first_statement:
1023
- temp, states = DFG_php(child, index_to_code, states)
1024
- DFG += temp
1025
- for child in root_node.children:
1026
- if child.type not in do_first_statement:
1027
- temp, states = DFG_php(child, index_to_code, states)
1028
- DFG += temp
1029
-
1030
- return sorted(DFG, key=lambda x: x[1]), states
1031
-
1032
-
1033
- def DFG_javascript(root_node, index_to_code, states):
1034
- assignment = ['assignment_pattern', 'augmented_assignment_expression']
1035
- def_statement = ['variable_declarator']
1036
- increment_statement = ['update_expression']
1037
- if_statement = ['if_statement', 'else']
1038
- for_statement = ['for_statement']
1039
- enhanced_for_statement = []
1040
- while_statement = ['while_statement']
1041
- do_first_statement = []
1042
- states = states.copy()
1043
- if (len(root_node.children) == 0 or root_node.type in ['string_literal', 'string',
1044
- 'character_literal']) and root_node.type != 'comment':
1045
- idx, code = index_to_code[(root_node.start_point, root_node.end_point)]
1046
- if root_node.type == code:
1047
- return [], states
1048
- elif code in states:
1049
- return [(code, idx, 'comesFrom', [code], states[code].copy())], states
1050
- else:
1051
- if root_node.type == 'identifier':
1052
- states[code] = [idx]
1053
- return [(code, idx, 'comesFrom', [], [])], states
1054
- elif root_node.type in def_statement:
1055
- name = root_node.child_by_field_name('name')
1056
- value = root_node.child_by_field_name('value')
1057
- DFG = []
1058
- if value is None:
1059
- indexs = tree_to_variable_index(name, index_to_code)
1060
- for index in indexs:
1061
- idx, code = index_to_code[index]
1062
- DFG.append((code, idx, 'comesFrom', [], []))
1063
- states[code] = [idx]
1064
- return sorted(DFG, key=lambda x: x[1]), states
1065
- else:
1066
- name_indexs = tree_to_variable_index(name, index_to_code)
1067
- value_indexs = tree_to_variable_index(value, index_to_code)
1068
- temp, states = DFG_javascript(value, index_to_code, states)
1069
- DFG += temp
1070
- for index1 in name_indexs:
1071
- idx1, code1 = index_to_code[index1]
1072
- for index2 in value_indexs:
1073
- idx2, code2 = index_to_code[index2]
1074
- DFG.append((code1, idx1, 'comesFrom', [code2], [idx2]))
1075
- states[code1] = [idx1]
1076
- return sorted(DFG, key=lambda x: x[1]), states
1077
- elif root_node.type in assignment:
1078
- left_nodes = root_node.child_by_field_name('left')
1079
- right_nodes = root_node.child_by_field_name('right')
1080
- DFG = []
1081
- temp, states = DFG_javascript(right_nodes, index_to_code, states)
1082
- DFG += temp
1083
- name_indexs = tree_to_variable_index(left_nodes, index_to_code)
1084
- value_indexs = tree_to_variable_index(right_nodes, index_to_code)
1085
- for index1 in name_indexs:
1086
- idx1, code1 = index_to_code[index1]
1087
- for index2 in value_indexs:
1088
- idx2, code2 = index_to_code[index2]
1089
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
1090
- states[code1] = [idx1]
1091
- return sorted(DFG, key=lambda x: x[1]), states
1092
- elif root_node.type in increment_statement:
1093
- DFG = []
1094
- indexs = tree_to_variable_index(root_node, index_to_code)
1095
- for index1 in indexs:
1096
- idx1, code1 = index_to_code[index1]
1097
- for index2 in indexs:
1098
- idx2, code2 = index_to_code[index2]
1099
- DFG.append((code1, idx1, 'computedFrom', [code2], [idx2]))
1100
- states[code1] = [idx1]
1101
- return sorted(DFG, key=lambda x: x[1]), states
1102
- elif root_node.type in if_statement:
1103
- DFG = []
1104
- current_states = states.copy()
1105
- others_states = []
1106
- flag = False
1107
- tag = False
1108
- if 'else' in root_node.type:
1109
- tag = True
1110
- for child in root_node.children:
1111
- if 'else' in child.type:
1112
- tag = True
1113
- if child.type not in if_statement and flag is False:
1114
- temp, current_states = DFG_javascript(child, index_to_code, current_states)
1115
- DFG += temp
1116
- else:
1117
- flag = True
1118
- temp, new_states = DFG_javascript(child, index_to_code, states)
1119
- DFG += temp
1120
- others_states.append(new_states)
1121
- others_states.append(current_states)
1122
- if tag is False:
1123
- others_states.append(states)
1124
- new_states = {}
1125
- for dic in others_states:
1126
- for key in dic:
1127
- if key not in new_states:
1128
- new_states[key] = dic[key].copy()
1129
- else:
1130
- new_states[key] += dic[key]
1131
- for key in states:
1132
- if key not in new_states:
1133
- new_states[key] = states[key]
1134
- else:
1135
- new_states[key] += states[key]
1136
- for key in new_states:
1137
- new_states[key] = sorted(list(set(new_states[key])))
1138
- return sorted(DFG, key=lambda x: x[1]), new_states
1139
- elif root_node.type in for_statement:
1140
- DFG = []
1141
- for child in root_node.children:
1142
- temp, states = DFG_javascript(child, index_to_code, states)
1143
- DFG += temp
1144
- flag = False
1145
- for child in root_node.children:
1146
- if flag:
1147
- temp, states = DFG_javascript(child, index_to_code, states)
1148
- DFG += temp
1149
- elif child.type == "variable_declaration":
1150
- flag = True
1151
- dic = {}
1152
- for x in DFG:
1153
- if (x[0], x[1], x[2]) not in dic:
1154
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
1155
- else:
1156
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
1157
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
1158
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
1159
- return sorted(DFG, key=lambda x: x[1]), states
1160
- elif root_node.type in while_statement:
1161
- DFG = []
1162
- for i in range(2):
1163
- for child in root_node.children:
1164
- temp, states = DFG_javascript(child, index_to_code, states)
1165
- DFG += temp
1166
- dic = {}
1167
- for x in DFG:
1168
- if (x[0], x[1], x[2]) not in dic:
1169
- dic[(x[0], x[1], x[2])] = [x[3], x[4]]
1170
- else:
1171
- dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3]))
1172
- dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4])))
1173
- DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])]
1174
- return sorted(DFG, key=lambda x: x[1]), states
1175
- else:
1176
- DFG = []
1177
- for child in root_node.children:
1178
- if child.type in do_first_statement:
1179
- temp, states = DFG_javascript(child, index_to_code, states)
1180
- DFG += temp
1181
- for child in root_node.children:
1182
- if child.type not in do_first_statement:
1183
- temp, states = DFG_javascript(child, index_to_code, states)
1184
- DFG += temp
1185
-
1186
- return sorted(DFG, key=lambda x: x[1]), states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/parser/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- from .utils import (remove_comments_and_docstrings,
5
- tree_to_token_index,
6
- index_to_code_token,
7
- tree_to_variable_index)
8
- from .DFG import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript,DFG_csharp
 
 
 
 
 
 
 
 
 
eval/parser/build.py DELETED
@@ -1,15 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- from tree_sitter import Language, Parser
5
-
6
- Language.build_library(
7
- # Store the library in the `build` directory
8
- 'my-languages.so',
9
-
10
- # Include one or more languages
11
- [
12
- 'tree-sitter-python'
13
- ]
14
- )
15
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/parser/build.sh DELETED
@@ -1,2 +0,0 @@
1
- git clone https://github.com/tree-sitter/tree-sitter-python
2
- python build.py
 
 
 
eval/parser/utils.py DELETED
@@ -1,101 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- import re
5
- from io import StringIO
6
- import tokenize
7
- def remove_comments_and_docstrings(source,lang):
8
- if lang in ['python']:
9
- """
10
- Returns 'source' minus comments and docstrings.
11
- """
12
- io_obj = StringIO(source)
13
- out = ""
14
- prev_toktype = tokenize.INDENT
15
- last_lineno = -1
16
- last_col = 0
17
- for tok in tokenize.generate_tokens(io_obj.readline):
18
- token_type = tok[0]
19
- token_string = tok[1]
20
- start_line, start_col = tok[2]
21
- end_line, end_col = tok[3]
22
- ltext = tok[4]
23
- if start_line > last_lineno:
24
- last_col = 0
25
- if start_col > last_col:
26
- out += (" " * (start_col - last_col))
27
- # Remove comments:
28
- if token_type == tokenize.COMMENT:
29
- pass
30
- # This series of conditionals removes docstrings:
31
- elif token_type == tokenize.STRING:
32
- if prev_toktype != tokenize.INDENT:
33
- # This is likely a docstring; double-check we're not inside an operator:
34
- if prev_toktype != tokenize.NEWLINE:
35
- if start_col > 0:
36
- out += token_string
37
- else:
38
- out += token_string
39
- prev_toktype = token_type
40
- last_col = end_col
41
- last_lineno = end_line
42
- temp=[]
43
- for x in out.split('\n'):
44
- if x.strip()!="":
45
- temp.append(x)
46
- return '\n'.join(temp)
47
- elif lang in ['ruby']:
48
- return source
49
- else:
50
- def replacer(match):
51
- s = match.group(0)
52
- if s.startswith('/'):
53
- return " " # note: a space and not an empty string
54
- else:
55
- return s
56
- pattern = re.compile(
57
- r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"',
58
- re.DOTALL | re.MULTILINE
59
- )
60
- temp=[]
61
- for x in re.sub(pattern, replacer, source).split('\n'):
62
- if x.strip()!="":
63
- temp.append(x)
64
- return '\n'.join(temp)
65
-
66
- def tree_to_token_index(root_node):
67
- if (len(root_node.children)==0 or root_node.type in ['string_literal','string','character_literal']) and root_node.type!='comment':
68
- return [(root_node.start_point,root_node.end_point)]
69
- else:
70
- code_tokens=[]
71
- for child in root_node.children:
72
- code_tokens+=tree_to_token_index(child)
73
- return code_tokens
74
-
75
- def tree_to_variable_index(root_node,index_to_code):
76
- if (len(root_node.children)==0 or root_node.type in ['string_literal','string','character_literal']) and root_node.type!='comment':
77
- index=(root_node.start_point,root_node.end_point)
78
- _,code=index_to_code[index]
79
- if root_node.type!=code:
80
- return [(root_node.start_point,root_node.end_point)]
81
- else:
82
- return []
83
- else:
84
- code_tokens=[]
85
- for child in root_node.children:
86
- code_tokens+=tree_to_variable_index(child,index_to_code)
87
- return code_tokens
88
-
89
- def index_to_code_token(index,code):
90
- start_point=index[0]
91
- end_point=index[1]
92
- if start_point[0]==end_point[0]:
93
- s=code[start_point[0]][start_point[1]:end_point[1]]
94
- else:
95
- s=""
96
- s+=code[start_point[0]][start_point[1]:]
97
- for i in range(start_point[0]+1,end_point[0]):
98
- s+=code[i]
99
- s+=code[end_point[0]][:end_point[1]]
100
- return s
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/syntax_match.py DELETED
@@ -1,76 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT license.
3
-
4
- from parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript, DFG_csharp
5
- from parser import (remove_comments_and_docstrings,
6
- tree_to_token_index,
7
- index_to_code_token,
8
- tree_to_variable_index)
9
- from tree_sitter import Language, Parser
10
-
11
- dfg_function = {
12
- 'python': DFG_python,
13
- 'java': DFG_java,
14
- 'ruby': DFG_ruby,
15
- 'go': DFG_go,
16
- 'php': DFG_php,
17
- 'javascript': DFG_javascript,
18
- 'c_sharp': DFG_csharp,
19
- }
20
-
21
-
22
- def calc_syntax_match(references, candidate, lang):
23
- return corpus_syntax_match([references], [candidate], lang)
24
-
25
-
26
- def corpus_syntax_match(references, candidates, lang):
27
- LANGUAGE = Language('./src/eval/parser/my-languages.so', lang)
28
- parser = Parser()
29
- parser.set_language(LANGUAGE)
30
- match_count = 0
31
- total_count = 0
32
-
33
- for i in range(len(candidates)):
34
- references_sample = references[i]
35
- candidate = candidates[i]
36
- for reference in references_sample:
37
- try:
38
- candidate = remove_comments_and_docstrings(candidate, LANGUAGE)
39
- except:
40
- pass
41
- try:
42
- reference = remove_comments_and_docstrings(reference, LANGUAGE)
43
- except:
44
- pass
45
-
46
- candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
47
-
48
- reference_tree = parser.parse(bytes(reference, 'utf8')).root_node
49
-
50
- def get_all_sub_trees(root_node):
51
- node_stack = []
52
- sub_tree_sexp_list = []
53
- depth = 1
54
- node_stack.append([root_node, depth])
55
- while len(node_stack) != 0:
56
- cur_node, cur_depth = node_stack.pop()
57
- sub_tree_sexp_list.append([cur_node.sexp(), cur_depth])
58
- for child_node in cur_node.children:
59
- if len(child_node.children) != 0:
60
- depth = cur_depth + 1
61
- node_stack.append([child_node, depth])
62
- return sub_tree_sexp_list
63
-
64
- cand_sexps = [x[0] for x in get_all_sub_trees(candidate_tree)]
65
- ref_sexps = get_all_sub_trees(reference_tree)
66
-
67
- # print(cand_sexps)
68
- # print(ref_sexps)
69
-
70
- for sub_tree, depth in ref_sexps:
71
- if sub_tree in cand_sexps:
72
- match_count += 1
73
- total_count += len(ref_sexps)
74
-
75
- score = match_count / total_count
76
- return score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/utils.py DELETED
@@ -1,106 +0,0 @@
1
- # Natural Language Toolkit: Utility functions
2
- #
3
- # Copyright (C) 2001-2020 NLTK Project
4
- # Author: Steven Bird <stevenbird1@gmail.com>
5
- # URL: <http://nltk.org/>
6
- # For license information, see LICENSE.TXT
7
-
8
- from itertools import chain
9
-
10
- def pad_sequence(
11
- sequence,
12
- n,
13
- pad_left=False,
14
- pad_right=False,
15
- left_pad_symbol=None,
16
- right_pad_symbol=None,
17
- ):
18
- """
19
- Returns a padded sequence of items before ngram extraction.
20
- >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
21
- ['<s>', 1, 2, 3, 4, 5, '</s>']
22
- >>> list(pad_sequence([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
23
- ['<s>', 1, 2, 3, 4, 5]
24
- >>> list(pad_sequence([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
25
- [1, 2, 3, 4, 5, '</s>']
26
- :param sequence: the source data to be padded
27
- :type sequence: sequence or iter
28
- :param n: the degree of the ngrams
29
- :type n: int
30
- :param pad_left: whether the ngrams should be left-padded
31
- :type pad_left: bool
32
- :param pad_right: whether the ngrams should be right-padded
33
- :type pad_right: bool
34
- :param left_pad_symbol: the symbol to use for left padding (default is None)
35
- :type left_pad_symbol: any
36
- :param right_pad_symbol: the symbol to use for right padding (default is None)
37
- :type right_pad_symbol: any
38
- :rtype: sequence or iter
39
- """
40
- sequence = iter(sequence)
41
- if pad_left:
42
- sequence = chain((left_pad_symbol,) * (n - 1), sequence)
43
- if pad_right:
44
- sequence = chain(sequence, (right_pad_symbol,) * (n - 1))
45
- return sequence
46
-
47
-
48
- # add a flag to pad the sequence so we get peripheral ngrams?
49
-
50
-
51
- def ngrams(
52
- sequence,
53
- n,
54
- pad_left=False,
55
- pad_right=False,
56
- left_pad_symbol=None,
57
- right_pad_symbol=None,
58
- ):
59
- """
60
- Return the ngrams generated from a sequence of items, as an iterator.
61
- For example:
62
- >>> from nltk.util import ngrams
63
- >>> list(ngrams([1,2,3,4,5], 3))
64
- [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
65
- Wrap with list for a list version of this function. Set pad_left
66
- or pad_right to true in order to get additional ngrams:
67
- >>> list(ngrams([1,2,3,4,5], 2, pad_right=True))
68
- [(1, 2), (2, 3), (3, 4), (4, 5), (5, None)]
69
- >>> list(ngrams([1,2,3,4,5], 2, pad_right=True, right_pad_symbol='</s>'))
70
- [(1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
71
- >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, left_pad_symbol='<s>'))
72
- [('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5)]
73
- >>> list(ngrams([1,2,3,4,5], 2, pad_left=True, pad_right=True, left_pad_symbol='<s>', right_pad_symbol='</s>'))
74
- [('<s>', 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, '</s>')]
75
- :param sequence: the source data to be converted into ngrams
76
- :type sequence: sequence or iter
77
- :param n: the degree of the ngrams
78
- :type n: int
79
- :param pad_left: whether the ngrams should be left-padded
80
- :type pad_left: bool
81
- :param pad_right: whether the ngrams should be right-padded
82
- :type pad_right: bool
83
- :param left_pad_symbol: the symbol to use for left padding (default is None)
84
- :type left_pad_symbol: any
85
- :param right_pad_symbol: the symbol to use for right padding (default is None)
86
- :type right_pad_symbol: any
87
- :rtype: sequence or iter
88
- """
89
- sequence = pad_sequence(
90
- sequence, n, pad_left, pad_right, left_pad_symbol, right_pad_symbol
91
- )
92
-
93
- history = []
94
- while n > 1:
95
- # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
96
- try:
97
- next_item = next(sequence)
98
- except StopIteration:
99
- # no more data, terminate the generator
100
- return
101
- history.append(next_item)
102
- n -= 1
103
- for item in sequence:
104
- history.append(item)
105
- yield tuple(history)
106
- del history[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/weighted_ngram_match.py DELETED
@@ -1,558 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Microsoft Corporation.
3
- # Licensed under the MIT license.
4
-
5
- # Natural Language Toolkit: BLEU Score
6
- #
7
- # Copyright (C) 2001-2020 NLTK Project
8
- # Authors: Chin Yee Lee, Hengfeng Li, Ruxin Hou, Calvin Tanujaya Lim
9
- # Contributors: Björn Mattsson, Dmitrijs Milajevs, Liling Tan
10
- # URL: <http://nltk.org/>
11
- # For license information, see LICENSE.TXT
12
-
13
- """BLEU score implementation."""
14
-
15
- import math
16
- import sys
17
- from fractions import Fraction
18
- import warnings
19
- from collections import Counter
20
-
21
- from utils import ngrams
22
- import pdb
23
-
24
-
25
- def sentence_bleu(
26
- references,
27
- hypothesis,
28
- weights=(0.25, 0.25, 0.25, 0.25),
29
- smoothing_function=None,
30
- auto_reweigh=False,
31
- ):
32
- """
33
- Calculate BLEU score (Bilingual Evaluation Understudy) from
34
- Papineni, Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu. 2002.
35
- "BLEU: a method for automatic evaluation of machine translation."
36
- In Proceedings of ACL. http://www.aclweb.org/anthology/P02-1040.pdf
37
- >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
38
- ... 'ensures', 'that', 'the', 'military', 'always',
39
- ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
40
- >>> hypothesis2 = ['It', 'is', 'to', 'insure', 'the', 'troops',
41
- ... 'forever', 'hearing', 'the', 'activity', 'guidebook',
42
- ... 'that', 'party', 'direct']
43
- >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
44
- ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
45
- ... 'heed', 'Party', 'commands']
46
- >>> reference2 = ['It', 'is', 'the', 'guiding', 'principle', 'which',
47
- ... 'guarantees', 'the', 'military', 'forces', 'always',
48
- ... 'being', 'under', 'the', 'command', 'of', 'the',
49
- ... 'Party']
50
- >>> reference3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
51
- ... 'army', 'always', 'to', 'heed', 'the', 'directions',
52
- ... 'of', 'the', 'party']
53
- >>> sentence_bleu([reference1, reference2, reference3], hypothesis1) # doctest: +ELLIPSIS
54
- 0.5045...
55
- If there is no ngrams overlap for any order of n-grams, BLEU returns the
56
- value 0. This is because the precision for the order of n-grams without
57
- overlap is 0, and the geometric mean in the final BLEU score computation
58
- multiplies the 0 with the precision of other n-grams. This results in 0
59
- (independently of the precision of the othe n-gram orders). The following
60
- example has zero 3-gram and 4-gram overlaps:
61
- >>> round(sentence_bleu([reference1, reference2, reference3], hypothesis2),4) # doctest: +ELLIPSIS
62
- 0.0
63
- To avoid this harsh behaviour when no ngram overlaps are found a smoothing
64
- function can be used.
65
- >>> chencherry = SmoothingFunction()
66
- >>> sentence_bleu([reference1, reference2, reference3], hypothesis2,
67
- ... smoothing_function=chencherry.method1) # doctest: +ELLIPSIS
68
- 0.0370...
69
- The default BLEU calculates a score for up to 4-grams using uniform
70
- weights (this is called BLEU-4). To evaluate your translations with
71
- higher/lower order ngrams, use customized weights. E.g. when accounting
72
- for up to 5-grams with uniform weights (this is called BLEU-5) use:
73
- >>> weights = (1./5., 1./5., 1./5., 1./5., 1./5.)
74
- >>> sentence_bleu([reference1, reference2, reference3], hypothesis1, weights) # doctest: +ELLIPSIS
75
- 0.3920...
76
- :param references: reference sentences
77
- :type references: list(list(str))
78
- :param hypothesis: a hypothesis sentence
79
- :type hypothesis: list(str)
80
- :param weights: weights for unigrams, bigrams, trigrams and so on
81
- :type weights: list(float)
82
- :param smoothing_function:
83
- :type smoothing_function: SmoothingFunction
84
- :param auto_reweigh: Option to re-normalize the weights uniformly.
85
- :type auto_reweigh: bool
86
- :return: The sentence-level BLEU score.
87
- :rtype: float
88
- """
89
- return corpus_bleu(
90
- [references], [hypothesis], weights, smoothing_function, auto_reweigh
91
- )
92
-
93
-
94
- def corpus_bleu(
95
- list_of_references,
96
- hypotheses,
97
- weights=(0.25, 0.25, 0.25, 0.25),
98
- smoothing_function=None,
99
- auto_reweigh=False,
100
- ):
101
- """
102
- Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
103
- the hypotheses and their respective references.
104
- Instead of averaging the sentence level BLEU scores (i.e. marco-average
105
- precision), the original BLEU metric (Papineni et al. 2002) accounts for
106
- the micro-average precision (i.e. summing the numerators and denominators
107
- for each hypothesis-reference(s) pairs before the division).
108
- >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
109
- ... 'ensures', 'that', 'the', 'military', 'always',
110
- ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
111
- >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
112
- ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
113
- ... 'heed', 'Party', 'commands']
114
- >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
115
- ... 'guarantees', 'the', 'military', 'forces', 'always',
116
- ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
117
- >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
118
- ... 'army', 'always', 'to', 'heed', 'the', 'directions',
119
- ... 'of', 'the', 'party']
120
- >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
121
- ... 'interested', 'in', 'world', 'history']
122
- >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
123
- ... 'because', 'he', 'read', 'the', 'book']
124
- >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
125
- >>> hypotheses = [hyp1, hyp2]
126
- >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
127
- 0.5920...
128
- The example below show that corpus_bleu() is different from averaging
129
- sentence_bleu() for hypotheses
130
- >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
131
- >>> score2 = sentence_bleu([ref2a], hyp2)
132
- >>> (score1 + score2) / 2 # doctest: +ELLIPSIS
133
- 0.6223...
134
- :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses
135
- :type list_of_references: list(list(list(str)))
136
- :param hypotheses: a list of hypothesis sentences
137
- :type hypotheses: list(list(str))
138
- :param weights: weights for unigrams, bigrams, trigrams and so on
139
- :type weights: list(float)
140
- :param smoothing_function:
141
- :type smoothing_function: SmoothingFunction
142
- :param auto_reweigh: Option to re-normalize the weights uniformly.
143
- :type auto_reweigh: bool
144
- :return: The corpus-level BLEU score.
145
- :rtype: float
146
- """
147
- # Before proceeding to compute BLEU, perform sanity checks.
148
-
149
- p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
150
- p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
151
- hyp_lengths, ref_lengths = 0, 0
152
-
153
- assert len(list_of_references) == len(hypotheses), (
154
- "The number of hypotheses and their reference(s) should be the " "same "
155
- )
156
-
157
- # Iterate through each hypothesis and their corresponding references.
158
- for references, hypothesis in zip(list_of_references, hypotheses):
159
- # For each order of ngram, calculate the numerator and
160
- # denominator for the corpus-level modified precision.
161
- for i, _ in enumerate(weights, start=1):
162
- p_i_numeraotr, p_i_denominator = modified_recall(references, hypothesis, i)
163
- p_numerators[i] += p_i_numeraotr
164
- p_denominators[i] += p_i_denominator
165
-
166
- # Calculate the hypothesis length and the closest reference length.
167
- # Adds them to the corpus-level hypothesis and reference counts.
168
- hyp_len = len(hypothesis)
169
- hyp_lengths += hyp_len
170
- ref_lengths += closest_ref_length(references, hyp_len)
171
-
172
- # Calculate corpus-level brevity penalty.
173
- bp = brevity_penalty(ref_lengths, hyp_lengths)
174
-
175
- # Uniformly re-weighting based on maximum hypothesis lengths if largest
176
- # order of n-grams < 4 and weights is set at default.
177
- if auto_reweigh:
178
- if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
179
- weights = (1 / hyp_lengths,) * hyp_lengths
180
-
181
- # Collects the various recall values for the different ngram orders.
182
- p_n = [
183
- (p_numerators[i], p_denominators[i])
184
- for i, _ in enumerate(weights, start=1)
185
- ]
186
-
187
- # Returns 0 if there's no matching n-grams
188
- # We only need to check for p_numerators[1] == 0, since if there's
189
- # no unigrams, there won't be any higher order ngrams.
190
- if p_numerators[1] == 0:
191
- return 0
192
-
193
- # If there's no smoothing, set use method0 from SmoothinFunction class.
194
- if not smoothing_function:
195
- smoothing_function = SmoothingFunction().method1
196
- # Smoothen the modified precision.
197
- # Note: smoothing_function() may convert values into floats;
198
- # it tries to retain the Fraction object as much as the
199
- # smoothing method allows.
200
- p_n = smoothing_function(
201
- p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
202
- )
203
- # pdb.set_trace()
204
- s = (w_i * math.log(p_i[0]/p_i[1]) for w_i, p_i in zip(weights, p_n))
205
- s = bp * math.exp(math.fsum(s))
206
- return s
207
-
208
-
209
- def modified_recall(references, hypothesis, n):
210
- """
211
- Calculate modified ngram recall.
212
- :param references: A list of reference translations.
213
- :type references: list(list(str))
214
- :param hypothesis: A hypothesis translation.
215
- :type hypothesis: list(str)
216
- :param n: The ngram order.
217
- :type n: int
218
- :return: BLEU's modified precision for the nth order ngram.
219
- :rtype: Fraction
220
- """
221
- # Extracts all ngrams in hypothesis
222
- # Set an empty Counter if hypothesis is empty.
223
- # pdb.set_trace()
224
- numerator = 0
225
- denominator = 0
226
-
227
- counts = Counter(ngrams(hypothesis, n)) if len(hypothesis) >= n else Counter()
228
- # Extract a union of references' counts.
229
- # max_counts = reduce(or_, [Counter(ngrams(ref, n)) for ref in references])
230
- max_counts = {}
231
- for reference_and_weights in references:
232
- reference = reference_and_weights[0]
233
- weights = reference_and_weights[1]
234
- reference_counts = (
235
- Counter(ngrams(reference, n)) if len(reference) >= n else Counter()
236
- )
237
- # for ngram in reference_counts:
238
- # max_counts[ngram] = max(max_counts.get(ngram, 0), counts[ngram])
239
- clipped_counts = {
240
- ngram: min(count, counts[ngram]) for ngram, count in reference_counts.items()
241
- }
242
- # reweight
243
- if n == 1 and len(weights) == len(reference_counts):
244
- def weighted_sum(weights, counts):
245
- sum_counts = 0
246
- for ngram, count in counts.items():
247
- sum_counts += count * (weights[ngram[0]] if ngram[0] in weights else 1)
248
- return sum_counts
249
-
250
- numerator += weighted_sum(weights, clipped_counts)
251
- denominator += max(1, weighted_sum(weights, reference_counts))
252
-
253
- else:
254
- numerator += sum(clipped_counts.values())
255
- denominator += max(1, sum(reference_counts.values()))
256
-
257
- # # Assigns the intersection between hypothesis and references' counts.
258
- # clipped_counts = {
259
- # ngram: min(count, max_counts[ngram]) for ngram, count in counts.items()
260
- # }
261
-
262
- # numerator += sum(clipped_counts.values())
263
- # # Ensures that denominator is minimum 1 to avoid ZeroDivisionError.
264
- # # Usually this happens when the ngram order is > len(reference).
265
- # denominator += max(1, sum(counts.values()))
266
-
267
- #return Fraction(numerator, denominator, _normalize=False)
268
- return numerator, denominator
269
-
270
-
271
- def closest_ref_length(references, hyp_len):
272
- """
273
- This function finds the reference that is the closest length to the
274
- hypothesis. The closest reference length is referred to as *r* variable
275
- from the brevity penalty formula in Papineni et. al. (2002)
276
- :param references: A list of reference translations.
277
- :type references: list(list(str))
278
- :param hyp_len: The length of the hypothesis.
279
- :type hyp_len: int
280
- :return: The length of the reference that's closest to the hypothesis.
281
- :rtype: int
282
- """
283
- ref_lens = (len(reference) for reference in references)
284
- closest_ref_len = min(
285
- ref_lens, key=lambda ref_len: (abs(ref_len - hyp_len), ref_len)
286
- )
287
- return closest_ref_len
288
-
289
-
290
- def brevity_penalty(closest_ref_len, hyp_len):
291
- """
292
- Calculate brevity penalty.
293
- As the modified n-gram precision still has the problem from the short
294
- length sentence, brevity penalty is used to modify the overall BLEU
295
- score according to length.
296
- An example from the paper. There are three references with length 12, 15
297
- and 17. And a concise hypothesis of the length 12. The brevity penalty is 1.
298
- >>> reference1 = list('aaaaaaaaaaaa') # i.e. ['a'] * 12
299
- >>> reference2 = list('aaaaaaaaaaaaaaa') # i.e. ['a'] * 15
300
- >>> reference3 = list('aaaaaaaaaaaaaaaaa') # i.e. ['a'] * 17
301
- >>> hypothesis = list('aaaaaaaaaaaa') # i.e. ['a'] * 12
302
- >>> references = [reference1, reference2, reference3]
303
- >>> hyp_len = len(hypothesis)
304
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
305
- >>> brevity_penalty(closest_ref_len, hyp_len)
306
- 1.0
307
- In case a hypothesis translation is shorter than the references, penalty is
308
- applied.
309
- >>> references = [['a'] * 28, ['a'] * 28]
310
- >>> hypothesis = ['a'] * 12
311
- >>> hyp_len = len(hypothesis)
312
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
313
- >>> brevity_penalty(closest_ref_len, hyp_len)
314
- 0.2635971381157267
315
- The length of the closest reference is used to compute the penalty. If the
316
- length of a hypothesis is 12, and the reference lengths are 13 and 2, the
317
- penalty is applied because the hypothesis length (12) is less then the
318
- closest reference length (13).
319
- >>> references = [['a'] * 13, ['a'] * 2]
320
- >>> hypothesis = ['a'] * 12
321
- >>> hyp_len = len(hypothesis)
322
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
323
- >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS
324
- 0.9200...
325
- The brevity penalty doesn't depend on reference order. More importantly,
326
- when two reference sentences are at the same distance, the shortest
327
- reference sentence length is used.
328
- >>> references = [['a'] * 13, ['a'] * 11]
329
- >>> hypothesis = ['a'] * 12
330
- >>> hyp_len = len(hypothesis)
331
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
332
- >>> bp1 = brevity_penalty(closest_ref_len, hyp_len)
333
- >>> hyp_len = len(hypothesis)
334
- >>> closest_ref_len = closest_ref_length(reversed(references), hyp_len)
335
- >>> bp2 = brevity_penalty(closest_ref_len, hyp_len)
336
- >>> bp1 == bp2 == 1
337
- True
338
- A test example from mteval-v13a.pl (starting from the line 705):
339
- >>> references = [['a'] * 11, ['a'] * 8]
340
- >>> hypothesis = ['a'] * 7
341
- >>> hyp_len = len(hypothesis)
342
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
343
- >>> brevity_penalty(closest_ref_len, hyp_len) # doctest: +ELLIPSIS
344
- 0.8668...
345
- >>> references = [['a'] * 11, ['a'] * 8, ['a'] * 6, ['a'] * 7]
346
- >>> hypothesis = ['a'] * 7
347
- >>> hyp_len = len(hypothesis)
348
- >>> closest_ref_len = closest_ref_length(references, hyp_len)
349
- >>> brevity_penalty(closest_ref_len, hyp_len)
350
- 1.0
351
- :param hyp_len: The length of the hypothesis for a single sentence OR the
352
- sum of all the hypotheses' lengths for a corpus
353
- :type hyp_len: int
354
- :param closest_ref_len: The length of the closest reference for a single
355
- hypothesis OR the sum of all the closest references for every hypotheses.
356
- :type closest_ref_len: int
357
- :return: BLEU's brevity penalty.
358
- :rtype: float
359
- """
360
- if hyp_len > closest_ref_len:
361
- return 1
362
- # If hypothesis is empty, brevity penalty = 0 should result in BLEU = 0.0
363
- elif hyp_len == 0:
364
- return 0
365
- else:
366
- return math.exp(1 - closest_ref_len / hyp_len)
367
-
368
-
369
- class SmoothingFunction:
370
- """
371
- This is an implementation of the smoothing techniques
372
- for segment-level BLEU scores that was presented in
373
- Boxing Chen and Collin Cherry (2014) A Systematic Comparison of
374
- Smoothing Techniques for Sentence-Level BLEU. In WMT14.
375
- http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf
376
- """
377
-
378
- def __init__(self, epsilon=0.1, alpha=5, k=5):
379
- """
380
- This will initialize the parameters required for the various smoothing
381
- techniques, the default values are set to the numbers used in the
382
- experiments from Chen and Cherry (2014).
383
- >>> hypothesis1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures',
384
- ... 'that', 'the', 'military', 'always', 'obeys', 'the',
385
- ... 'commands', 'of', 'the', 'party']
386
- >>> reference1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures',
387
- ... 'that', 'the', 'military', 'will', 'forever', 'heed',
388
- ... 'Party', 'commands']
389
- >>> chencherry = SmoothingFunction()
390
- >>> print(sentence_bleu([reference1], hypothesis1)) # doctest: +ELLIPSIS
391
- 0.4118...
392
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method0)) # doctest: +ELLIPSIS
393
- 0.4118...
394
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method1)) # doctest: +ELLIPSIS
395
- 0.4118...
396
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method2)) # doctest: +ELLIPSIS
397
- 0.4489...
398
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method3)) # doctest: +ELLIPSIS
399
- 0.4118...
400
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method4)) # doctest: +ELLIPSIS
401
- 0.4118...
402
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method5)) # doctest: +ELLIPSIS
403
- 0.4905...
404
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method6)) # doctest: +ELLIPSIS
405
- 0.4135...
406
- >>> print(sentence_bleu([reference1], hypothesis1, smoothing_function=chencherry.method7)) # doctest: +ELLIPSIS
407
- 0.4905...
408
- :param epsilon: the epsilon value use in method 1
409
- :type epsilon: float
410
- :param alpha: the alpha value use in method 6
411
- :type alpha: int
412
- :param k: the k value use in method 4
413
- :type k: int
414
- """
415
- self.epsilon = epsilon
416
- self.alpha = alpha
417
- self.k = k
418
-
419
- def method0(self, p_n, *args, **kwargs):
420
- """
421
- No smoothing.
422
- """
423
- p_n_new = []
424
- for i, p_i in enumerate(p_n):
425
- if p_i[0] != 0:
426
- p_n_new.append(p_i)
427
- else:
428
- _msg = str(
429
- "\nThe hypothesis contains 0 counts of {}-gram overlaps.\n"
430
- "Therefore the BLEU score evaluates to 0, independently of\n"
431
- "how many N-gram overlaps of lower order it contains.\n"
432
- "Consider using lower n-gram order or use "
433
- "SmoothingFunction()"
434
- ).format(i + 1)
435
- warnings.warn(_msg)
436
- # When numerator==0 where denonminator==0 or !=0, the result
437
- # for the precision score should be equal to 0 or undefined.
438
- # Due to BLEU geometric mean computation in logarithm space,
439
- # we we need to take the return sys.float_info.min such that
440
- # math.log(sys.float_info.min) returns a 0 precision score.
441
- p_n_new.append(sys.float_info.min)
442
- return p_n_new
443
-
444
- def method1(self, p_n, *args, **kwargs):
445
- """
446
- Smoothing method 1: Add *epsilon* counts to precision with 0 counts.
447
- """
448
- return [
449
- ((p_i[0] + self.epsilon), p_i[1])
450
- if p_i[0] == 0
451
- else p_i
452
- for p_i in p_n
453
- ]
454
-
455
- def method2(self, p_n, *args, **kwargs):
456
- """
457
- Smoothing method 2: Add 1 to both numerator and denominator from
458
- Chin-Yew Lin and Franz Josef Och (2004) Automatic evaluation of
459
- machine translation quality using longest common subsequence and
460
- skip-bigram statistics. In ACL04.
461
- """
462
- return [
463
- (p_i[0] + 1, p_i[1] + 1)
464
- for p_i in p_n
465
- ]
466
-
467
- def method3(self, p_n, *args, **kwargs):
468
- """
469
- Smoothing method 3: NIST geometric sequence smoothing
470
- The smoothing is computed by taking 1 / ( 2^k ), instead of 0, for each
471
- precision score whose matching n-gram count is null.
472
- k is 1 for the first 'n' value for which the n-gram match count is null/
473
- For example, if the text contains:
474
- - one 2-gram match
475
- - and (consequently) two 1-gram matches
476
- the n-gram count for each individual precision score would be:
477
- - n=1 => prec_count = 2 (two unigrams)
478
- - n=2 => prec_count = 1 (one bigram)
479
- - n=3 => prec_count = 1/2 (no trigram, taking 'smoothed' value of 1 / ( 2^k ), with k=1)
480
- - n=4 => prec_count = 1/4 (no fourgram, taking 'smoothed' value of 1 / ( 2^k ), with k=2)
481
- """
482
- incvnt = 1 # From the mteval-v13a.pl, it's referred to as k.
483
- for i, p_i in enumerate(p_n):
484
- if p_i.numerator == 0:
485
- p_n[i] = 1 / (2 ** incvnt * p_i.denominator)
486
- incvnt += 1
487
- return p_n
488
-
489
- def method4(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
490
- """
491
- Smoothing method 4:
492
- Shorter translations may have inflated precision values due to having
493
- smaller denominators; therefore, we give them proportionally
494
- smaller smoothed counts. Instead of scaling to 1/(2^k), Chen and Cherry
495
- suggests dividing by 1/ln(len(T)), where T is the length of the translation.
496
- """
497
- hyp_len = hyp_len if hyp_len else len(hypothesis)
498
- for i, p_i in enumerate(p_n):
499
- if p_i.numerator == 0 and hyp_len != 0:
500
- incvnt = i + 1 * self.k / math.log(
501
- hyp_len
502
- ) # Note that this K is different from the K from NIST.
503
- p_n[i] = incvnt / p_i.denominator
504
- return p_n
505
-
506
- def method5(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
507
- """
508
- Smoothing method 5:
509
- The matched counts for similar values of n should be similar. To a
510
- calculate the n-gram matched count, it averages the n−1, n and n+1 gram
511
- matched counts.
512
- """
513
- hyp_len = hyp_len if hyp_len else len(hypothesis)
514
- m = {}
515
- # Requires an precision value for an addition ngram order.
516
- p_n_plus1 = p_n + [modified_precision(references, hypothesis, 5)]
517
- m[-1] = p_n[0] + 1
518
- for i, p_i in enumerate(p_n):
519
- p_n[i] = (m[i - 1] + p_i + p_n_plus1[i + 1]) / 3
520
- m[i] = p_n[i]
521
- return p_n
522
-
523
- def method6(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
524
- """
525
- Smoothing method 6:
526
- Interpolates the maximum likelihood estimate of the precision *p_n* with
527
- a prior estimate *pi0*. The prior is estimated by assuming that the ratio
528
- between pn and pn−1 will be the same as that between pn−1 and pn−2; from
529
- Gao and He (2013) Training MRF-Based Phrase Translation Models using
530
- Gradient Ascent. In NAACL.
531
- """
532
- hyp_len = hyp_len if hyp_len else len(hypothesis)
533
- # This smoothing only works when p_1 and p_2 is non-zero.
534
- # Raise an error with an appropriate message when the input is too short
535
- # to use this smoothing technique.
536
- assert p_n[2], "This smoothing method requires non-zero precision for bigrams."
537
- for i, p_i in enumerate(p_n):
538
- if i in [0, 1]: # Skips the first 2 orders of ngrams.
539
- continue
540
- else:
541
- pi0 = 0 if p_n[i - 2] == 0 else p_n[i - 1] ** 2 / p_n[i - 2]
542
- # No. of ngrams in translation that matches the reference.
543
- m = p_i.numerator
544
- # No. of ngrams in translation.
545
- l = sum(1 for _ in ngrams(hypothesis, i + 1))
546
- # Calculates the interpolated precision.
547
- p_n[i] = (m + self.alpha * pi0) / (l + self.alpha)
548
- return p_n
549
-
550
- def method7(self, p_n, references, hypothesis, hyp_len=None, *args, **kwargs):
551
- """
552
- Smoothing method 7:
553
- Interpolates methods 4 and 5.
554
- """
555
- hyp_len = hyp_len if hyp_len else len(hypothesis)
556
- p_n = self.method4(p_n, references, hypothesis, hyp_len)
557
- p_n = self.method5(p_n, references, hypothesis, hyp_len)
558
- return p_n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codebleu.py → metric-codebleu.py RENAMED
@@ -11,55 +11,56 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- """TODO: Add a description here."""
15
 
16
- import evaluate
17
  import datasets
18
- import eval
19
- import os
20
 
21
 
22
- # TODO: Add BibTeX citation
23
  _CITATION = """\
24
- @misc{2009.10297,
25
- Author = {Shuo Ren and Daya Guo and Shuai Lu and Long Zhou and Shujie Liu and Duyu Tang and Neel Sundaresan and Ming Zhou and Ambrosio Blanco and Shuai Ma},
26
- Title = {CodeBLEU: a Method for Automatic Evaluation of Code Synthesis},
27
- Year = {2020},
28
- Eprint = {arXiv:2009.10297},
 
 
29
  }
30
  """
31
 
32
- # TODO: Add description of the module here
33
  _DESCRIPTION = """\
34
- This new module is designed to calculate the CodeBLEU score for code generation tasks.
35
  """
36
 
37
 
38
- # TODO: Add description of the arguments of the module here
39
  _KWARGS_DESCRIPTION = """
40
- Calculates how good are predictions given some references, using certain scores
41
  Args:
42
- predictions: list of predictions to score.
43
- references: list of reference for each prediction.
 
 
 
 
44
  Returns:
45
- ngram_match_score
46
- weighted_ngram_match_score
47
- syntax_match_score
48
- dataflow_match_score
49
- code_bleu_score
50
  Examples:
51
- Examples should be written in doctest format, and should illustrate how
52
- to use the function.
53
-
54
- >>> my_new_module = evaluate.load("my_new_module")
55
- >>> results = my_new_module.compute(references=["def add(a, b): return a + b"], predictions=["def add(a, b): return a + b"])
56
  >>> print(results)
57
- {'ngram_match_score': 1.0, 'weighted_ngram_match_score': 1.0, 'syntax_match_score': 1.0, 'dataflow_match_score': 1.0, 'code_bleu_score': 1.0}
58
  """
59
 
60
 
61
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
62
  class codebleu(evaluate.Metric):
 
 
63
  def _info(self):
64
  return evaluate.MetricInfo(
65
  # This is the description that will appear on the modules page.
@@ -68,15 +69,38 @@ class codebleu(evaluate.Metric):
68
  citation=_CITATION,
69
  inputs_description=_KWARGS_DESCRIPTION,
70
  # This defines the format of each prediction and reference
71
- features=datasets.Features({
72
- 'predictions': datasets.Value('string'),
73
- 'references': datasets.Value('string'),
74
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
 
77
- def _download_and_prepare(self, dl_manager):
78
- pass
79
 
80
- def _compute(self, predictions, references):
 
 
 
 
 
81
  """Returns the scores"""
82
- return eval.code_bleu.calc(predictions, references)
 
 
 
 
 
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
 
15
+ from codebleu import calc_codebleu
16
  import datasets
17
+ import evaluate
 
18
 
19
 
 
20
  _CITATION = """\
21
+ @misc{ren2020codebleu,
22
+ title={CodeBLEU: a Method for Automatic Evaluation of Code Synthesis},
23
+ author={Shuo Ren and Daya Guo and Shuai Lu and Long Zhou and Shujie Liu and Duyu Tang and Neel Sundaresan and Ming Zhou and Ambrosio Blanco and Shuai Ma},
24
+ year={2020},
25
+ eprint={2009.10297},
26
+ archivePrefix={arXiv},
27
+ primaryClass={cs.SE}
28
  }
29
  """
30
 
 
31
  _DESCRIPTION = """\
32
+ Unofficial `CodeBLEU` implementation that supports Linux and MacOS.
33
  """
34
 
35
 
 
36
  _KWARGS_DESCRIPTION = """
37
+ Calculate a weighted combination of `n-gram match (BLEU)`, `weighted n-gram match (BLEU-weighted)`, `AST match` and `data-flow match` scores.
38
  Args:
39
+ predictions: list of predictions to score. Each predictions
40
+ should be a string with tokens separated by spaces.
41
+ references: list of reference for each prediction. Each
42
+ reference should be a string with tokens separated by spaces.
43
+ language: programming language in ['java','js','c_sharp','php','c','python','cpp'].
44
+ weights: tuple of 4 floats to use as weights for scores. Defaults to (0.25, 0.25, 0.25, 0.25).
45
  Returns:
46
+ codebleu: resulting `CodeBLEU` score,
47
+ ngram_match_score: resulting `n-gram match (BLEU)` score,
48
+ weighted_ngram_match_score: resulting `weighted n-gram match (BLEU-weighted)` score,
49
+ syntax_match_score: resulting `AST match` score,
50
+ dataflow_match_score: resulting `data-flow match` score,
51
  Examples:
52
+ >>> metric = evaluate.load("k4black/codebleu")
53
+ >>> ref = "def sum ( first , second ) :\n return second + first"
54
+ >>> pred = "def add ( a , b ) :\n return a + b"
55
+ >>> results = metric.compute(references=[ref], predictions=[pred], language="python")
 
56
  >>> print(results)
 
57
  """
58
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class codebleu(evaluate.Metric):
62
+ """CodeBLEU metric from CodexGLUE"""
63
+
64
  def _info(self):
65
  return evaluate.MetricInfo(
66
  # This is the description that will appear on the modules page.
 
69
  citation=_CITATION,
70
  inputs_description=_KWARGS_DESCRIPTION,
71
  # This defines the format of each prediction and reference
72
+ features=[
73
+ datasets.Features(
74
+ {
75
+ "predictions": datasets.Value("string", id="sequence"),
76
+ "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
77
+ "lang": datasets.Value("string"),
78
+ "weights": datasets.Value("string")
79
+ }
80
+ )
81
+ ],
82
+ # Homepage of the module for documentation
83
+ homepage="https://github.com/k4black/codebleu",
84
+ # Additional links to the codebase or references
85
+ codebase_urls=["https://github.com/k4black/codebleu"],
86
+ reference_urls=[
87
+ "https://github.com/k4black/codebleu",
88
+ "https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/code-to-code-trans/evaluator",
89
+ "https://arxiv.org/abs/2009.10297",
90
+ ],
91
  )
92
 
 
 
93
 
94
+ def _compute(
95
+ self,
96
+ predictions,
97
+ references,
98
+ lang,weights=(0.25, 0.25, 0.25, 0.25)
99
+ ):
100
  """Returns the scores"""
101
+ return calc_codebleu(
102
+ references=references,
103
+ predictions=predictions,
104
+ lang=lang,
105
+ weights=weights
106
+ )
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@main
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+ codebleu>=0.2.0,<1.0.0