Skip to content

Commit 1b926a8

Browse files
Update asymmetric padding implementation
1 parent e900508 commit 1b926a8

File tree

8 files changed

+251
-148
lines changed

8 files changed

+251
-148
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h

-9
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,6 @@ LogicalResult createDequantizeTensor(ConversionPatternRewriter &rewriter,
130130
Location loc, Value input, Value scale,
131131
Value zeroPoint, Value &output);
132132

133-
// Checks the validity of pooling parameters and stores them in the respective
134-
// vector.
135-
LogicalResult checkAndGetOnnxPoolingOpParameters(
136-
OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype,
137-
std::string autoPad, int64_t spatialRank, Value &input,
138-
SmallVectorImpl<int64_t> &kernelSizeInts,
139-
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
140-
SmallVectorImpl<int64_t> &dilationInts);
141-
142133
} // namespace mlir::torch::onnx_c
143134

144135
#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H

include/torch-mlir/Conversion/TorchToLinalg/Utils.h

+10
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
4747
Value kernelSizeInt, Value strideInt,
4848
bool ceilMode = false);
4949

50+
// Helper function to caculate the output tensor dims for pooling-like ops.
51+
// Along each dim:
52+
// dim_out =
53+
// floor((dim_in + totalPadding - dilation * (kernelSize - 1) - 1) / stride) +
54+
// 1
55+
Value getOutputDimForPoolOps(OpBuilder &b, Location loc, Value in,
56+
int64_t totalPadding, int64_t leftPadding,
57+
Value dilationInt, Value kernelSizeInt,
58+
Value strideInt, bool ceilMode);
59+
5060
// As above but for transposed convolution ops
5161
// Along each dim:
5262
// dim_out =

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

+69-5
Original file line numberDiff line numberDiff line change
@@ -476,14 +476,78 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
476476
"Unimplemented: unranked tensor");
477477
unsigned rank = *maybeRank;
478478

479+
int64_t spatialRank = rank - 2;
479480
SmallVector<int64_t> kernel, padding, strides, dilations,
480481
stridesDilations;
481-
if (failed(checkAndGetOnnxPoolingOpParameters(
482-
binder, rewriter, resultType.getDtype(), autoPad,
483-
/*spatialRank=*/rank - 2,
484-
/*input=*/operand, kernel, strides, padding, dilations)))
482+
483+
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
484+
return rewriter.notifyMatchFailure(binder.op,
485+
"kernel_shape bind failure");
486+
if (kernel.size() != static_cast<size_t>(spatialRank))
487+
return rewriter.notifyMatchFailure(
488+
binder.op, "kernel list size does not match the number of axes");
489+
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
490+
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
491+
if (!padding.empty() &&
492+
padding.size() != static_cast<size_t>(2 * spatialRank))
493+
return rewriter.notifyMatchFailure(
494+
binder.op, "padding list must contain (begin,end) pair for each "
495+
"spatial axis");
496+
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
497+
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
498+
if (!strides.empty() &&
499+
strides.size() != static_cast<size_t>(spatialRank))
500+
return rewriter.notifyMatchFailure(
501+
binder.op, "strides list size does not match the number of axes");
502+
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
485503
return rewriter.notifyMatchFailure(binder.op,
486-
"invalid pooling parameters");
504+
"dilations bind failure");
505+
506+
// set default values for padding, strides, and dilations.
507+
if (padding.empty())
508+
padding.resize(spatialRank, 0);
509+
if (strides.empty())
510+
strides.resize(spatialRank, 1);
511+
if (dilations.empty())
512+
dilations.resize(spatialRank, 1);
513+
514+
// Padding for the beginning and ending along each spatial axis, it can
515+
// take any value greater than or equal to 0. The value represent the
516+
// number of pixels added to the beginning and end part of the
517+
// corresponding axis. pads format should be as follow [x1_begin,
518+
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
519+
// at the beginning of axis i and xi_end, the number of pixels added at
520+
// the end of axis i.
521+
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
522+
if (autoPad != "NOTSET" && autoPad != "VALID") {
523+
const bool isSameLower = autoPad == "SAME_LOWER";
524+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
525+
padding.resize_for_overwrite(2 * spatialRank);
526+
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
527+
const int64_t dilatedKernelSize =
528+
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
529+
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
530+
strides[dimIdx] -
531+
1) *
532+
strides[dimIdx] +
533+
dilatedKernelSize - inputShape[dimIdx + 2];
534+
totalPad = totalPad >= 0 ? totalPad : 0;
535+
padding[dimIdx] =
536+
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
537+
padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
538+
}
539+
}
540+
541+
// If the padding is symmetric we can push the padding operation to the
542+
// torch operator.
543+
if (padding.size() == static_cast<size_t>(2 * spatialRank)) {
544+
bool equal = true;
545+
for (int i = 0; i < spatialRank; ++i) {
546+
equal = equal && (padding[i] == padding[i + spatialRank]);
547+
}
548+
if (equal)
549+
padding.resize(spatialRank);
550+
}
487551

