-
Notifications
You must be signed in to change notification settings - Fork 64
submit #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
submit #194
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,92 +1,204 @@ | ||||||
| import json, os, argparse | ||||||
| from tqdm import tqdm, trange | ||||||
| from transformers import AutoTokenizer | ||||||
|
|
||||||
| # from method import build_prompt, select_examples, annotate | ||||||
|
|
||||||
| from method import build_prompt, select_examples | ||||||
|
|
||||||
| from method import annotate_nvidia as annotate # For Nvidia GPU | ||||||
| # from method import annotate_ascend as annotate # For Huawei Ascend | ||||||
|
|
||||||
| TASK_FILES = { | ||||||
| 1: './data/openseek-1_closest_integers.json', | ||||||
| 2: './data/openseek-2_count_nouns_verbs.json', | ||||||
| 3: './data/openseek-3_collatz_conjecture.json', | ||||||
| 4: './data/openseek-4_conala_concat_strings.json', | ||||||
| 5: './data/openseek-5_semeval_2018_task1_tweet_sadness_detection.json', | ||||||
| 6: './data/openseek-6_mnli_same_genre_classification.json', | ||||||
| 7: './data/openseek-7_jeopardy_answer_generation_all.json', | ||||||
| 8: '../data/openseek-8_kernel_generation.json', | ||||||
| } | ||||||
|
|
||||||
| def parser_args(): | ||||||
| parser = argparse.ArgumentParser() | ||||||
| parser.add_argument('--task_id', type=int, required=True, | ||||||
| help='Task ID to evaluate, should be in [1, 7].') | ||||||
| parser.add_argument('--max_input_length', type=int, default=10_000, | ||||||
| help='Maximum input length for the model.') | ||||||
| parser.add_argument('--log_path_prefix', type=str, | ||||||
| default='../outputs/', | ||||||
| help='Prefix path to save the evaluation logs.') | ||||||
| parser.add_argument('--tokenizer_path', type=str, | ||||||
| default='/share/project/wuhaiming/spaces/data_agent/OpenSeek-main/openseek/competition/LongContext-ICL-Annotation/src/Qwen3-4B') | ||||||
| args = parser.parse_args() | ||||||
| return args | ||||||
|
|
||||||
| def evaluate(task_id:int, | ||||||
| qwen_tokenizer:AutoTokenizer, | ||||||
| max_input_length:int=128_000, | ||||||
| log_path_prefix:str='./outputs/' | ||||||
| )->float: | ||||||
| assert task_id in [i for i in range(1, 9)],\ | ||||||
| f"task_id should be in [1, 8], but got {task_id}." | ||||||
|
|
||||||
| task_file = TASK_FILES[task_id] | ||||||
| with open(task_file, 'r') as f: | ||||||
| task_dict = json.load(f) | ||||||
|
|
||||||
| task_name = task_dict['task_name'] | ||||||
| task_description = task_dict['Definition'][0] | ||||||
| icl_examples = task_dict['examples'][:100] | ||||||
| test_samples = task_dict['test_samples'] | ||||||
|
|
||||||
| version = 1 | ||||||
| output_file = f'{log_path_prefix}openseek-{task_id}-v{version}.jsonl' | ||||||
| output_path = os.path.dirname(output_file) | ||||||
| os.makedirs(output_path, exist_ok=True) | ||||||
| while os.path.exists(output_file): | ||||||
| version += 1 | ||||||
| output_file = f'{log_path_prefix}openseek-{task_id}-v{version}.jsonl' | ||||||
| with open(output_file, 'w') as f: | ||||||
| pass | ||||||
|
|
||||||
| examples_str = None | ||||||
| for test_sample in tqdm(test_samples, desc=f'Evaluation on Task {task_id}: {task_name}'): | ||||||
| test_record = dict() | ||||||
|
|
||||||
| test_sample_id = test_sample['id'] | ||||||
| test_record['test_sample_id'] = test_sample_id | ||||||
|
|
||||||
|
|
||||||
| text2annotate = test_sample['input'] | ||||||
| prompt = build_prompt(task_description, text2annotate) | ||||||
| if examples_str is None: | ||||||
| examples_str = select_examples(icl_examples, task_description, text2annotate) | ||||||
| input_prompt = prompt.replace("[[EXAMPLES]]\n\n", examples_str+'\n\n') | ||||||
|
|
||||||
| # tokenized_input = qwen_tokenizer(input_prompt, return_tensors="pt") | ||||||
| # if tokenized_input['input_ids'].shape[1] > max_input_length: | ||||||
| # test_record['prediction'] = None | ||||||
| # else: | ||||||
| # prediction = annotate(input_prompt) | ||||||
| # test_record['prediction'] = prediction | ||||||
| prediction = annotate(input_prompt) | ||||||
| test_record['prediction'] = prediction | ||||||
| with open(output_file, 'a') as f: | ||||||
| f.write(json.dumps(test_record)+'\n') | ||||||
|
|
||||||
| if __name__ == '__main__': | ||||||
| args = parser_args() | ||||||
| qwen_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) | ||||||
| import json, os, argparse | ||||||
| from tqdm import tqdm, trange | ||||||
| from transformers import AutoTokenizer | ||||||
|
|
||||||
| # from method import build_prompt, select_examples, annotate | ||||||
|
|
||||||
| from method import build_prompt, select_examples | ||||||
|
|
||||||
| # from method import annotate_nvidia as annotate # For Nvidia GPU | ||||||
| from method import annotate_ascend as annotate # For Huawei Ascend | ||||||
| from method import annotate_batch | ||||||
| from method import annotate_with_self_consistency | ||||||
| from method import annotate_with_multi_turn_dialog | ||||||
|
|
||||||
| DATA_DIR = '/root/flagos/OpenSeek/openseek/competition/LongContext-ICL-Annotation/data' | ||||||
| OUTPUT_DIR = '/root/flagos/OpenSeek/openseek/competition/LongContext-ICL-Annotation/outputs' | ||||||
|
|
||||||
| TASK_FILES = { | ||||||
| 1: f'{DATA_DIR}/openseek-1_closest_integers.json', | ||||||
| 2: f'{DATA_DIR}/openseek-2_count_nouns_verbs.json', | ||||||
| 3: f'{DATA_DIR}/openseek-3_collatz_conjecture.json', | ||||||
| 4: f'{DATA_DIR}/openseek-4_conala_concat_strings.json', | ||||||
| 5: f'{DATA_DIR}/openseek-5_semeval_2018_task1_tweet_sadness_detection.json', | ||||||
| 6: f'{DATA_DIR}/openseek-6_mnli_same_genre_classification.json', | ||||||
| 7: f'{DATA_DIR}/openseek-7_jeopardy_answer_generation_all.json', | ||||||
| 8: f'{DATA_DIR}/openseek-8_kernel_generation.json', | ||||||
| } | ||||||
|
|
||||||
| def parser_args(): | ||||||
| parser = argparse.ArgumentParser() | ||||||
| parser.add_argument('--task_id', type=int, required=True, | ||||||
| help='Task ID to evaluate, should be in [1, 7].') | ||||||
| parser.add_argument('--max_input_length', type=int, default=10_000, | ||||||
| help='Maximum input length for the model.') | ||||||
| parser.add_argument('--log_path_prefix', type=str, | ||||||
| default='/root/flagos/OpenSeek/openseek/competition/LongContext-ICL-Annotation/outputs/', | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value for
Suggested change
|
||||||
| help='Prefix path to save the evaluation logs.') | ||||||
| parser.add_argument('--tokenizer_path', type=str, | ||||||
| default='/root/flagos/Qwen3-4B') | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| args = parser.parse_args() | ||||||
| return args | ||||||
|
|
||||||
| def parse_category_and_clue(text: str) -> tuple[str, str]: | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
| lines = text.strip().split('\n') | ||||||
|
Comment on lines
+44
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| category = "" | ||||||
| clue = "" | ||||||
|
|
||||||
| for line in lines: | ||||||
| if line.startswith('Category:'): | ||||||
| category = line.replace('Category:', '').strip() | ||||||
| elif line.startswith('Clue:'): | ||||||
| clue = line.replace('Clue:', '').strip() | ||||||
|
|
||||||
| return category, clue | ||||||
|
|
||||||
| def evaluate(task_id:int, | ||||||
| qwen_tokenizer:AutoTokenizer, | ||||||
| max_input_length:int=128_000, | ||||||
| log_path_prefix:str='./outputs/' | ||||||
| )->float: | ||||||
| assert task_id in [i for i in range(1, 9)],\ | ||||||
| f"task_id should be in [1, 8], but got {task_id}." | ||||||
|
|
||||||
| task_file = TASK_FILES[task_id] | ||||||
| with open(task_file, 'r') as f: | ||||||
| task_dict = json.load(f) | ||||||
|
|
||||||
| task_name = task_dict['task_name'] | ||||||
| task_description = task_dict['Definition'][0] | ||||||
| icl_examples = task_dict['examples'][:50] | ||||||
| test_samples = task_dict['test_samples'] | ||||||
|
|
||||||
| version = 1 | ||||||
| output_file = f'{log_path_prefix}openseek-{task_id}-v{version}.jsonl' | ||||||
| output_path = os.path.dirname(output_file) | ||||||
| os.makedirs(output_path, exist_ok=True) | ||||||
| while os.path.exists(output_file): | ||||||
| version += 1 | ||||||
| output_file = f'{log_path_prefix}openseek-{task_id}-v{version}.jsonl' | ||||||
| with open(output_file, 'w') as f: | ||||||
| pass | ||||||
|
|
||||||
| examples_str = None | ||||||
| batch_size = 8 | ||||||
| prompts_batch = [] | ||||||
| sample_ids_batch = [] | ||||||
|
|
||||||
| # Task 8 is code generation, needs more tokens and different post-processing | ||||||
| max_tokens = 1024 if task_id == 8 else 256 | ||||||
| use_count_answer = False if task_id == 8 else True | ||||||
|
|
||||||
|
|
||||||
| use_self_consistency = (task_id == 6) | ||||||
|
|
||||||
|
|
||||||
| use_social_media_enhancement = (task_id == 5) | ||||||
|
|
||||||
|
|
||||||
| use_multi_turn_dialog = (task_id == 7) | ||||||
|
|
||||||
| for test_sample in tqdm(test_samples, desc=f'Evaluation on Task {task_id}: {task_name}'): | ||||||
| test_sample_id = test_sample['id'] | ||||||
| text2annotate = test_sample['input'] | ||||||
|
|
||||||
|
|
||||||
| if use_multi_turn_dialog: | ||||||
|
|
||||||
| category, clue = parse_category_and_clue(text2annotate) | ||||||
|
|
||||||
|
|
||||||
| if examples_str is None: | ||||||
| examples_str = select_examples(icl_examples, task_description, text2annotate, | ||||||
| is_code_generation=False, | ||||||
| use_task_aware=True, task_id=task_id, | ||||||
| use_quality_filter=True, quality_threshold=0.5, | ||||||
| use_diversity=False, use_similarity=False, | ||||||
| use_cot=False, | ||||||
| balance_sentiment=False) | ||||||
|
|
||||||
|
|
||||||
| final_answer, round1_response, round2_response = annotate_with_multi_turn_dialog( | ||||||
| category, clue, examples_str, task_description, max_tokens=max_tokens | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| test_record = { | ||||||
| 'test_sample_id': test_sample_id, | ||||||
| 'prediction': final_answer | ||||||
| } | ||||||
| with open(output_file, 'a') as f: | ||||||
| f.write(json.dumps(test_record)+'\n') | ||||||
| else: | ||||||
|
|
||||||
|
|
||||||
| prompt = build_prompt(task_description, text2annotate, task_id=task_id, use_social_media_enhancement=use_social_media_enhancement) | ||||||
| if examples_str is None: | ||||||
| # Task 8 is code generation task | ||||||
| is_code_generation = (task_id == 8) | ||||||
|
|
||||||
|
|
||||||
| use_cot = (1 <= task_id <= 4) | ||||||
|
|
||||||
| balance_sentiment = (task_id == 5) | ||||||
| examples_str = select_examples(icl_examples, task_description, text2annotate, | ||||||
| is_code_generation=is_code_generation, | ||||||
| use_task_aware=True, task_id=task_id, | ||||||
| use_quality_filter=True, quality_threshold=0.5, | ||||||
| use_diversity=False, use_similarity=False, | ||||||
| use_cot=use_cot, | ||||||
| balance_sentiment=balance_sentiment) | ||||||
| input_prompt = prompt.replace("[[EXAMPLES]]\n\n", examples_str+'\n\n') | ||||||
|
|
||||||
|
|
||||||
| if use_self_consistency: | ||||||
|
|
||||||
| final_answer, confidence, all_predictions = annotate_with_self_consistency( | ||||||
| input_prompt, | ||||||
| num_samples=5, | ||||||
| temperature_range=[0.7, 0.85, 1.0, 1.1, 1.2], | ||||||
| max_tokens=max_tokens, | ||||||
| task_id=task_id, | ||||||
| confidence_threshold=0.4 | ||||||
| ) | ||||||
| test_record = { | ||||||
| 'test_sample_id': test_sample_id, | ||||||
| 'prediction': final_answer, | ||||||
| 'confidence': confidence, | ||||||
| 'all_predictions': all_predictions | ||||||
| } | ||||||
| with open(output_file, 'a') as f: | ||||||
| f.write(json.dumps(test_record)+'\n') | ||||||
| else: | ||||||
|
|
||||||
| prompts_batch.append(input_prompt) | ||||||
| sample_ids_batch.append(test_sample_id) | ||||||
|
|
||||||
| # Process batch when full | ||||||
| if len(prompts_batch) >= batch_size: | ||||||
| results = annotate_batch(prompts_batch, num_workers=4, max_tokens=max_tokens, use_count_answer=use_count_answer, task_id=task_id) | ||||||
| for sid, (pred, _) in zip(sample_ids_batch, results): | ||||||
| test_record = {'test_sample_id': sid, 'prediction': pred} | ||||||
| with open(output_file, 'a') as f: | ||||||
| f.write(json.dumps(test_record)+'\n') | ||||||
| prompts_batch = [] | ||||||
| sample_ids_batch = [] | ||||||
|
|
||||||
|
|
||||||
| if not use_self_consistency and prompts_batch: | ||||||
| results = annotate_batch(prompts_batch, num_workers=4, max_tokens=max_tokens, use_count_answer=use_count_answer, task_id=task_id) | ||||||
| for sid, (pred, _) in zip(sample_ids_batch, results): | ||||||
| test_record = {'test_sample_id': sid, 'prediction': pred} | ||||||
| with open(output_file, 'a') as f: | ||||||
| f.write(json.dumps(test_record)+'\n') | ||||||
|
|
||||||
| if __name__ == '__main__': | ||||||
| args = parser_args() | ||||||
| qwen_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) | ||||||
| evaluate(args.task_id, qwen_tokenizer, args.max_input_length, args.log_path_prefix) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The paths
DATA_DIRandOUTPUT_DIRare hardcoded to absolute paths on a specific machine (/root/flagos/...). This makes the script non-portable and likely to fail in other environments. Consider using relative paths or environment variables to define these directories.