Skip to content

[Relax][PyTorch] CrossEntropyLoss #17863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ def apply( # pylint: disable=too-many-locals

# Align the number of block iters of the last block.
num_last_block_iter = len(block_infos[-1].dom_kind())

# If the last block is a scalar value, there is nothing left to
# tile/parallelise, and `iters` is an empty tuple.
# Add a unit thread loop so the final write happens inside a valid
# GPU thread environment.
if num_last_block_iter == 0:
# Put every block (both the running reductions and the final
# scalar write) inside a trivial GPU thread. The very first block
# gets a `blockIdx.x` wrapper so that kernels still have a unique
# block scope.
for i, info in enumerate(block_infos):
loop_rv = sch.add_unit_loop(info.block_rv)
if i == 0:
sch.bind(loop_rv, "blockIdx.x")
else:
sch.bind(loop_rv, "threadIdx.x")

return sch

if num_last_block_iter < len(dom_kind):

def f_layout_mapping(*iters):
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,25 @@ def _conv3d(self, node: fx.Node) -> relax.Var:
groups=groups,
)

def _cross_entropy_loss(
self,
preds: relax.Expr,
targets: relax.Expr,
weights: Optional[relax.Expr],
reduction: str,
ignore_index: int,
) -> relax.Expr:
log_probs = relax.op.nn.log_softmax(preds)
return self.block_builder.emit(
relax.op.nn.nll_loss(
log_probs,
targets,
weights,
reduction,
ignore_index,
)
)

def _einsum(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _reciprocal(self, node: fx.Node) -> relax.Var:

########## Neural Network ##########

def _batch_norm(self, node: fx.Node, training) -> relax.Var:
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
import numpy as np

x = self.env[node.args[0]]
Expand Down Expand Up @@ -113,6 +113,14 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
training = False
return self._batch_norm(node, training)

def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
targets = self.env[node.args[1]]
weight = self.env.get(node.args[2], None) if len(node.args) > 2 else None
reduction = node.kwargs.get("reduction", "mean")
ignore_index = node.kwargs.get("ignore_index", -100)
return self._cross_entropy_loss(preds, targets, weight, reduction, ignore_index)

def _group_norm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
num_groups = node.args[1]
Expand Down Expand Up @@ -386,6 +394,7 @@ def create_convert_map(
"conv1d.default": self._conv1d,
"conv2d.default": self._conv2d,
"conv3d.default": self._conv3d,
"cross_entropy_loss.default": self._cross_entropy_default,
"einsum.default": self._einsum,
"embedding.default": lambda node: self._embedding_impl(
self.env[node.args[1]], self.env[node.args[0]]
Expand Down
17 changes: 7 additions & 10 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,7 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr:
weights = self.env.get(node.kwargs["weight"], None)
reduction = node.kwargs["reduction"]
ignore_index = node.kwargs["ignore_index"]

return self.block_builder.emit(
relax.op.nn.nll_loss(
relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index
)
)
return self._cross_entropy_loss(preds, targets, weights, reduction, ignore_index)

def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
preds = self.env[node.args[0]]
Expand All @@ -282,10 +277,12 @@ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
reduction = module.reduction
ignore_index = module.ignore_index

return self.block_builder.emit(
relax.op.nn.nll_loss(
relax.op.nn.log_softmax(preds), targets, weights, reduction, ignore_index
)
return self._cross_entropy_loss(
preds,
targets,
weights,
reduction,
ignore_index,
)

def _embedding_module(self, node: fx.Node) -> relax.Var:
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import operator
import pytest
import torch
from torch import nn
from torch.nn import Module
from torch.export import export

Expand Down Expand Up @@ -4419,5 +4420,35 @@ def main(
verify_model(Eye2(), example_args2, {}, Expected2)


def test_cross_entropy():
class CrossEntropyModule(Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to include a test case for functional cross entropy.

def __init__(self):
super().__init__()
self.criterion = nn.CrossEntropyLoss()
self.target = torch.tensor([0, 1, 2, 1])

def forward(self, x):
return self.criterion(x, self.target)

@tvm.script.ir_module
class Expected1:
@R.function
def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1)
lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
lv,
targets=R.const([0, 1, 2, 1], dtype="int64"),
reduction="mean",
ignore_index=-100,
)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
R.output(gv)
return gv

example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
verify_model(CrossEntropyModule(), example_args1, {}, Expected1)


if __name__ == "__main__":
tvm.testing.main()