488552
// Since the PyTorch AvgPool op does not contain the `dilation` arg,
489553
// hence we use the trick of encoding dilation into strides. Then,

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

+114-14
Original file line numberDiff line numberDiff line change
@@ -1130,38 +1130,138 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
11301130
});
11311131
patterns.onOp(
11321132
"MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1133+
std::string autoPad;
1134+
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
1135+
return rewriter.notifyMatchFailure(binder.op,
1136+
"auto_pad bind failure");
1137+
11331138
Torch::ValueTensorType resultTypeOut;
11341139
Value operand;
11351140
int64_t ceilMode, storageOrder;
1136-
std::string autoPad;
1141+
// TODO: Add support for indices output and storage_order
11371142
if (binder.tensorOperand(operand) ||
11381143
binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) ||
11391144
binder.s64IntegerAttr(storageOrder, "storage_order", 0) ||
1140-
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
11411145
binder.tensorResultTypeAtIndex(resultTypeOut, 0))
11421146
return rewriter.notifyMatchFailure(
1143-
binder.op, "operand/ceil_mode/storage_order/auto_pad/resultType "
1144-
"bind failure");
1145-
// TODO: Add support for storage_order
1147+
binder.op,
1148+
"operand/ceil_mode/storage_order/resultType bind failure");
11461149
if (storageOrder != 0)
11471150
return rewriter.notifyMatchFailure(
11481151
binder.op, "storage_order setting is not supported.");
1149-
11501152
// Determine the rank of input tensor.
11511153
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
11521154
if (!maybeRank)
11531155
return rewriter.notifyMatchFailure(binder.op,
11541156
"Unimplemented: unranked tensor");
1155-
unsigned rank = *maybeRank;
1157+
int64_t rank = *maybeRank;
1158+
int64_t spatial = rank - 2;
11561159

1157-
SmallVector<int64_t> kernel, padding, strides, dilations,
1158-
stridesDilations;
1159-
if (failed(checkAndGetOnnxPoolingOpParameters(
1160-
binder, rewriter, resultTypeOut.getDtype(), autoPad,
1161-
/*spatialRank=*/rank - 2,
1162-
/*input=*/operand, kernel, strides, padding, dilations)))
1160+
SmallVector<int64_t> kernel, padding, strides, dilations;
1161+
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
11631162
return rewriter.notifyMatchFailure(binder.op,
1164-
"invalid pooling parameters");
1163+
"kernel_shape bind failure");
1164+
if (kernel.size() != static_cast<size_t>(spatial))
1165+
return rewriter.notifyMatchFailure(
1166+
binder.op, "kernel list size does not match the number of axes");
1167+
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
1168+
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
1169+
if (!padding.empty() &&
1170+
padding.size() != static_cast<size_t>(2 * spatial))
1171+
return rewriter.notifyMatchFailure(
1172+
binder.op, "padding list must contain (begin,end) pair for each "
1173+
"spatial axis");
1174+
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
1175+
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
1176+
if (!strides.empty() && strides.size() != static_cast<size_t>(spatial))
1177+
return rewriter.notifyMatchFailure(
1178+
binder.op, "strides list size does not match the number of axes");
1179+
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
1180+
return rewriter.notifyMatchFailure(binder.op,
1181+
"dilations bind failure");
1182+
1183+
// set default padding
1184+
if (padding.empty())
1185+
padding.resize(spatial, 0);
1186+
if (strides.empty())
1187+
strides.resize(spatial, 1);
1188+
if (dilations.empty())
1189+
dilations.resize(spatial, 1);
1190+
1191+
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
1192+
1193+
// Padding for the beginning and ending along each spatial axis, it can
1194+
// take any value greater than or equal to 0. The value represent the
1195+
// number of pixels added to the beginning and end part of the
1196+
// corresponding axis. pads format should be as follow [x1_begin,
1197+
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1198+
// at the beginning of axis i and xi_end, the number of pixels added at
1199+
// the end of axis i.
1200+
if (autoPad != "NOTSET" && autoPad != "VALID") {
1201+
const bool isSameLower = autoPad == "SAME_LOWER";
1202+
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
1203+
padding.resize_for_overwrite(2 * spatial);
1204+
for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) {
1205+
const int64_t dilatedKernelSize =
1206+
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
1207+
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
1208+
strides[dimIdx] -
1209+
1) *
1210+
strides[dimIdx] +
1211+
dilatedKernelSize - inputShape[dimIdx + 2];
1212+
totalPad = totalPad >= 0 ? totalPad : 0;
1213+
padding[dimIdx] =
1214+
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
1215+
padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1216+
}
1217+
}
1218+
1219+
// If the padding is symmetric we can push the padding operation to the
1220+
// torch operator.
1221+
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1222+
bool equal = true;
1223+
for (int i = 0; i < spatial; ++i) {
1224+
equal = equal && (padding[i] == padding[i + spatial]);
1225+
}
1226+
if (equal)
1227+
padding.resize(spatial);
1228+
}
1229+
1230+
// Torch pool operators require equal padding on each size of each
1231+
// dimension so we materialize the padding behavior explicitly and set
1232+
// the padding to 0.
1233+
if (padding.size() == static_cast<size_t>(2 * spatial)) {
1234+
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
1235+
llvm::SmallVector<int64_t> shuffledPadding(spatial * 2);
1236+
llvm::SmallVector<int64_t> paddedShape(operandTy.getSizes());
1237+
for (int i = 0; i < spatial; ++i) {
1238+
paddedShape[i + 2] += padding[i] + padding[i + spatial];
1239+
shuffledPadding[2 * i] = padding[spatial - i - 1];
1240+
shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1];
1241+
}
1242+
1243+
Value shuffledPaddingList =
1244+
createConstantIntList(binder, rewriter, shuffledPadding);
1245+
Value zero;
1246+
if (isa<FloatType>(resultTypeOut.getDtype())) {
1247+
zero = rewriter.create<Torch::ConstantFloatOp>(
1248+
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
1249+
rewriter.getF64FloatAttr(
1250+
std::numeric_limits<double>::lowest()));
1251+
} else if (isa<IntegerType>(resultTypeOut.getDtype())) {
1252+
zero = rewriter.create<Torch::ConstantIntOp>(
1253+
binder.getLoc(), rewriter.getI64IntegerAttr(
1254+
std::numeric_limits<int64_t>::lowest()));
1255+
}
1256+
1257+
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
1258+
paddedShape, operandTy.getDtype());
1259+
operand = rewriter.create<Torch::AtenConstantPadNdOp>(
1260+
binder.getLoc(), paddedInputTy, operand, shuffledPaddingList,
1261+
zero);
1262+
padding.clear();
1263+
padding.resize(spatial, 0);
1264+
}
11651265

11661266
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
11671267
Value paddingList = createConstantIntList(binder, rewriter, padding);

lib/Conversion/TorchOnnxToTorch/Utils.cpp

