Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

click accuracy and scroll accuracy #6

Open
njucckevin opened this issue Oct 16, 2023 · 4 comments
Open

click accuracy and scroll accuracy #6

njucckevin opened this issue Oct 16, 2023 · 4 comments

Comments

@njucckevin
Copy link

Hi, thanks for the good work.
I wonder how the click accuracy and scroll accuracy is calculated in section 5.1. I can not find such code in main.py and action_matching.py.
Thanks~

@njucckevin
Copy link
Author

Also problem with typed text accuracy. In main.py, the text accuracy is calculated as:
if check_match and (action_1_typed_text in action_2_typed_text or action_2_typed_text in action_1_typed_text): text_correct += 1. In this case, the text accuracy will be lower than the total action accuracy (as indicated by check_match). However, in section 5.1 the text accuracy is over 90%.

@cooelf
Copy link
Owner

cooelf commented Oct 16, 2023

Hi, please refer to the following codes. In Section 5.1, the text accuracy is only measured if the predicted and reference texts are matched or overlapped (see below). It is not calculated by the code in main.py.

pred = eval("{" + pred + "}")
action_1_touch_yx = eval(pred["touch_point"])
action_1_lift_yx = eval(pred["lift_point"])
action_1_action_type = int(pred["action_type"])
action_1_typed_text = pred["typed_text"].lower()

try:
    reference = pred_dict["target"]
    lift_id = [i for i,x in enumerate(reference) if x == ","][4] - 1
    lift_punk = reference[lift_id]
    if lift_punk != "'":
        str_list = list(reference)
        str_list.insert(lift_id + 1, "'")
        reference = ''.join(str_list)
    reference = eval("{" + reference + "}")
except:
    print("reference error")
    continue


action_2_touch_yx = eval(reference["touch_point"])
action_2_lift_yx = eval(reference["lift_point"])
action_2_action_type = int(reference["action_type"])
action_2_typed_text = reference["typed_text"].lower()

annotation_positions = gold_ui[idx]

try:
    check_match = action_matching.check_actions_match(
        action_1_touch_yx,
        action_1_lift_yx,
        action_1_action_type,
        action_2_touch_yx,
        action_2_lift_yx,
        action_2_action_type,
        annotation_positions
    )
except Exception as exc:
    print(idx, action_1_touch_yx, action_1_lift_yx)
    check_match = False
    match_label = "invalid"

episode_acc = 0

if check_match:
    partial_correct += 1
    match_label = 1
else:
    match_label = 0

if action_1_action_type == action_2_action_type:
    type_correct += 1

# dual
if action_2_action_type == 4:
    if is_tap_action(action_2_touch_yx, action_2_lift_yx):
        click_num += 1
        if match_label:
            click_correct += 1
    else:
        scroll_num += 1
        if match_label:
            scroll_correct += 1

# type
if action_2_action_type == 3:
    text_num += 1
    if (action_2_typed_text == action_1_typed_text) or (action_1_typed_text in action_2_typed_text) or (action_2_typed_text in action_1_typed_text):
        text_correct += 1   

@njucckevin
Copy link
Author

thanks for the explanation.

@frankfengdi
Copy link

Thanks for the great work!

I have a question regarding evaluation metrics. In main.py, the metrics are computed in the following code snippets.

Seems that the action accuracy score is computed by the number of action_correct / number of all frames. This is different from the original definition of action accuracy score, which is action_correct / len(episode), averaged by the number of episodes.

Correct me if I misunderstand the metrics ...

Could you share the full eval mentioned here? #6 (comment)

    predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len) 
    if trainer.is_world_process_zero():
        preds, targets = predict_results.predictions, predict_results.label_ids
        preds= np.where(preds != -100, preds, tokenizer.pad_token_id)
        preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        action_correct = 0
        text_correct = 0
        type_correct = 0
        
        reference_test_positions = test_set.anno_positions

        output_data = []

        pattern = r'(?<=Action Decision:\s).*'

        assert len(preds) == len(targets)  == len(reference_test_positions)
        for idx, pred in enumerate(preds):
            try:
                result = re.search(pattern, targets[idx])
                target_text = result.group(0)
                target_text = target_text.strip()

                reference = eval("{" + target_text + "}")
            except:
                print("reference error")
                continue

            try:
                result = re.search(pattern, preds[idx])
                pred_text = result.group(0)
                pred_text = pred_text.strip()

                pred = eval("{" + pred_text + "}")
                action_1_touch_yx = eval(pred["touch_point"])
                action_1_lift_yx = eval(pred["lift_point"])
                action_1_action_type = action_type.ActionType[pred["action_type"]].value
                action_1_typed_text = pred["typed_text"].lower()
                action_1_typed_text = action_1_typed_text.strip()

                action_1_wrap = f'"action_type": "{action_1_action_type}", "touch_point": "{action_1_touch_yx}", "lift_point": "{action_1_lift_yx}", "typed_text": "{action_1_typed_text}"'
                action_1_wrap = action_1_wrap.replace('"', "'")
            except:
                pred = '{ "action_type": "TYPE", "touch_point": "[-1.0, -1.0]", "lift_point": "[-1.0, -1.0]", "typed_text": "Invalid"}'
            
            action_2_touch_yx = eval(reference["touch_point"])
            action_2_lift_yx = eval(reference["lift_point"])
            action_2_action_type = action_type.ActionType[reference["action_type"]].value
            action_2_typed_text = reference["typed_text"].lower()
            
            action_2_wrap = f'"action_type": "{action_2_action_type}", "touch_point": "{action_2_touch_yx}", "lift_point": "{action_2_lift_yx}", "typed_text": "{action_2_typed_text}"'
            action_2_wrap = action_2_wrap.replace('"', "'")

            annotation_positions = reference_test_positions[idx]

            try:
                check_match = action_matching.check_actions_match(
                    action_1_touch_yx,
                    action_1_lift_yx,
                    action_1_action_type,
                    action_2_touch_yx,
                    action_2_lift_yx,
                    action_2_action_type,
                    annotation_positions
                )

            except Exception as exc:
                print(idx, action_1_touch_yx, action_1_lift_yx)
                check_match = False
                match_label = "invalid"

            if check_match:
                action_correct += 1
                match_label = 1
            else:
                match_label = 0
            if check_match and (action_1_typed_text in action_2_typed_text or action_2_typed_text in action_1_typed_text):
                text_correct += 1
            if action_1_action_type == action_2_action_type:
                type_correct += 1

            action_data = {"pred": action_1_wrap, "target": action_2_wrap, "match_label": match_label}
            output_data.append(action_data)

        metrics["accuracy"] = "{:.2f}".format(action_correct/len(targets) * 100)
        metrics["text_acc"] = "{:.2f}".format(text_correct/len(targets) * 100)
        metrics["type_acc"] = "{:.2f}".format(type_correct/len(targets) * 100)
        metrics["action_correct"] = action_correct
        metrics["text_correct"] = text_correct
        metrics["type_correct"] = type_correct
        metrics["total_num"] = len(targets)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants