@@ -1130,38 +1130,138 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1130
1130
});
1131
1131
patterns.onOp (
1132
1132
" MaxPool" , 12 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1133
+ std::string autoPad;
1134
+ if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
1135
+ return rewriter.notifyMatchFailure (binder.op ,
1136
+ " auto_pad bind failure" );
1137
+
1133
1138
Torch::ValueTensorType resultTypeOut;
1134
1139
Value operand;
1135
1140
int64_t ceilMode, storageOrder;
1136
- std::string autoPad;
1141
+ // TODO: Add support for indices output and storage_order
1137
1142
if (binder.tensorOperand (operand) ||
1138
1143
binder.s64IntegerAttr (ceilMode, " ceil_mode" , 0 ) ||
1139
1144
binder.s64IntegerAttr (storageOrder, " storage_order" , 0 ) ||
1140
- binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ) ||
1141
1145
binder.tensorResultTypeAtIndex (resultTypeOut, 0 ))
1142
1146
return rewriter.notifyMatchFailure (
1143
- binder.op , " operand/ceil_mode/storage_order/auto_pad/resultType "
1144
- " bind failure" );
1145
- // TODO: Add support for storage_order
1147
+ binder.op ,
1148
+ " operand/ceil_mode/storage_order/resultType bind failure" );
1146
1149
if (storageOrder != 0 )
1147
1150
return rewriter.notifyMatchFailure (
1148
1151
binder.op , " storage_order setting is not supported." );
1149
-
1150
1152
// Determine the rank of input tensor.
1151
1153
std::optional<unsigned > maybeRank = Torch::getTensorRank (operand);
1152
1154
if (!maybeRank)
1153
1155
return rewriter.notifyMatchFailure (binder.op ,
1154
1156
" Unimplemented: unranked tensor" );
1155
- unsigned rank = *maybeRank;
1157
+ int64_t rank = *maybeRank;
1158
+ int64_t spatial = rank - 2 ;
1156
1159
1157
- SmallVector<int64_t > kernel, padding, strides, dilations,
1158
- stridesDilations;
1159
- if (failed (checkAndGetOnnxPoolingOpParameters (
1160
- binder, rewriter, resultTypeOut.getDtype (), autoPad,
1161
- /* spatialRank=*/ rank - 2 ,
1162
- /* input=*/ operand, kernel, strides, padding, dilations)))
1160
+ SmallVector<int64_t > kernel, padding, strides, dilations;
1161
+ if (binder.s64IntegerArrayAttr (kernel, " kernel_shape" , {}))
1163
1162
return rewriter.notifyMatchFailure (binder.op ,
1164
- " invalid pooling parameters" );
1163
+ " kernel_shape bind failure" );
1164
+ if (kernel.size () != static_cast <size_t >(spatial))
1165
+ return rewriter.notifyMatchFailure (
1166
+ binder.op , " kernel list size does not match the number of axes" );
1167
+ if (binder.s64IntegerArrayAttr (padding, " pads" , {}))
1168
+ return rewriter.notifyMatchFailure (binder.op , " pads bind failure" );
1169
+ if (!padding.empty () &&
1170
+ padding.size () != static_cast <size_t >(2 * spatial))
1171
+ return rewriter.notifyMatchFailure (
1172
+ binder.op , " padding list must contain (begin,end) pair for each "
1173
+ " spatial axis" );
1174
+ if (binder.s64IntegerArrayAttr (strides, " strides" , {}))
1175
+ return rewriter.notifyMatchFailure (binder.op , " strides bind failure" );
1176
+ if (!strides.empty () && strides.size () != static_cast <size_t >(spatial))
1177
+ return rewriter.notifyMatchFailure (
1178
+ binder.op , " strides list size does not match the number of axes" );
1179
+ if (binder.s64IntegerArrayAttr (dilations, " dilations" , {}))
1180
+ return rewriter.notifyMatchFailure (binder.op ,
1181
+ " dilations bind failure" );
1182
+
1183
+ // set default padding
1184
+ if (padding.empty ())
1185
+ padding.resize (spatial, 0 );
1186
+ if (strides.empty ())
1187
+ strides.resize (spatial, 1 );
1188
+ if (dilations.empty ())
1189
+ dilations.resize (spatial, 1 );
1190
+
1191
+ auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType ());
1192
+
1193
+ // Padding for the beginning and ending along each spatial axis, it can
1194
+ // take any value greater than or equal to 0. The value represent the
1195
+ // number of pixels added to the beginning and end part of the
1196
+ // corresponding axis. pads format should be as follow [x1_begin,
1197
+ // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
1198
+ // at the beginning of axis i and xi_end, the number of pixels added at
1199
+ // the end of axis i.
1200
+ if (autoPad != " NOTSET" && autoPad != " VALID" ) {
1201
+ const bool isSameLower = autoPad == " SAME_LOWER" ;
1202
+ ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
1203
+ padding.resize_for_overwrite (2 * spatial);
1204
+ for (unsigned dimIdx = 0 ; dimIdx < spatial; dimIdx++) {
1205
+ const int64_t dilatedKernelSize =
1206
+ dilations[dimIdx] * (kernel[dimIdx] - 1 ) + 1 ;
1207
+ int64_t totalPad = ((inputShape[dimIdx + 2 ] + strides[dimIdx] - 1 ) /
1208
+ strides[dimIdx] -
1209
+ 1 ) *
1210
+ strides[dimIdx] +
1211
+ dilatedKernelSize - inputShape[dimIdx + 2 ];
1212
+ totalPad = totalPad >= 0 ? totalPad : 0 ;
1213
+ padding[dimIdx] =
1214
+ isSameLower ? ((totalPad + 1 ) / 2 ) : (totalPad / 2 );
1215
+ padding[spatial + dimIdx] = totalPad - padding[dimIdx];
1216
+ }
1217
+ }
1218
+
1219
+ // If the padding is symmetric we can push the padding operation to the
1220
+ // torch operator.
1221
+ if (padding.size () == static_cast <size_t >(2 * spatial)) {
1222
+ bool equal = true ;
1223
+ for (int i = 0 ; i < spatial; ++i) {
1224
+ equal = equal && (padding[i] == padding[i + spatial]);
1225
+ }
1226
+ if (equal)
1227
+ padding.resize (spatial);
1228
+ }
1229
+
1230
+ // Torch pool operators require equal padding on each size of each
1231
+ // dimension so we materialize the padding behavior explicitly and set
1232
+ // the padding to 0.
1233
+ if (padding.size () == static_cast <size_t >(2 * spatial)) {
1234
+ auto operandTy = cast<Torch::ValueTensorType>(operand.getType ());
1235
+ llvm::SmallVector<int64_t > shuffledPadding (spatial * 2 );
1236
+ llvm::SmallVector<int64_t > paddedShape (operandTy.getSizes ());
1237
+ for (int i = 0 ; i < spatial; ++i) {
1238
+ paddedShape[i + 2 ] += padding[i] + padding[i + spatial];
1239
+ shuffledPadding[2 * i] = padding[spatial - i - 1 ];
1240
+ shuffledPadding[2 * i + 1 ] = padding[2 * spatial - i - 1 ];
1241
+ }
1242
+
1243
+ Value shuffledPaddingList =
1244
+ createConstantIntList (binder, rewriter, shuffledPadding);
1245
+ Value zero;
1246
+ if (isa<FloatType>(resultTypeOut.getDtype ())) {
1247
+ zero = rewriter.create <Torch::ConstantFloatOp>(
1248
+ binder.getLoc (), rewriter.getType <Torch::FloatType>(),
1249
+ rewriter.getF64FloatAttr (
1250
+ std::numeric_limits<double >::lowest ()));
1251
+ } else if (isa<IntegerType>(resultTypeOut.getDtype ())) {
1252
+ zero = rewriter.create <Torch::ConstantIntOp>(
1253
+ binder.getLoc (), rewriter.getI64IntegerAttr (
1254
+ std::numeric_limits<int64_t >::lowest ()));
1255
+ }
1256
+
1257
+ auto paddedInputTy = rewriter.getType <Torch::ValueTensorType>(
1258
+ paddedShape, operandTy.getDtype ());
1259
+ operand = rewriter.create <Torch::AtenConstantPadNdOp>(
1260
+ binder.getLoc (), paddedInputTy, operand, shuffledPaddingList,
1261
+ zero);
1262
+ padding.clear ();
1263
+ padding.resize (spatial, 0 );
1264
+ }
1165
1265
1166
1266
Value kernelSizeList = createConstantIntList (binder, rewriter, kernel);
1167
1267
Value paddingList = createConstantIntList (binder, rewriter, padding);
0 commit comments