-113
Original file line numberDiff line numberDiff line change
@@ -201,116 +201,3 @@ LogicalResult mlir::torch::onnx_c::createDequantizeTensor(
201201
quantizedInput);
202202
return success();
203203
}
204-
205-
// Checks the validity of pooling parameters and stores them in the respective
206-
// vector.
207-
LogicalResult mlir::torch::onnx_c::checkAndGetOnnxPoolingOpParameters(
208-
OpBinder binder, ConversionPatternRewriter &rewriter, Type resultDtype,
209-
std::string autoPad, int64_t spatialRank, Value &input,
210-
SmallVectorImpl<int64_t> &kernelSizeInts,
211-
SmallVectorImpl<int64_t> &strideInts, SmallVectorImpl<int64_t> &paddingInts,
212-
SmallVectorImpl<int64_t> &dilationInts) {
213-
SmallVector<int64_t> kernel, padding, strides, dilations;
214-
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
215-
return rewriter.notifyMatchFailure(binder.op, "kernel_shape bind failure");
216-
if (kernel.size() != static_cast<size_t>(spatialRank))
217-
return rewriter.notifyMatchFailure(
218-
binder.op, "kernel list size does not match the number of axes");
219-
if (binder.s64IntegerArrayAttr(padding, "pads", {}))
220-
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
221-
if (!padding.empty() &&
222-
padding.size() != static_cast<size_t>(2 * spatialRank))
223-
return rewriter.notifyMatchFailure(
224-
binder.op, "padding list must contain (begin,end) pair for each "
225-
"spatial axis");
226-
if (binder.s64IntegerArrayAttr(strides, "strides", {}))
227-
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
228-
if (!strides.empty() && strides.size() != static_cast<size_t>(spatialRank))
229-
return rewriter.notifyMatchFailure(
230-
binder.op, "strides list size does not match the number of axes");
231-
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
232-
return rewriter.notifyMatchFailure(binder.op, "dilations bind failure");
233-
234-
// set default values for padding, strides, and dilations.
235-
if (padding.empty())
236-
padding.resize(spatialRank, 0);
237-
if (strides.empty())
238-
strides.resize(spatialRank, 1);
239-
if (dilations.empty())
240-
dilations.resize(spatialRank, 1);
241-
242-
// Padding for the beginning and ending along each spatial axis, it can
243-
// take any value greater than or equal to 0. The value represent the
244-
// number of pixels added to the beginning and end part of the
245-
// corresponding axis. pads format should be as follow [x1_begin,
246-
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
247-
// at the beginning of axis i and xi_end, the number of pixels added at
248-
// the end of axis i.
249-
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
250-
if (autoPad != "NOTSET" && autoPad != "VALID") {
251-
const bool isSameLower = autoPad == "SAME_LOWER";
252-
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
253-
padding.resize_for_overwrite(2 * spatialRank);
254-
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
255-
const int64_t dilatedKernelSize =
256-
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
257-
int64_t totalPad =
258-
((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / strides[dimIdx] -
259-
1) *
260-
strides[dimIdx] +
261-
dilatedKernelSize - inputShape[dimIdx + 2];
262-
totalPad = totalPad >= 0 ? totalPad : 0;
263-
padding[dimIdx] = isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
264-
padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
265-
}
266-
}
267-
268-
// If the padding is symmetric we can push the padding operation to the
269-
// torch operator.
270-
if (padding.size() == static_cast<size_t>(2 * spatialRank)) {
271-
bool equal = true;
272-
for (int i = 0; i < spatialRank; ++i) {
273-
equal = equal && (padding[i] == padding[i + spatialRank]);
274-
}
275-
if (equal)
276-
padding.resize(spatialRank);
277-
}
278-
279-
// Torch pool operators require equal padding on each size of each
280-
// dimension so we materialize the padding behavior explicitly and set
281-
// the padding to 0.
282-
if (padding.size() == static_cast<size_t>(2 * spatialRank)) {
283-
llvm::SmallVector<int64_t> shuffledPadding(spatialRank * 2);
284-
llvm::SmallVector<int64_t> paddedShape(inputTensorType.getSizes());
285-
for (int i = 0; i < spatialRank; ++i) {
286-
paddedShape[i + 2] += padding[i] + padding[i + spatialRank];
287-
shuffledPadding[2 * i] = padding[spatialRank - i - 1];
288-
shuffledPadding[2 * i + 1] = padding[2 * spatialRank - i - 1];
289-
}
290-
291-
Value shuffledPaddingList =
292-
createConstantIntList(binder, rewriter, shuffledPadding);
293-
Value zero;
294-
if (isa<FloatType>(resultDtype)) {
295-
zero = rewriter.create<Torch::ConstantFloatOp>(
296-
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
297-
rewriter.getF64FloatAttr(std::numeric_limits<double>::lowest()));
298-
} else if (isa<IntegerType>(resultDtype)) {
299-
zero = rewriter.create<Torch::ConstantIntOp>(
300-
binder.getLoc(),
301-
rewriter.getI64IntegerAttr(std::numeric_limits<int64_t>::lowest()));
302-
}
303-
304-
auto paddedInputTy = rewriter.getType<Torch::ValueTensorType>(
305-
paddedShape, inputTensorType.getDtype());
306-
input = rewriter.create<Torch::AtenConstantPadNdOp>(
307-
binder.getLoc(), paddedInputTy, input, shuffledPaddingList, zero);
308-
padding.clear();
309-
padding.resize(spatialRank, 0);
310-
}
311-
312-
kernelSizeInts = kernel;
313-
paddingInts = padding;
314-
dilationInts = dilations;
315-
return success();
316-
}

0 commit comments

Comments
 (0)