-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[Relax][PyTorch] CrossEntropyLoss #17863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -588,5 +584,332 @@ def forward(self, x): | |||
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) | |||
|
|||
|
|||
@tvm.testing.parametrize_targets("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not confident whether we need e2e tests for every ops.
The upside is we may not have to care about pytorch incompatible API update such as #17680 we hit recently.
The downside is that they increase the CI pressure, which results in a slow development speed.
Maybe nightly is a good place for e2e tests?
cc @Hzfengsy @tqchen @yongwww
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should avoid exccessive e2e tests and infavor of unit tests, we should move the execution based tests to something likely a nightly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I will undo that and also update #17862
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One reason why I added those tests is that we already had support for torch.sort (with a unit test), but when I tested e2e I realized that we had differerent behavior than pytorch. Maybe it's worth it indeed to add those tests to nightly? If so, what is the way to do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nightly tests are located at tests/python/nightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CrossEntropyLoss support looks good to me.
@mshr-h @tqchen @MasterJH5574 I removed the e2e tests. |
@@ -4419,5 +4420,35 @@ def main( | |||
verify_model(Eye2(), example_args2, {}, Expected2) | |||
|
|||
|
|||
def test_cross_entropy(): | |||
class CrossEntropyModule(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be better to include a test case for functional cross entropy.
Add support for nn.CrossEntropyLoss in exported program translator.
general_reduction.py would fail when the reduction would reduce to a scalar. This PR allows the last block being a scalar.